add QTableFrame and clickable states that display a gui

- remove org.javaTuple in favour of org.apache.common for tuples and circleQueue
- remove ViewListener from non-GUI Controller
- stateActionTable saves the last 10 states that changed. They will get displayed in QTable Frame
in JTextAreas
This commit is contained in:
Jan Löwenstrom 2020-01-01 23:54:18 +01:00
parent a8f8af1102
commit f4f1f7bd37
16 changed files with 334 additions and 136 deletions

View File

@ -20,9 +20,10 @@ dependencies {
testCompile group: 'junit', name: 'junit', version: '4.12' testCompile group: 'junit', name: 'junit', version: '4.12'
compileOnly 'org.projectlombok:lombok:1.18.10' compileOnly 'org.projectlombok:lombok:1.18.10'
annotationProcessor 'org.projectlombok:lombok:1.18.10' annotationProcessor 'org.projectlombok:lombok:1.18.10'
compile 'org.javatuples:javatuples:1.2' // https://mvnrepository.com/artifact/org.apache.commons/commons-lang3
// https://mvnrepository.com/artifact/org.javatuples/javatuples compile group: 'org.apache.commons', name: 'commons-lang3', version: '3.0'
compile group: 'org.javatuples', name: 'javatuples', version: '1.2' // https://mvnrepository.com/artifact/org.apache.commons/commons-collections4
compile group: 'org.apache.commons', name: 'commons-collections4', version: '4.1'
} }

View File

@ -1,9 +1,9 @@
package core; package core;
import java.io.Serializable; import java.io.Serializable;
import java.util.LinkedHashMap; import java.util.*;
import java.util.Map;
import org.apache.commons.collections4.queue.CircularFifoQueue;
/** /**
* Premise: All states have the complete action space * Premise: All states have the complete action space
*/ */
@ -11,10 +11,12 @@ public class DeterministicStateActionTable<A extends Enum> implements StateActio
private final Map<State, Map<A, Double>> table; private final Map<State, Map<A, Double>> table;
private DiscreteActionSpace<A> discreteActionSpace; private DiscreteActionSpace<A> discreteActionSpace;
private Queue<Map.Entry<State, Map<A, Double>>> latestChanges;
public DeterministicStateActionTable(DiscreteActionSpace<A> discreteActionSpace){ public DeterministicStateActionTable(DiscreteActionSpace<A> discreteActionSpace){
table = new LinkedHashMap<>(); table = new LinkedHashMap<>();
this.discreteActionSpace = discreteActionSpace; this.discreteActionSpace = discreteActionSpace;
latestChanges = new CircularFifoQueue<>(10);
} }
/** /**
@ -57,6 +59,7 @@ public class DeterministicStateActionTable<A extends Enum> implements StateActio
actionValues = createDefaultActionValues(); actionValues = createDefaultActionValues();
table.put(state, actionValues); table.put(state, actionValues);
} }
latestChanges.add(new AbstractMap.SimpleEntry<>(state, actionValues));
actionValues.put(action, value); actionValues.put(action, value);
} }
@ -72,6 +75,11 @@ public class DeterministicStateActionTable<A extends Enum> implements StateActio
return table.get(state); return table.get(state);
} }
@Override
public Queue<Map.Entry<State, Map<A, Double>>> getFirstStateEntriesForView() {
return latestChanges;
}
/** /**
* @return Map with initial values for every available action * @return Map with initial values for every available action
*/ */

View File

@ -1,6 +1,7 @@
package core; package core;
import java.util.Map; import java.util.Map;
import java.util.Queue;
/** /**
* Q-Table which saves all seen states, all available actions for each state * Q-Table which saves all seen states, all available actions for each state
@ -15,4 +16,6 @@ public interface StateActionTable<A extends Enum> {
void setValue(State state, A action, double value); void setValue(State state, A action, double value);
int getStateCount(); int getStateCount();
Map<A, Double> getActionValues(State state); Map<A, Double> getActionValues(State state);
Queue<Map.Entry<State, Map<A, Double>>> getFirstStateEntriesForView();
} }

View File

@ -3,7 +3,8 @@ package core.algo.mc;
import core.*; import core.*;
import core.algo.EpisodicLearning; import core.algo.EpisodicLearning;
import core.policy.EpsilonGreedyPolicy; import core.policy.EpsilonGreedyPolicy;
import org.javatuples.Pair; import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import java.io.IOException; import java.io.IOException;
import java.io.ObjectInputStream; import java.io.ObjectInputStream;
@ -80,7 +81,7 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
Set<Pair<State, A>> stateActionPairs = new LinkedHashSet<>(); Set<Pair<State, A>> stateActionPairs = new LinkedHashSet<>();
for (StepResult<A> sr : episode) { for (StepResult<A> sr : episode) {
stateActionPairs.add(new Pair<>(sr.getState(), sr.getAction())); stateActionPairs.add(new ImmutablePair<>(sr.getState(), sr.getAction()));
} }
//System.out.println("stateActionPairs " + stateActionPairs.size()); //System.out.println("stateActionPairs " + stateActionPairs.size());
@ -88,7 +89,7 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
int firstOccurenceIndex = 0; int firstOccurenceIndex = 0;
// find first occurance of state action pair // find first occurance of state action pair
for (StepResult<A> sr : episode) { for (StepResult<A> sr : episode) {
if (stateActionPair.getValue0().equals(sr.getState()) && stateActionPair.getValue1().equals(sr.getAction())) { if (stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())) {
break; break;
} }
firstOccurenceIndex++; firstOccurenceIndex++;
@ -102,7 +103,7 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
// if the key does not exists, it will create a new entry with G as default value // if the key does not exists, it will create a new entry with G as default value
returnSum.merge(stateActionPair, G, Double::sum); returnSum.merge(stateActionPair, G, Double::sum);
returnCount.merge(stateActionPair, 1, Integer::sum); returnCount.merge(stateActionPair, 1, Integer::sum);
stateActionTable.setValue(stateActionPair.getValue0(), stateActionPair.getValue1(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair)); stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair));
} }
} }

View File

@ -15,10 +15,9 @@ import lombok.Setter;
import javax.swing.*; import javax.swing.*;
import java.io.*; import java.io.*;
import java.util.List; import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
public class RLController<A extends Enum> implements ViewListener, LearningListener {
public class RLController<A extends Enum> implements LearningListener {
protected final String folderPrefix = "learningStates" + File.separator; protected final String folderPrefix = "learningStates" + File.separator;
protected Environment<A> environment; protected Environment<A> environment;
protected DiscreteActionSpace<A> discreteActionSpace; protected DiscreteActionSpace<A> discreteActionSpace;
@ -37,15 +36,15 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
protected int prevDelay; protected int prevDelay;
protected volatile boolean printNextEpisode; protected volatile boolean printNextEpisode;
public RLController(Environment<A> env, Method method, A... actions){ public RLController(Environment<A> env, Method method, A... actions) {
setEnvironment(env); setEnvironment(env);
setMethod(method); setMethod(method);
setAllowedActions(actions); setAllowedActions(actions);
printNextEpisode = true; printNextEpisode = true;
} }
public void start(){ public void start() {
switch (method){ switch(method) {
case MC_ONPOLICY_EGREEDY: case MC_ONPOLICY_EGREEDY:
learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay); learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
break; break;
@ -60,13 +59,13 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
initLearning(); initLearning();
} }
protected void initListeners(){ protected void initListeners() {
learning.addListener(this); learning.addListener(this);
new Thread(() -> { new Thread(() -> {
while (true){ while(true) {
printNextEpisode = true; printNextEpisode = true;
try { try {
Thread.sleep(30*1000); Thread.sleep(30 * 1000);
} catch (InterruptedException e) { } catch (InterruptedException e) {
e.printStackTrace(); e.printStackTrace();
} }
@ -74,35 +73,42 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
}).start(); }).start();
} }
private void initLearning(){ private void initLearning() {
if(learning instanceof EpisodicLearning){ if(learning instanceof EpisodicLearning) {
System.out.println("Starting learning of <" + nrOfEpisodes + "> episodes"); System.out.println("Starting learning of <" + nrOfEpisodes + "> episodes");
((EpisodicLearning) learning).learn(nrOfEpisodes); ((EpisodicLearning) learning).learn(nrOfEpisodes);
}else{ } else {
learning.learn(); learning.learn();
} }
} }
/************************************************* protected void changeLearningDelay(int delay) {
** VIEW LISTENERS ** learning.setDelay(delay);
*************************************************/ }
@Override
public void onLearnMoreEpisodes(int nrOfEpisodes){ protected void learnMoreEpisodes(int nrOfEpisodes) {
if(learning instanceof EpisodicLearning){ if(learning instanceof EpisodicLearning) {
((EpisodicLearning) learning).learn(nrOfEpisodes); ((EpisodicLearning) learning).learn(nrOfEpisodes);
}else{ } else {
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!"); throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
} }
} }
@Override protected void changeEpsilon(float epsilon) {
public void onLoadState(String fileName) { if(learning.getPolicy() instanceof EpsilonPolicy) {
((EpsilonPolicy<A>) learning.getPolicy()).setEpsilon(epsilon);
} else {
System.out.println("Trying to call inEpsilonChange on non-epsilon policy");
}
}
protected void saveState(String fileName) {
FileInputStream fis; FileInputStream fis;
ObjectInputStream in; ObjectInputStream in;
try { try {
fis = new FileInputStream(fileName); fis = new FileInputStream(fileName);
in = new ObjectInputStream(fis); in = new ObjectInputStream(fis);
System.out.println("interrupt" + Thread.currentThread().getId()); System.out.println("interrup" + Thread.currentThread().getId());
learning.interruptLearning(); learning.interruptLearning();
learning.load(in); learning.load(in);
in.close(); in.close();
@ -111,46 +117,26 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
} }
} }
@Override protected void loadState(String fileName) {
public void onSaveState(String fileName) {
FileOutputStream fos; FileOutputStream fos;
ObjectOutputStream out; ObjectOutputStream out;
try{ try {
fos = new FileOutputStream(folderPrefix + fileName); fos = new FileOutputStream(folderPrefix + fileName);
out = new ObjectOutputStream(fos); out = new ObjectOutputStream(fos);
learning.interruptLearning(); learning.interruptLearning();
learning.save(out); learning.save(out);
out.close(); out.close();
}catch (IOException e){ } catch (IOException e) {
e.printStackTrace(); e.printStackTrace();
} }
} }
@Override protected void changeFastLearning(boolean fastLearn) {
public void onEpsilonChange(float epsilon) {
if(learning.getPolicy() instanceof EpsilonPolicy){
((EpsilonPolicy<A>) learning.getPolicy()).setEpsilon(epsilon);
}else{
System.out.println("Trying to call inEpsilonChange on non-epsilon policy");
}
}
@Override
public void onDelayChange(int delay) {
changeLearningDelay(delay);
}
protected void changeLearningDelay(int delay){
learning.setDelay(delay);
}
@Override
public void onFastLearnChange(boolean fastLearn) {
this.fastLearning = fastLearn; this.fastLearning = fastLearn;
if(fastLearn){ if(fastLearn) {
prevDelay = learning.getDelay(); prevDelay = learning.getDelay();
changeLearningDelay(0); changeLearningDelay(0);
}else{ } else {
changeLearningDelay(prevDelay); changeLearningDelay(prevDelay);
} }
} }
@ -165,7 +151,6 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
@Override @Override
public void onLearningEnd() { public void onLearningEnd() {
System.out.println("Learning finished"); System.out.println("Learning finished");
onSaveState( method.toString() + System.currentTimeMillis()/1000 + (learning instanceof EpisodicLearning ? "e" + ((EpisodicLearning) learning).getCurrentEpisode() : ""));
} }
@Override @Override
@ -176,9 +161,9 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
@Override @Override
public void onEpisodeEnd(List<Double> rewardHistory) { public void onEpisodeEnd(List<Double> rewardHistory) {
latestRewardsHistory = rewardHistory; latestRewardsHistory = rewardHistory;
if(printNextEpisode){ if(printNextEpisode) {
System.out.println("Episode " + ((EpisodicLearning) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size()-1)); System.out.println("Episode " + ((EpisodicLearning) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1));
System.out.println("Eps/sec: " + ((EpisodicLearning) learning).getEpisodePerSecond()); System.out.println("Eps/sec: " + ((EpisodicLearning) learning).getEpisodePerSecond());
printNextEpisode = false; printNextEpisode = false;
} }
} }
@ -192,22 +177,22 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
** SETTERS ** ** SETTERS **
*************************************************/ *************************************************/
private void setEnvironment(Environment<A> environment){ private void setEnvironment(Environment<A> environment) {
if(environment == null){ if(environment == null) {
throw new IllegalArgumentException("Environment cannot be null"); throw new IllegalArgumentException("Environment cannot be null");
} }
this.environment = environment; this.environment = environment;
} }
private void setMethod(Method method){ private void setMethod(Method method) {
if(method == null){ if(method == null) {
throw new IllegalArgumentException("Method cannot be null"); throw new IllegalArgumentException("Method cannot be null");
} }
this.method = method; this.method = method;
} }
private void setAllowedActions(A[] actions){ private void setAllowedActions(A[] actions) {
if(actions == null || actions.length == 0){ if(actions == null || actions.length == 0) {
throw new IllegalArgumentException("There has to be at least one action"); throw new IllegalArgumentException("There has to be at least one action");
} }
this.discreteActionSpace = new ListDiscreteActionSpace<>(actions); this.discreteActionSpace = new ListDiscreteActionSpace<>(actions);

View File

@ -1,14 +1,16 @@
package core.controller; package core.controller;
import core.Environment; import core.Environment;
import core.algo.EpisodicLearning;
import core.algo.Method; import core.algo.Method;
import core.gui.LearningView; import core.gui.LearningView;
import core.gui.View; import core.gui.View;
import core.listener.ViewListener;
import javax.swing.*; import javax.swing.*;
import java.util.List; import java.util.List;
public class RLControllerGUI<A extends Enum> extends RLController<A> { public class RLControllerGUI<A extends Enum> extends RLController<A> implements ViewListener {
private LearningView learningView; private LearningView learningView;
public RLControllerGUI(Environment<A> env, Method method, A... actions) { public RLControllerGUI(Environment<A> env, Method method, A... actions) {
@ -23,21 +25,41 @@ public class RLControllerGUI<A extends Enum> extends RLController<A> {
}); });
} }
@Override /*************************************************
public void onLearnMoreEpisodes(int nrOfEpisodes) { ** View LISTENERS **
super.onLearnMoreEpisodes(nrOfEpisodes); *************************************************/
learningView.updateLearningInfoPanel();
}
@Override @Override
public void onLoadState(String fileName) { public void onLoadState(String fileName) {
super.onLoadState(fileName); super.loadState(fileName);
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel()); SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
} }
@Override
public void onSaveState(String fileName) {
super.saveState(fileName);
}
@Override
public void onShowQTable() {
learningView.showQTableFrame();
}
@Override @Override
public void onEpsilonChange(float epsilon) { public void onEpsilonChange(float epsilon) {
super.onEpsilonChange(epsilon); super.changeEpsilon(epsilon);
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
}
@Override
public void onDelayChange(int delay) {
super.changeLearningDelay(delay);
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
}
@Override
public void onFastLearnChange(boolean isFastLearn) {
super.changeFastLearning(isFastLearn);
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel()); SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
} }
@ -48,17 +70,23 @@ public class RLControllerGUI<A extends Enum> extends RLController<A> {
} }
@Override @Override
public void onLearningEnd() { public void onLearnMoreEpisodes(int nrOfEpisodes) {
super.onLearningEnd(); super.learnMoreEpisodes(nrOfEpisodes);
SwingUtilities.invokeLater(() -> learningView.updateRewardGraph(latestRewardsHistory)); learningView.updateLearningInfoPanel();
} }
/*************************************************
** LEARNING LISTENERS **
*************************************************/
@Override @Override
public void onEpisodeEnd(List<Double> rewardHistory) { public void onEpisodeEnd(List<Double> rewardHistory) {
super.onEpisodeEnd(rewardHistory); super.onEpisodeEnd(rewardHistory);
SwingUtilities.invokeLater(() -> { SwingUtilities.invokeLater(() -> {
if (!fastLearning) { if(!fastLearning) {
learningView.updateRewardGraph(latestRewardsHistory); learningView.updateRewardGraph(latestRewardsHistory);
learningView.updateQTable();
} }
learningView.updateLearningInfoPanel(); learningView.updateLearningInfoPanel();
}); });
@ -66,8 +94,15 @@ public class RLControllerGUI<A extends Enum> extends RLController<A> {
@Override @Override
public void onStepEnd() { public void onStepEnd() {
if (!fastLearning) { if(!fastLearning) {
SwingUtilities.invokeLater(() -> learningView.repaintEnvironment()); SwingUtilities.invokeLater(() -> learningView.repaintEnvironment());
} }
} }
@Override
public void onLearningEnd() {
super.onLearningEnd();
onSaveState(method.toString() + System.currentTimeMillis() / 1000 + (learning instanceof EpisodicLearning ? "e" + ((EpisodicLearning) learning).getCurrentEpisode() : ""));
SwingUtilities.invokeLater(() -> learningView.updateRewardGraph(latestRewardsHistory));
}
} }

View File

@ -26,24 +26,25 @@ public class LearningInfoPanel extends JPanel {
private JCheckBox drawEnvironmentCheckbox; private JCheckBox drawEnvironmentCheckbox;
private JTextField learnMoreEpisodesInput; private JTextField learnMoreEpisodesInput;
private JButton learnMoreEpisodesButton; private JButton learnMoreEpisodesButton;
private JButton showQTableButton;
public LearningInfoPanel(Learning learning, ViewListener viewListener){ public LearningInfoPanel(Learning learning, ViewListener viewListener) {
this.learning = learning; this.learning = learning;
setLayout(new BoxLayout(this, BoxLayout.Y_AXIS)); setLayout(new BoxLayout(this, BoxLayout.Y_AXIS));
policyLabel = new JLabel(); policyLabel = new JLabel();
discountLabel = new JLabel(); discountLabel = new JLabel();
delayLabel = new JLabel(); delayLabel = new JLabel();
if(learning instanceof Episodic){ if(learning instanceof Episodic) {
episodeLabel = new JLabel(); episodeLabel = new JLabel();
add(episodeLabel); add(episodeLabel);
} }
delaySlider = new JSlider(0,1000, learning.getDelay()); delaySlider = new JSlider(0, 1000, learning.getDelay());
delaySlider.addChangeListener(e -> viewListener.onDelayChange(delaySlider.getValue())); delaySlider.addChangeListener(e -> viewListener.onDelayChange(delaySlider.getValue()));
add(policyLabel); add(policyLabel);
add(discountLabel); add(discountLabel);
if(learning.getPolicy() instanceof EpsilonPolicy){ if(learning.getPolicy() instanceof EpsilonPolicy) {
epsilonLabel = new JLabel(); epsilonLabel = new JLabel();
epsilonSlider = new JSlider(0, 100, (int)((EpsilonPolicy)learning.getPolicy()).getEpsilon() * 100); epsilonSlider = new JSlider(0, 100, (int) ((EpsilonPolicy) learning.getPolicy()).getEpsilon() * 100);
epsilonSlider.addChangeListener(e -> viewListener.onEpsilonChange(epsilonSlider.getValue() / 100f)); epsilonSlider.addChangeListener(e -> viewListener.onEpsilonChange(epsilonSlider.getValue() / 100f));
add(epsilonLabel); add(epsilonLabel);
add(epsilonSlider); add(epsilonSlider);
@ -51,7 +52,7 @@ public class LearningInfoPanel extends JPanel {
toggleFastLearningButton = new JButton("Enable fast-learn"); toggleFastLearningButton = new JButton("Enable fast-learn");
fastLearning = false; fastLearning = false;
toggleFastLearningButton.addActionListener(e->{ toggleFastLearningButton.addActionListener(e -> {
fastLearning = !fastLearning; fastLearning = !fastLearning;
delaySlider.setEnabled(!fastLearning); delaySlider.setEnabled(!fastLearning);
epsilonSlider.setEnabled(!fastLearning); epsilonSlider.setEnabled(!fastLearning);
@ -71,10 +72,10 @@ public class LearningInfoPanel extends JPanel {
if(learning instanceof EpisodicLearning) { if(learning instanceof EpisodicLearning) {
learnMoreEpisodesInput = new JTextField(); learnMoreEpisodesInput = new JTextField();
learnMoreEpisodesInput.setMaximumSize(new Dimension(200,20)); learnMoreEpisodesInput.setMaximumSize(new Dimension(200, 20));
learnMoreEpisodesButton = new JButton("Learn More Episodes"); learnMoreEpisodesButton = new JButton("Learn More Episodes");
learnMoreEpisodesButton.addActionListener(e -> { learnMoreEpisodesButton.addActionListener(e -> {
if (Util.isNumeric(learnMoreEpisodesInput.getText())) { if(Util.isNumeric(learnMoreEpisodesInput.getText())) {
viewListener.onLearnMoreEpisodes(Integer.parseInt(learnMoreEpisodesInput.getText())); viewListener.onLearnMoreEpisodes(Integer.parseInt(learnMoreEpisodesInput.getText()));
} else { } else {
learnMoreEpisodesInput.setText(""); learnMoreEpisodesInput.setText("");
@ -83,9 +84,14 @@ public class LearningInfoPanel extends JPanel {
add(learnMoreEpisodesInput); add(learnMoreEpisodesInput);
add(learnMoreEpisodesButton); add(learnMoreEpisodesButton);
} }
showQTableButton = new JButton("Show Q-Table");
showQTableButton.addActionListener(e -> {
viewListener.onShowQTable();
});
add(drawEnvironmentCheckbox); add(drawEnvironmentCheckbox);
add(smoothGraphCheckbox); add(smoothGraphCheckbox);
add(last100Checkbox); add(last100Checkbox);
add(showQTableButton);
refreshLabels(); refreshLabels();
setVisible(true); setVisible(true);
} }
@ -93,17 +99,17 @@ public class LearningInfoPanel extends JPanel {
public void refreshLabels() { public void refreshLabels() {
policyLabel.setText("Policy: " + learning.getPolicy().getClass()); policyLabel.setText("Policy: " + learning.getPolicy().getClass());
discountLabel.setText("Discount factor: " + learning.getDiscountFactor()); discountLabel.setText("Discount factor: " + learning.getDiscountFactor());
if(learning instanceof Episodic){ if(learning instanceof Episodic) {
episodeLabel.setText("Episode: " + ((Episodic)(learning)).getCurrentEpisode() + episodeLabel.setText("Episode: " + ((Episodic) (learning)).getCurrentEpisode() +
"\t Episodes to go: " + ((Episodic)(learning)).getEpisodesToGo() + "\t Episodes to go: " + ((Episodic) (learning)).getEpisodesToGo() +
"\t Eps/Sec: " + ((Episodic)(learning)).getEpisodesPerSecond()); "\t Eps/Sec: " + ((Episodic) (learning)).getEpisodesPerSecond());
} }
if (learning.getPolicy() instanceof EpsilonPolicy) { if(learning.getPolicy() instanceof EpsilonPolicy) {
epsilonLabel.setText("Exploration (Epsilon): " + ((EpsilonPolicy) learning.getPolicy()).getEpsilon()); epsilonLabel.setText("Exploration (Epsilon): " + ((EpsilonPolicy) learning.getPolicy()).getEpsilon());
epsilonSlider.setValue((int)(((EpsilonPolicy) learning.getPolicy()).getEpsilon() * 100)); epsilonSlider.setValue((int) (((EpsilonPolicy) learning.getPolicy()).getEpsilon() * 100));
} }
delayLabel.setText("Delay (ms): " + learning.getDelay()); delayLabel.setText("Delay (ms): " + learning.getDelay());
if(delaySlider.isEnabled()){ if(delaySlider.isEnabled()) {
delaySlider.setValue(learning.getDelay()); delaySlider.setValue(learning.getDelay());
} }
toggleFastLearningButton.setText(fastLearning ? "Disable fast-learning" : "Enable fast-learning"); toggleFastLearningButton.setText(fastLearning ? "Disable fast-learning" : "Enable fast-learning");
@ -112,11 +118,12 @@ public class LearningInfoPanel extends JPanel {
protected boolean isSmoothenGraphSelected() { protected boolean isSmoothenGraphSelected() {
return smoothGraphCheckbox.isSelected(); return smoothGraphCheckbox.isSelected();
} }
protected boolean isLast100Selected(){
protected boolean isLast100Selected() {
return last100Checkbox.isSelected(); return last100Checkbox.isSelected();
} }
protected boolean isDrawEnvironmentSelected(){ protected boolean isDrawEnvironmentSelected() {
return drawEnvironmentCheckbox.isSelected(); return drawEnvironmentCheckbox.isSelected();
} }
} }

View File

@ -8,5 +8,7 @@ import java.util.List;
public interface LearningView { public interface LearningView {
void repaintEnvironment(); void repaintEnvironment();
void updateLearningInfoPanel(); void updateLearningInfoPanel();
void updateQTable();
void updateRewardGraph(final List<Double> rewardHistory); void updateRewardGraph(final List<Double> rewardHistory);
void showQTableFrame();
} }

View File

@ -0,0 +1,57 @@
package core.gui;
import core.State;
import core.StateActionTable;
import javax.swing.*;
import java.awt.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
public class QTableFrame<A extends Enum> extends JFrame {
private JLabel stateCountLabel;
private StateActionTable<A> stateActionTable;
private List<StateActionRow<A>> rows;
private JPanel areaWrapper;
public QTableFrame(StateActionTable<A> stateActionTable) {
super("Q-Table");
this.stateActionTable = stateActionTable;
rows = new ArrayList<>(10);
setDefaultCloseOperation(WindowConstants.HIDE_ON_CLOSE);
setLayout(new BorderLayout());
setPreferredSize(new Dimension(500, 500));
stateCountLabel = new JLabel();
add(BorderLayout.NORTH, stateCountLabel);
areaWrapper = new JPanel();
areaWrapper.setLayout(new BoxLayout(areaWrapper, BoxLayout.Y_AXIS));
for(int i = 0; i < 10; ++i) {
StateActionRow<A> a = new StateActionRow<>();
rows.add(a);
areaWrapper.add(a);
}
add(BorderLayout.CENTER, areaWrapper);
setVisible(false);
pack();
}
private void refreshAllTextAreas(){
for(StateActionRow<A> row : rows){
row.refreshLabels();
}
}
protected void refreshQTable() {
System.out.println("ref");
int stateCount = stateActionTable.getStateCount();
stateCountLabel.setText("Total states: " + stateCount);
int idx = -1;
for(Map.Entry<State, Map<A, Double>> entry : stateActionTable.getFirstStateEntriesForView()) {
if(++idx > rows.size() -1) break;
StateActionRow<A> row = rows.get(idx);
row.setState(entry.getKey());
row.setActionValues(entry.getValue());
}
refreshAllTextAreas();
}
}

View File

@ -0,0 +1,55 @@
package core.gui;
import core.State;
import lombok.Setter;
import javax.swing.*;
import java.awt.*;
import java.awt.event.MouseAdapter;
import java.awt.event.MouseEvent;
import java.util.Map;
@Setter
public class StateActionRow<A extends Enum> extends JTextArea {
private State state;
private Map<A, Double> actionValues;
public StateActionRow(){
this.state = null;
this.actionValues = null;
setMaximumSize(new Dimension(600, 100));
setEditable(false);
addMouseListener(new MouseAdapter() {
@Override
public void mousePressed(MouseEvent e) {
super.mousePressed(e);
showState();
}
});
}
protected void refreshLabels(){
if(state == null || actionValues == null) return;
System.out.println("refreshing");
StringBuilder sb = new StringBuilder(state.toString()).append("\n");
for(Map.Entry<A, Double> actionValue: actionValues.entrySet()){
sb.append("\t").append(actionValue.getKey()).append("\t").append(actionValue.getValue()).append("\n");
}
setText(sb.toString());
}
private void showState() {
if(state != null && state instanceof Visualizable){
new JFrame() {
{
JComponent stateComponent = ((Visualizable)state).visualize();
setPreferredSize(stateComponent.getPreferredSize());
setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
add(stateComponent);
pack();
setVisible(true);
}
};
}
}
}

View File

@ -4,7 +4,8 @@ import core.Environment;
import core.algo.Learning; import core.algo.Learning;
import core.listener.ViewListener; import core.listener.ViewListener;
import lombok.Getter; import lombok.Getter;
import org.javatuples.Pair; import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.knowm.xchart.QuickChart; import org.knowm.xchart.QuickChart;
import org.knowm.xchart.XChartPanel; import org.knowm.xchart.XChartPanel;
import org.knowm.xchart.XYChart; import org.knowm.xchart.XYChart;
@ -16,7 +17,7 @@ import java.io.File;
import java.util.List; import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CopyOnWriteArrayList;
public class View<A extends Enum> implements LearningView{ public class View<A extends Enum> implements LearningView {
private Learning<A> learning; private Learning<A> learning;
private Environment<A> environment; private Environment<A> environment;
@Getter @Getter
@ -26,6 +27,7 @@ public class View<A extends Enum> implements LearningView{
@Getter @Getter
private JFrame mainFrame; private JFrame mainFrame;
private JFrame environmentFrame; private JFrame environmentFrame;
private QTableFrame<A> qTableFrame;
private XChartPanel<XYChart> rewardChartPanel; private XChartPanel<XYChart> rewardChartPanel;
private ViewListener viewListener; private ViewListener viewListener;
private JMenuBar menuBar; private JMenuBar menuBar;
@ -36,11 +38,12 @@ public class View<A extends Enum> implements LearningView{
this.environment = environment; this.environment = environment;
this.viewListener = viewListener; this.viewListener = viewListener;
initMainFrame(); initMainFrame();
initQTableFrame();
} }
private void initMainFrame() { private void initMainFrame() {
mainFrame = new JFrame(); mainFrame = new JFrame();
mainFrame.setPreferredSize(new Dimension(1280, 720)); mainFrame.setPreferredSize(new Dimension(1000, 400));
mainFrame.setLayout(new BorderLayout()); mainFrame.setLayout(new BorderLayout());
menuBar = new JMenuBar(); menuBar = new JMenuBar();
fileMenu = new JMenu("File"); fileMenu = new JMenu("File");
@ -52,7 +55,7 @@ public class View<A extends Enum> implements LearningView{
fc.setCurrentDirectory(new File(System.getProperty("user.dir"))); fc.setCurrentDirectory(new File(System.getProperty("user.dir")));
int returnVal = fc.showOpenDialog(mainFrame); int returnVal = fc.showOpenDialog(mainFrame);
if (returnVal == JFileChooser.APPROVE_OPTION) { if(returnVal == JFileChooser.APPROVE_OPTION) {
viewListener.onLoadState(fc.getSelectedFile().toString()); viewListener.onLoadState(fc.getSelectedFile().toString());
} }
} }
@ -62,7 +65,7 @@ public class View<A extends Enum> implements LearningView{
@Override @Override
public void actionPerformed(ActionEvent e) { public void actionPerformed(ActionEvent e) {
String fileName = JOptionPane.showInputDialog("Enter file name", "path/to/file"); String fileName = JOptionPane.showInputDialog("Enter file name", "path/to/file");
if(fileName != null){ if(fileName != null) {
viewListener.onSaveState(fileName); viewListener.onSaveState(fileName);
} }
} }
@ -78,7 +81,7 @@ public class View<A extends Enum> implements LearningView{
mainFrame.pack(); mainFrame.pack();
mainFrame.setVisible(true); mainFrame.setVisible(true);
if (environment instanceof Visualizable) { if(environment instanceof Visualizable) {
environmentFrame = new JFrame() { environmentFrame = new JFrame() {
{ {
add(((Visualizable) environment).visualize()); add(((Visualizable) environment).visualize());
@ -86,9 +89,21 @@ public class View<A extends Enum> implements LearningView{
setVisible(true); setVisible(true);
} }
}; };
} }
} }
private void initQTableFrame(){
qTableFrame = new QTableFrame<>(learning.getStateActionTable());
}
@Override
public void updateQTable() {
qTableFrame.refreshQTable();
}
public void showQTableFrame(){
updateQTable();
qTableFrame.setVisible(true);
}
private void initLearningInfoPanel() { private void initLearningInfoPanel() {
learningInfoPanel = new LearningInfoPanel(learning, viewListener); learningInfoPanel = new LearningInfoPanel(learning, viewListener);
@ -109,32 +124,21 @@ public class View<A extends Enum> implements LearningView{
rewardChartPanel.setPreferredSize(new Dimension(300, 300)); rewardChartPanel.setPreferredSize(new Dimension(300, 300));
} }
public void showState(Visualizable state) {
new JFrame() {
{
JComponent stateComponent = state.visualize();
setPreferredSize(new Dimension(stateComponent.getWidth(), stateComponent.getHeight()));
add(stateComponent);
setVisible(true);
}
};
}
public void updateRewardGraph(final List<Double> rewardHistory) { public void updateRewardGraph(final List<Double> rewardHistory) {
List<Integer> xValues; List<Integer> xValues;
List<Double> yValues; List<Double> yValues;
if(learningInfoPanel.isLast100Selected()){ if(learningInfoPanel.isLast100Selected()) {
yValues = new CopyOnWriteArrayList<>(rewardHistory.subList(rewardHistory.size() - Math.min(rewardHistory.size(), 100), rewardHistory.size())); yValues = new CopyOnWriteArrayList<>(rewardHistory.subList(rewardHistory.size() - Math.min(rewardHistory.size(), 100), rewardHistory.size()));
xValues = new CopyOnWriteArrayList<>(); xValues = new CopyOnWriteArrayList<>();
for(int i = rewardHistory.size() - Math.min(rewardHistory.size(), 100); i <rewardHistory.size(); ++i){ for(int i = rewardHistory.size() - Math.min(rewardHistory.size(), 100); i < rewardHistory.size(); ++i) {
xValues.add(i); xValues.add(i);
} }
}else{ } else {
if(learningInfoPanel.isSmoothenGraphSelected()){ if(learningInfoPanel.isSmoothenGraphSelected()) {
Pair<List<Integer>, List<Double>> XYvalues = smoothenGraph(rewardHistory); Pair<List<Integer>, List<Double>> XYvalues = smoothenGraph(rewardHistory);
xValues = XYvalues.getValue0(); xValues = XYvalues.getKey();
yValues = XYvalues.getValue1(); yValues = XYvalues.getValue();
}else{ } else {
xValues = null; xValues = null;
yValues = rewardHistory; yValues = rewardHistory;
} }
@ -145,37 +149,37 @@ public class View<A extends Enum> implements LearningView{
rewardChartPanel.repaint(); rewardChartPanel.repaint();
} }
private Pair<List<Integer>, List<Double>> smoothenGraph(List<Double> original){ private Pair<List<Integer>, List<Double>> smoothenGraph(List<Double> original) {
int totalXPoints = 100; int totalXPoints = 100;
List<Integer> xValues = new CopyOnWriteArrayList<>(); List<Integer> xValues = new CopyOnWriteArrayList<>();
List<Double> tmp = new CopyOnWriteArrayList<>(); List<Double> tmp = new CopyOnWriteArrayList<>();
int meanBatch = original.size() / totalXPoints; int meanBatch = original.size() / totalXPoints;
if(meanBatch < 1){ if(meanBatch < 1) {
meanBatch = 1; meanBatch = 1;
} }
int idx = 0; int idx = 0;
int batchIdx = 0; int batchIdx = 0;
double batchSum = 0; double batchSum = 0;
for(Double x: original) { for(Double x : original) {
++idx; ++idx;
batchSum += x; batchSum += x;
if (idx == 1 || ++batchIdx % meanBatch == 0) { if(idx == 1 || ++batchIdx % meanBatch == 0) {
tmp.add(batchSum / meanBatch); tmp.add(batchSum / meanBatch);
xValues.add(idx); xValues.add(idx);
batchSum = 0; batchSum = 0;
} }
} }
return new Pair<>(xValues, tmp); return new ImmutablePair<>(xValues, tmp);
} }
public void updateLearningInfoPanel() { public void updateLearningInfoPanel() {
this.learningInfoPanel.refreshLabels(); this.learningInfoPanel.refreshLabels();
} }
public void repaintEnvironment(){ public void repaintEnvironment() {
if (environmentFrame != null && learningInfoPanel.isDrawEnvironmentSelected()) { if(environmentFrame != null && learningInfoPanel.isDrawEnvironmentSelected()) {
environmentFrame.repaint(); environmentFrame.repaint();
} }
} }

View File

@ -12,4 +12,5 @@ public interface ViewListener {
void onLearnMoreEpisodes(int nrOfEpisodes); void onLearnMoreEpisodes(int nrOfEpisodes);
void onLoadState(String fileName); void onLoadState(String fileName);
void onSaveState(String fileName); void onSaveState(String fileName);
void onShowQTable();
} }

View File

@ -1,8 +1,8 @@
package evironment.jumpingDino; package evironment.jumpingDino;
public class Config { public class Config {
public static final int FRAME_WIDTH = 1280; public static final int FRAME_WIDTH = 800;
public static final int FRAME_HEIGHT = 720; public static final int FRAME_HEIGHT = 300;
public static final int GROUND_Y = 50; public static final int GROUND_Y = 50;
public static final int DINO_STARTING_X = 50; public static final int DINO_STARTING_X = 50;
public static final int DINO_SIZE = 50; public static final int DINO_SIZE = 50;

View File

@ -1,16 +1,20 @@
package evironment.jumpingDino; package evironment.jumpingDino;
import core.State; import core.State;
import core.gui.Visualizable;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Getter; import lombok.Getter;
import javax.swing.*;
import java.awt.*;
import java.io.Serializable; import java.io.Serializable;
import java.util.Objects; import java.util.Objects;
@AllArgsConstructor @AllArgsConstructor
@Getter @Getter
public class DinoState implements State, Serializable { public class DinoState implements State, Serializable, Visualizable {
private int xDistanceToObstacle; private int xDistanceToObstacle;
protected final double scale = 0.5;
@Override @Override
public String toString() { public String toString() {
@ -31,4 +35,31 @@ public class DinoState implements State, Serializable {
public int hashCode() { public int hashCode() {
return Objects.hash(xDistanceToObstacle); return Objects.hash(xDistanceToObstacle);
} }
@Override
public JComponent visualize() {
return new JComponent() {
{
setPreferredSize(new Dimension(Config.FRAME_WIDTH, (int)(scale * Config.FRAME_HEIGHT)));
setVisible(true);
}
@Override
protected void paintComponent(Graphics g) {
super.paintComponents(g);
drawObjects(g);
}
};
}
public void drawObjects(Graphics g){
g.setColor(Color.BLACK);
g.fillRect(0, (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y)), Config.FRAME_WIDTH, 2);
g.fillRect((int)(scale * Config.DINO_STARTING_X), (int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE)), (int)(scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
g.drawString("Distance: " + xDistanceToObstacle, (int)(scale * Config.DINO_STARTING_X),(int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40) ));
g.fillRect((int)(scale * (Config.DINO_STARTING_X + getXDistanceToObstacle())), (int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE)), (int)(scale * Config.OBSTACLE_SIZE), (int)(scale *Config.OBSTACLE_SIZE));
}
} }

View File

@ -1,11 +1,13 @@
package evironment.jumpingDino; package evironment.jumpingDino;
import core.gui.Visualizable;
import lombok.Getter; import lombok.Getter;
import java.awt.*;
import java.util.Objects; import java.util.Objects;
@Getter @Getter
public class DinoStateWithSpeed extends DinoState{ public class DinoStateWithSpeed extends DinoState implements Visualizable {
private int obstacleSpeed; private int obstacleSpeed;
public DinoStateWithSpeed(int xDistanceToObstacle, int obstacleSpeed) { public DinoStateWithSpeed(int xDistanceToObstacle, int obstacleSpeed) {
@ -33,4 +35,10 @@ public class DinoStateWithSpeed extends DinoState{
public int hashCode() { public int hashCode() {
return Objects.hash(super.hashCode(), getObstacleSpeed()); return Objects.hash(super.hashCode(), getObstacleSpeed());
} }
@Override
public void drawObjects(Graphics g) {
super.drawObjects(g);
g.drawString("Speed: " + obstacleSpeed, (int)(scale * (Config.DINO_STARTING_X + getXDistanceToObstacle())),(int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40)) );
}
} }

View File

@ -16,7 +16,7 @@ public class JumpingDino {
Method.MC_ONPOLICY_EGREEDY, Method.MC_ONPOLICY_EGREEDY,
DinoAction.values()); DinoAction.values());
rl.setDelay(0); rl.setDelay(100);
rl.setDiscountFactor(1f); rl.setDiscountFactor(1f);
rl.setEpsilon(0.15f); rl.setEpsilon(0.15f);
rl.setNrOfEpisodes(100000); rl.setNrOfEpisodes(100000);