add QTableFrame and clickable states that display a gui
- remove org.javaTuple in favour of org.apache.common for tuples and circleQueue - remove ViewListener from non-GUI Controller - stateActionTable saves the last 10 states that changed. They will get displayed in QTable Frame in JTextAreas
This commit is contained in:
parent
a8f8af1102
commit
f4f1f7bd37
|
@ -20,9 +20,10 @@ dependencies {
|
|||
testCompile group: 'junit', name: 'junit', version: '4.12'
|
||||
compileOnly 'org.projectlombok:lombok:1.18.10'
|
||||
annotationProcessor 'org.projectlombok:lombok:1.18.10'
|
||||
compile 'org.javatuples:javatuples:1.2'
|
||||
// https://mvnrepository.com/artifact/org.javatuples/javatuples
|
||||
compile group: 'org.javatuples', name: 'javatuples', version: '1.2'
|
||||
// https://mvnrepository.com/artifact/org.apache.commons/commons-lang3
|
||||
compile group: 'org.apache.commons', name: 'commons-lang3', version: '3.0'
|
||||
// https://mvnrepository.com/artifact/org.apache.commons/commons-collections4
|
||||
compile group: 'org.apache.commons', name: 'commons-collections4', version: '4.1'
|
||||
|
||||
}
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
package core;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
import java.util.*;
|
||||
|
||||
import org.apache.commons.collections4.queue.CircularFifoQueue;
|
||||
/**
|
||||
* Premise: All states have the complete action space
|
||||
*/
|
||||
|
@ -11,10 +11,12 @@ public class DeterministicStateActionTable<A extends Enum> implements StateActio
|
|||
|
||||
private final Map<State, Map<A, Double>> table;
|
||||
private DiscreteActionSpace<A> discreteActionSpace;
|
||||
private Queue<Map.Entry<State, Map<A, Double>>> latestChanges;
|
||||
|
||||
public DeterministicStateActionTable(DiscreteActionSpace<A> discreteActionSpace){
|
||||
table = new LinkedHashMap<>();
|
||||
this.discreteActionSpace = discreteActionSpace;
|
||||
latestChanges = new CircularFifoQueue<>(10);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -57,6 +59,7 @@ public class DeterministicStateActionTable<A extends Enum> implements StateActio
|
|||
actionValues = createDefaultActionValues();
|
||||
table.put(state, actionValues);
|
||||
}
|
||||
latestChanges.add(new AbstractMap.SimpleEntry<>(state, actionValues));
|
||||
actionValues.put(action, value);
|
||||
}
|
||||
|
||||
|
@ -72,6 +75,11 @@ public class DeterministicStateActionTable<A extends Enum> implements StateActio
|
|||
return table.get(state);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Queue<Map.Entry<State, Map<A, Double>>> getFirstStateEntriesForView() {
|
||||
return latestChanges;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return Map with initial values for every available action
|
||||
*/
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package core;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Queue;
|
||||
|
||||
/**
|
||||
* Q-Table which saves all seen states, all available actions for each state
|
||||
|
@ -15,4 +16,6 @@ public interface StateActionTable<A extends Enum> {
|
|||
void setValue(State state, A action, double value);
|
||||
int getStateCount();
|
||||
Map<A, Double> getActionValues(State state);
|
||||
|
||||
Queue<Map.Entry<State, Map<A, Double>>> getFirstStateEntriesForView();
|
||||
}
|
||||
|
|
|
@ -3,7 +3,8 @@ package core.algo.mc;
|
|||
import core.*;
|
||||
import core.algo.EpisodicLearning;
|
||||
import core.policy.EpsilonGreedyPolicy;
|
||||
import org.javatuples.Pair;
|
||||
import org.apache.commons.lang3.tuple.ImmutablePair;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.ObjectInputStream;
|
||||
|
@ -80,7 +81,7 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
|
|||
Set<Pair<State, A>> stateActionPairs = new LinkedHashSet<>();
|
||||
|
||||
for (StepResult<A> sr : episode) {
|
||||
stateActionPairs.add(new Pair<>(sr.getState(), sr.getAction()));
|
||||
stateActionPairs.add(new ImmutablePair<>(sr.getState(), sr.getAction()));
|
||||
}
|
||||
|
||||
//System.out.println("stateActionPairs " + stateActionPairs.size());
|
||||
|
@ -88,7 +89,7 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
|
|||
int firstOccurenceIndex = 0;
|
||||
// find first occurance of state action pair
|
||||
for (StepResult<A> sr : episode) {
|
||||
if (stateActionPair.getValue0().equals(sr.getState()) && stateActionPair.getValue1().equals(sr.getAction())) {
|
||||
if (stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())) {
|
||||
break;
|
||||
}
|
||||
firstOccurenceIndex++;
|
||||
|
@ -102,7 +103,7 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
|
|||
// if the key does not exists, it will create a new entry with G as default value
|
||||
returnSum.merge(stateActionPair, G, Double::sum);
|
||||
returnCount.merge(stateActionPair, 1, Integer::sum);
|
||||
stateActionTable.setValue(stateActionPair.getValue0(), stateActionPair.getValue1(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair));
|
||||
stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -15,10 +15,9 @@ import lombok.Setter;
|
|||
import javax.swing.*;
|
||||
import java.io.*;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
|
||||
public class RLController<A extends Enum> implements ViewListener, LearningListener {
|
||||
|
||||
public class RLController<A extends Enum> implements LearningListener {
|
||||
protected final String folderPrefix = "learningStates" + File.separator;
|
||||
protected Environment<A> environment;
|
||||
protected DiscreteActionSpace<A> discreteActionSpace;
|
||||
|
@ -37,15 +36,15 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
|||
protected int prevDelay;
|
||||
protected volatile boolean printNextEpisode;
|
||||
|
||||
public RLController(Environment<A> env, Method method, A... actions){
|
||||
public RLController(Environment<A> env, Method method, A... actions) {
|
||||
setEnvironment(env);
|
||||
setMethod(method);
|
||||
setAllowedActions(actions);
|
||||
printNextEpisode = true;
|
||||
}
|
||||
|
||||
public void start(){
|
||||
switch (method){
|
||||
public void start() {
|
||||
switch(method) {
|
||||
case MC_ONPOLICY_EGREEDY:
|
||||
learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
|
||||
break;
|
||||
|
@ -60,13 +59,13 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
|||
initLearning();
|
||||
}
|
||||
|
||||
protected void initListeners(){
|
||||
protected void initListeners() {
|
||||
learning.addListener(this);
|
||||
new Thread(() -> {
|
||||
while (true){
|
||||
while(true) {
|
||||
printNextEpisode = true;
|
||||
try {
|
||||
Thread.sleep(30*1000);
|
||||
Thread.sleep(30 * 1000);
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
|
@ -74,35 +73,42 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
|||
}).start();
|
||||
}
|
||||
|
||||
private void initLearning(){
|
||||
if(learning instanceof EpisodicLearning){
|
||||
private void initLearning() {
|
||||
if(learning instanceof EpisodicLearning) {
|
||||
System.out.println("Starting learning of <" + nrOfEpisodes + "> episodes");
|
||||
((EpisodicLearning) learning).learn(nrOfEpisodes);
|
||||
}else{
|
||||
((EpisodicLearning) learning).learn(nrOfEpisodes);
|
||||
} else {
|
||||
learning.learn();
|
||||
}
|
||||
}
|
||||
|
||||
/*************************************************
|
||||
** VIEW LISTENERS **
|
||||
*************************************************/
|
||||
@Override
|
||||
public void onLearnMoreEpisodes(int nrOfEpisodes){
|
||||
if(learning instanceof EpisodicLearning){
|
||||
protected void changeLearningDelay(int delay) {
|
||||
learning.setDelay(delay);
|
||||
}
|
||||
|
||||
protected void learnMoreEpisodes(int nrOfEpisodes) {
|
||||
if(learning instanceof EpisodicLearning) {
|
||||
((EpisodicLearning) learning).learn(nrOfEpisodes);
|
||||
}else{
|
||||
} else {
|
||||
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onLoadState(String fileName) {
|
||||
protected void changeEpsilon(float epsilon) {
|
||||
if(learning.getPolicy() instanceof EpsilonPolicy) {
|
||||
((EpsilonPolicy<A>) learning.getPolicy()).setEpsilon(epsilon);
|
||||
} else {
|
||||
System.out.println("Trying to call inEpsilonChange on non-epsilon policy");
|
||||
}
|
||||
}
|
||||
|
||||
protected void saveState(String fileName) {
|
||||
FileInputStream fis;
|
||||
ObjectInputStream in;
|
||||
try {
|
||||
fis = new FileInputStream(fileName);
|
||||
in = new ObjectInputStream(fis);
|
||||
System.out.println("interrupt" + Thread.currentThread().getId());
|
||||
System.out.println("interrup" + Thread.currentThread().getId());
|
||||
learning.interruptLearning();
|
||||
learning.load(in);
|
||||
in.close();
|
||||
|
@ -111,46 +117,26 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onSaveState(String fileName) {
|
||||
protected void loadState(String fileName) {
|
||||
FileOutputStream fos;
|
||||
ObjectOutputStream out;
|
||||
try{
|
||||
fos = new FileOutputStream(folderPrefix + fileName);
|
||||
try {
|
||||
fos = new FileOutputStream(folderPrefix + fileName);
|
||||
out = new ObjectOutputStream(fos);
|
||||
learning.interruptLearning();
|
||||
learning.save(out);
|
||||
out.close();
|
||||
}catch (IOException e){
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onEpsilonChange(float epsilon) {
|
||||
if(learning.getPolicy() instanceof EpsilonPolicy){
|
||||
((EpsilonPolicy<A>) learning.getPolicy()).setEpsilon(epsilon);
|
||||
}else{
|
||||
System.out.println("Trying to call inEpsilonChange on non-epsilon policy");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onDelayChange(int delay) {
|
||||
changeLearningDelay(delay);
|
||||
}
|
||||
|
||||
protected void changeLearningDelay(int delay){
|
||||
learning.setDelay(delay);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFastLearnChange(boolean fastLearn) {
|
||||
protected void changeFastLearning(boolean fastLearn) {
|
||||
this.fastLearning = fastLearn;
|
||||
if(fastLearn){
|
||||
if(fastLearn) {
|
||||
prevDelay = learning.getDelay();
|
||||
changeLearningDelay(0);
|
||||
}else{
|
||||
} else {
|
||||
changeLearningDelay(prevDelay);
|
||||
}
|
||||
}
|
||||
|
@ -165,7 +151,6 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
|||
@Override
|
||||
public void onLearningEnd() {
|
||||
System.out.println("Learning finished");
|
||||
onSaveState( method.toString() + System.currentTimeMillis()/1000 + (learning instanceof EpisodicLearning ? "e" + ((EpisodicLearning) learning).getCurrentEpisode() : ""));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -176,9 +161,9 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
|||
@Override
|
||||
public void onEpisodeEnd(List<Double> rewardHistory) {
|
||||
latestRewardsHistory = rewardHistory;
|
||||
if(printNextEpisode){
|
||||
System.out.println("Episode " + ((EpisodicLearning) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size()-1));
|
||||
System.out.println("Eps/sec: " + ((EpisodicLearning) learning).getEpisodePerSecond());
|
||||
if(printNextEpisode) {
|
||||
System.out.println("Episode " + ((EpisodicLearning) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1));
|
||||
System.out.println("Eps/sec: " + ((EpisodicLearning) learning).getEpisodePerSecond());
|
||||
printNextEpisode = false;
|
||||
}
|
||||
}
|
||||
|
@ -192,22 +177,22 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
|||
** SETTERS **
|
||||
*************************************************/
|
||||
|
||||
private void setEnvironment(Environment<A> environment){
|
||||
if(environment == null){
|
||||
private void setEnvironment(Environment<A> environment) {
|
||||
if(environment == null) {
|
||||
throw new IllegalArgumentException("Environment cannot be null");
|
||||
}
|
||||
this.environment = environment;
|
||||
}
|
||||
|
||||
private void setMethod(Method method){
|
||||
if(method == null){
|
||||
private void setMethod(Method method) {
|
||||
if(method == null) {
|
||||
throw new IllegalArgumentException("Method cannot be null");
|
||||
}
|
||||
this.method = method;
|
||||
}
|
||||
|
||||
private void setAllowedActions(A[] actions){
|
||||
if(actions == null || actions.length == 0){
|
||||
private void setAllowedActions(A[] actions) {
|
||||
if(actions == null || actions.length == 0) {
|
||||
throw new IllegalArgumentException("There has to be at least one action");
|
||||
}
|
||||
this.discreteActionSpace = new ListDiscreteActionSpace<>(actions);
|
||||
|
|
|
@ -1,14 +1,16 @@
|
|||
package core.controller;
|
||||
|
||||
import core.Environment;
|
||||
import core.algo.EpisodicLearning;
|
||||
import core.algo.Method;
|
||||
import core.gui.LearningView;
|
||||
import core.gui.View;
|
||||
import core.listener.ViewListener;
|
||||
|
||||
import javax.swing.*;
|
||||
import java.util.List;
|
||||
|
||||
public class RLControllerGUI<A extends Enum> extends RLController<A> {
|
||||
public class RLControllerGUI<A extends Enum> extends RLController<A> implements ViewListener {
|
||||
private LearningView learningView;
|
||||
|
||||
public RLControllerGUI(Environment<A> env, Method method, A... actions) {
|
||||
|
@ -23,21 +25,41 @@ public class RLControllerGUI<A extends Enum> extends RLController<A> {
|
|||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onLearnMoreEpisodes(int nrOfEpisodes) {
|
||||
super.onLearnMoreEpisodes(nrOfEpisodes);
|
||||
learningView.updateLearningInfoPanel();
|
||||
}
|
||||
/*************************************************
|
||||
** View LISTENERS **
|
||||
*************************************************/
|
||||
|
||||
@Override
|
||||
public void onLoadState(String fileName) {
|
||||
super.onLoadState(fileName);
|
||||
super.loadState(fileName);
|
||||
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onSaveState(String fileName) {
|
||||
super.saveState(fileName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onShowQTable() {
|
||||
learningView.showQTableFrame();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onEpsilonChange(float epsilon) {
|
||||
super.onEpsilonChange(epsilon);
|
||||
super.changeEpsilon(epsilon);
|
||||
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onDelayChange(int delay) {
|
||||
super.changeLearningDelay(delay);
|
||||
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFastLearnChange(boolean isFastLearn) {
|
||||
super.changeFastLearning(isFastLearn);
|
||||
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
|
||||
}
|
||||
|
||||
|
@ -48,17 +70,23 @@ public class RLControllerGUI<A extends Enum> extends RLController<A> {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void onLearningEnd() {
|
||||
super.onLearningEnd();
|
||||
SwingUtilities.invokeLater(() -> learningView.updateRewardGraph(latestRewardsHistory));
|
||||
public void onLearnMoreEpisodes(int nrOfEpisodes) {
|
||||
super.learnMoreEpisodes(nrOfEpisodes);
|
||||
learningView.updateLearningInfoPanel();
|
||||
}
|
||||
|
||||
|
||||
/*************************************************
|
||||
** LEARNING LISTENERS **
|
||||
*************************************************/
|
||||
|
||||
@Override
|
||||
public void onEpisodeEnd(List<Double> rewardHistory) {
|
||||
super.onEpisodeEnd(rewardHistory);
|
||||
SwingUtilities.invokeLater(() -> {
|
||||
if (!fastLearning) {
|
||||
if(!fastLearning) {
|
||||
learningView.updateRewardGraph(latestRewardsHistory);
|
||||
learningView.updateQTable();
|
||||
}
|
||||
learningView.updateLearningInfoPanel();
|
||||
});
|
||||
|
@ -66,8 +94,15 @@ public class RLControllerGUI<A extends Enum> extends RLController<A> {
|
|||
|
||||
@Override
|
||||
public void onStepEnd() {
|
||||
if (!fastLearning) {
|
||||
if(!fastLearning) {
|
||||
SwingUtilities.invokeLater(() -> learningView.repaintEnvironment());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onLearningEnd() {
|
||||
super.onLearningEnd();
|
||||
onSaveState(method.toString() + System.currentTimeMillis() / 1000 + (learning instanceof EpisodicLearning ? "e" + ((EpisodicLearning) learning).getCurrentEpisode() : ""));
|
||||
SwingUtilities.invokeLater(() -> learningView.updateRewardGraph(latestRewardsHistory));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,24 +26,25 @@ public class LearningInfoPanel extends JPanel {
|
|||
private JCheckBox drawEnvironmentCheckbox;
|
||||
private JTextField learnMoreEpisodesInput;
|
||||
private JButton learnMoreEpisodesButton;
|
||||
private JButton showQTableButton;
|
||||
|
||||
public LearningInfoPanel(Learning learning, ViewListener viewListener){
|
||||
public LearningInfoPanel(Learning learning, ViewListener viewListener) {
|
||||
this.learning = learning;
|
||||
setLayout(new BoxLayout(this, BoxLayout.Y_AXIS));
|
||||
policyLabel = new JLabel();
|
||||
discountLabel = new JLabel();
|
||||
delayLabel = new JLabel();
|
||||
if(learning instanceof Episodic){
|
||||
if(learning instanceof Episodic) {
|
||||
episodeLabel = new JLabel();
|
||||
add(episodeLabel);
|
||||
}
|
||||
delaySlider = new JSlider(0,1000, learning.getDelay());
|
||||
delaySlider = new JSlider(0, 1000, learning.getDelay());
|
||||
delaySlider.addChangeListener(e -> viewListener.onDelayChange(delaySlider.getValue()));
|
||||
add(policyLabel);
|
||||
add(discountLabel);
|
||||
if(learning.getPolicy() instanceof EpsilonPolicy){
|
||||
if(learning.getPolicy() instanceof EpsilonPolicy) {
|
||||
epsilonLabel = new JLabel();
|
||||
epsilonSlider = new JSlider(0, 100, (int)((EpsilonPolicy)learning.getPolicy()).getEpsilon() * 100);
|
||||
epsilonSlider = new JSlider(0, 100, (int) ((EpsilonPolicy) learning.getPolicy()).getEpsilon() * 100);
|
||||
epsilonSlider.addChangeListener(e -> viewListener.onEpsilonChange(epsilonSlider.getValue() / 100f));
|
||||
add(epsilonLabel);
|
||||
add(epsilonSlider);
|
||||
|
@ -51,7 +52,7 @@ public class LearningInfoPanel extends JPanel {
|
|||
|
||||
toggleFastLearningButton = new JButton("Enable fast-learn");
|
||||
fastLearning = false;
|
||||
toggleFastLearningButton.addActionListener(e->{
|
||||
toggleFastLearningButton.addActionListener(e -> {
|
||||
fastLearning = !fastLearning;
|
||||
delaySlider.setEnabled(!fastLearning);
|
||||
epsilonSlider.setEnabled(!fastLearning);
|
||||
|
@ -71,10 +72,10 @@ public class LearningInfoPanel extends JPanel {
|
|||
|
||||
if(learning instanceof EpisodicLearning) {
|
||||
learnMoreEpisodesInput = new JTextField();
|
||||
learnMoreEpisodesInput.setMaximumSize(new Dimension(200,20));
|
||||
learnMoreEpisodesInput.setMaximumSize(new Dimension(200, 20));
|
||||
learnMoreEpisodesButton = new JButton("Learn More Episodes");
|
||||
learnMoreEpisodesButton.addActionListener(e -> {
|
||||
if (Util.isNumeric(learnMoreEpisodesInput.getText())) {
|
||||
if(Util.isNumeric(learnMoreEpisodesInput.getText())) {
|
||||
viewListener.onLearnMoreEpisodes(Integer.parseInt(learnMoreEpisodesInput.getText()));
|
||||
} else {
|
||||
learnMoreEpisodesInput.setText("");
|
||||
|
@ -83,9 +84,14 @@ public class LearningInfoPanel extends JPanel {
|
|||
add(learnMoreEpisodesInput);
|
||||
add(learnMoreEpisodesButton);
|
||||
}
|
||||
showQTableButton = new JButton("Show Q-Table");
|
||||
showQTableButton.addActionListener(e -> {
|
||||
viewListener.onShowQTable();
|
||||
});
|
||||
add(drawEnvironmentCheckbox);
|
||||
add(smoothGraphCheckbox);
|
||||
add(last100Checkbox);
|
||||
add(showQTableButton);
|
||||
refreshLabels();
|
||||
setVisible(true);
|
||||
}
|
||||
|
@ -93,17 +99,17 @@ public class LearningInfoPanel extends JPanel {
|
|||
public void refreshLabels() {
|
||||
policyLabel.setText("Policy: " + learning.getPolicy().getClass());
|
||||
discountLabel.setText("Discount factor: " + learning.getDiscountFactor());
|
||||
if(learning instanceof Episodic){
|
||||
episodeLabel.setText("Episode: " + ((Episodic)(learning)).getCurrentEpisode() +
|
||||
"\t Episodes to go: " + ((Episodic)(learning)).getEpisodesToGo() +
|
||||
"\t Eps/Sec: " + ((Episodic)(learning)).getEpisodesPerSecond());
|
||||
if(learning instanceof Episodic) {
|
||||
episodeLabel.setText("Episode: " + ((Episodic) (learning)).getCurrentEpisode() +
|
||||
"\t Episodes to go: " + ((Episodic) (learning)).getEpisodesToGo() +
|
||||
"\t Eps/Sec: " + ((Episodic) (learning)).getEpisodesPerSecond());
|
||||
}
|
||||
if (learning.getPolicy() instanceof EpsilonPolicy) {
|
||||
if(learning.getPolicy() instanceof EpsilonPolicy) {
|
||||
epsilonLabel.setText("Exploration (Epsilon): " + ((EpsilonPolicy) learning.getPolicy()).getEpsilon());
|
||||
epsilonSlider.setValue((int)(((EpsilonPolicy) learning.getPolicy()).getEpsilon() * 100));
|
||||
epsilonSlider.setValue((int) (((EpsilonPolicy) learning.getPolicy()).getEpsilon() * 100));
|
||||
}
|
||||
delayLabel.setText("Delay (ms): " + learning.getDelay());
|
||||
if(delaySlider.isEnabled()){
|
||||
if(delaySlider.isEnabled()) {
|
||||
delaySlider.setValue(learning.getDelay());
|
||||
}
|
||||
toggleFastLearningButton.setText(fastLearning ? "Disable fast-learning" : "Enable fast-learning");
|
||||
|
@ -112,11 +118,12 @@ public class LearningInfoPanel extends JPanel {
|
|||
protected boolean isSmoothenGraphSelected() {
|
||||
return smoothGraphCheckbox.isSelected();
|
||||
}
|
||||
protected boolean isLast100Selected(){
|
||||
|
||||
protected boolean isLast100Selected() {
|
||||
return last100Checkbox.isSelected();
|
||||
}
|
||||
|
||||
protected boolean isDrawEnvironmentSelected(){
|
||||
protected boolean isDrawEnvironmentSelected() {
|
||||
return drawEnvironmentCheckbox.isSelected();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,5 +8,7 @@ import java.util.List;
|
|||
public interface LearningView {
|
||||
void repaintEnvironment();
|
||||
void updateLearningInfoPanel();
|
||||
void updateQTable();
|
||||
void updateRewardGraph(final List<Double> rewardHistory);
|
||||
void showQTableFrame();
|
||||
}
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
package core.gui;
|
||||
|
||||
import core.State;
|
||||
import core.StateActionTable;
|
||||
|
||||
import javax.swing.*;
|
||||
import java.awt.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class QTableFrame<A extends Enum> extends JFrame {
|
||||
private JLabel stateCountLabel;
|
||||
private StateActionTable<A> stateActionTable;
|
||||
private List<StateActionRow<A>> rows;
|
||||
private JPanel areaWrapper;
|
||||
|
||||
public QTableFrame(StateActionTable<A> stateActionTable) {
|
||||
super("Q-Table");
|
||||
this.stateActionTable = stateActionTable;
|
||||
rows = new ArrayList<>(10);
|
||||
setDefaultCloseOperation(WindowConstants.HIDE_ON_CLOSE);
|
||||
setLayout(new BorderLayout());
|
||||
setPreferredSize(new Dimension(500, 500));
|
||||
stateCountLabel = new JLabel();
|
||||
add(BorderLayout.NORTH, stateCountLabel);
|
||||
areaWrapper = new JPanel();
|
||||
areaWrapper.setLayout(new BoxLayout(areaWrapper, BoxLayout.Y_AXIS));
|
||||
for(int i = 0; i < 10; ++i) {
|
||||
StateActionRow<A> a = new StateActionRow<>();
|
||||
rows.add(a);
|
||||
areaWrapper.add(a);
|
||||
}
|
||||
add(BorderLayout.CENTER, areaWrapper);
|
||||
setVisible(false);
|
||||
pack();
|
||||
}
|
||||
|
||||
private void refreshAllTextAreas(){
|
||||
for(StateActionRow<A> row : rows){
|
||||
row.refreshLabels();
|
||||
}
|
||||
}
|
||||
protected void refreshQTable() {
|
||||
System.out.println("ref");
|
||||
int stateCount = stateActionTable.getStateCount();
|
||||
stateCountLabel.setText("Total states: " + stateCount);
|
||||
int idx = -1;
|
||||
for(Map.Entry<State, Map<A, Double>> entry : stateActionTable.getFirstStateEntriesForView()) {
|
||||
if(++idx > rows.size() -1) break;
|
||||
StateActionRow<A> row = rows.get(idx);
|
||||
row.setState(entry.getKey());
|
||||
row.setActionValues(entry.getValue());
|
||||
}
|
||||
refreshAllTextAreas();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
package core.gui;
|
||||
|
||||
import core.State;
|
||||
import lombok.Setter;
|
||||
|
||||
import javax.swing.*;
|
||||
import java.awt.*;
|
||||
import java.awt.event.MouseAdapter;
|
||||
import java.awt.event.MouseEvent;
|
||||
import java.util.Map;
|
||||
|
||||
@Setter
|
||||
public class StateActionRow<A extends Enum> extends JTextArea {
|
||||
private State state;
|
||||
private Map<A, Double> actionValues;
|
||||
|
||||
public StateActionRow(){
|
||||
this.state = null;
|
||||
this.actionValues = null;
|
||||
setMaximumSize(new Dimension(600, 100));
|
||||
setEditable(false);
|
||||
addMouseListener(new MouseAdapter() {
|
||||
@Override
|
||||
public void mousePressed(MouseEvent e) {
|
||||
super.mousePressed(e);
|
||||
showState();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
protected void refreshLabels(){
|
||||
if(state == null || actionValues == null) return;
|
||||
System.out.println("refreshing");
|
||||
StringBuilder sb = new StringBuilder(state.toString()).append("\n");
|
||||
for(Map.Entry<A, Double> actionValue: actionValues.entrySet()){
|
||||
sb.append("\t").append(actionValue.getKey()).append("\t").append(actionValue.getValue()).append("\n");
|
||||
}
|
||||
setText(sb.toString());
|
||||
}
|
||||
|
||||
private void showState() {
|
||||
if(state != null && state instanceof Visualizable){
|
||||
new JFrame() {
|
||||
{
|
||||
JComponent stateComponent = ((Visualizable)state).visualize();
|
||||
setPreferredSize(stateComponent.getPreferredSize());
|
||||
setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE);
|
||||
add(stateComponent);
|
||||
pack();
|
||||
setVisible(true);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
|
@ -4,7 +4,8 @@ import core.Environment;
|
|||
import core.algo.Learning;
|
||||
import core.listener.ViewListener;
|
||||
import lombok.Getter;
|
||||
import org.javatuples.Pair;
|
||||
import org.apache.commons.lang3.tuple.ImmutablePair;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.knowm.xchart.QuickChart;
|
||||
import org.knowm.xchart.XChartPanel;
|
||||
import org.knowm.xchart.XYChart;
|
||||
|
@ -16,7 +17,7 @@ import java.io.File;
|
|||
import java.util.List;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
|
||||
public class View<A extends Enum> implements LearningView{
|
||||
public class View<A extends Enum> implements LearningView {
|
||||
private Learning<A> learning;
|
||||
private Environment<A> environment;
|
||||
@Getter
|
||||
|
@ -26,6 +27,7 @@ public class View<A extends Enum> implements LearningView{
|
|||
@Getter
|
||||
private JFrame mainFrame;
|
||||
private JFrame environmentFrame;
|
||||
private QTableFrame<A> qTableFrame;
|
||||
private XChartPanel<XYChart> rewardChartPanel;
|
||||
private ViewListener viewListener;
|
||||
private JMenuBar menuBar;
|
||||
|
@ -36,11 +38,12 @@ public class View<A extends Enum> implements LearningView{
|
|||
this.environment = environment;
|
||||
this.viewListener = viewListener;
|
||||
initMainFrame();
|
||||
initQTableFrame();
|
||||
}
|
||||
|
||||
private void initMainFrame() {
|
||||
mainFrame = new JFrame();
|
||||
mainFrame.setPreferredSize(new Dimension(1280, 720));
|
||||
mainFrame.setPreferredSize(new Dimension(1000, 400));
|
||||
mainFrame.setLayout(new BorderLayout());
|
||||
menuBar = new JMenuBar();
|
||||
fileMenu = new JMenu("File");
|
||||
|
@ -52,7 +55,7 @@ public class View<A extends Enum> implements LearningView{
|
|||
fc.setCurrentDirectory(new File(System.getProperty("user.dir")));
|
||||
int returnVal = fc.showOpenDialog(mainFrame);
|
||||
|
||||
if (returnVal == JFileChooser.APPROVE_OPTION) {
|
||||
if(returnVal == JFileChooser.APPROVE_OPTION) {
|
||||
viewListener.onLoadState(fc.getSelectedFile().toString());
|
||||
}
|
||||
}
|
||||
|
@ -62,7 +65,7 @@ public class View<A extends Enum> implements LearningView{
|
|||
@Override
|
||||
public void actionPerformed(ActionEvent e) {
|
||||
String fileName = JOptionPane.showInputDialog("Enter file name", "path/to/file");
|
||||
if(fileName != null){
|
||||
if(fileName != null) {
|
||||
viewListener.onSaveState(fileName);
|
||||
}
|
||||
}
|
||||
|
@ -78,7 +81,7 @@ public class View<A extends Enum> implements LearningView{
|
|||
mainFrame.pack();
|
||||
mainFrame.setVisible(true);
|
||||
|
||||
if (environment instanceof Visualizable) {
|
||||
if(environment instanceof Visualizable) {
|
||||
environmentFrame = new JFrame() {
|
||||
{
|
||||
add(((Visualizable) environment).visualize());
|
||||
|
@ -86,9 +89,21 @@ public class View<A extends Enum> implements LearningView{
|
|||
setVisible(true);
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
private void initQTableFrame(){
|
||||
qTableFrame = new QTableFrame<>(learning.getStateActionTable());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void updateQTable() {
|
||||
qTableFrame.refreshQTable();
|
||||
}
|
||||
|
||||
public void showQTableFrame(){
|
||||
updateQTable();
|
||||
qTableFrame.setVisible(true);
|
||||
}
|
||||
|
||||
private void initLearningInfoPanel() {
|
||||
learningInfoPanel = new LearningInfoPanel(learning, viewListener);
|
||||
|
@ -109,32 +124,21 @@ public class View<A extends Enum> implements LearningView{
|
|||
rewardChartPanel.setPreferredSize(new Dimension(300, 300));
|
||||
}
|
||||
|
||||
public void showState(Visualizable state) {
|
||||
new JFrame() {
|
||||
{
|
||||
JComponent stateComponent = state.visualize();
|
||||
setPreferredSize(new Dimension(stateComponent.getWidth(), stateComponent.getHeight()));
|
||||
add(stateComponent);
|
||||
setVisible(true);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
public void updateRewardGraph(final List<Double> rewardHistory) {
|
||||
List<Integer> xValues;
|
||||
List<Double> yValues;
|
||||
if(learningInfoPanel.isLast100Selected()){
|
||||
if(learningInfoPanel.isLast100Selected()) {
|
||||
yValues = new CopyOnWriteArrayList<>(rewardHistory.subList(rewardHistory.size() - Math.min(rewardHistory.size(), 100), rewardHistory.size()));
|
||||
xValues = new CopyOnWriteArrayList<>();
|
||||
for(int i = rewardHistory.size() - Math.min(rewardHistory.size(), 100); i <rewardHistory.size(); ++i){
|
||||
for(int i = rewardHistory.size() - Math.min(rewardHistory.size(), 100); i < rewardHistory.size(); ++i) {
|
||||
xValues.add(i);
|
||||
}
|
||||
}else{
|
||||
if(learningInfoPanel.isSmoothenGraphSelected()){
|
||||
} else {
|
||||
if(learningInfoPanel.isSmoothenGraphSelected()) {
|
||||
Pair<List<Integer>, List<Double>> XYvalues = smoothenGraph(rewardHistory);
|
||||
xValues = XYvalues.getValue0();
|
||||
yValues = XYvalues.getValue1();
|
||||
}else{
|
||||
xValues = XYvalues.getKey();
|
||||
yValues = XYvalues.getValue();
|
||||
} else {
|
||||
xValues = null;
|
||||
yValues = rewardHistory;
|
||||
}
|
||||
|
@ -145,37 +149,37 @@ public class View<A extends Enum> implements LearningView{
|
|||
rewardChartPanel.repaint();
|
||||
}
|
||||
|
||||
private Pair<List<Integer>, List<Double>> smoothenGraph(List<Double> original){
|
||||
private Pair<List<Integer>, List<Double>> smoothenGraph(List<Double> original) {
|
||||
int totalXPoints = 100;
|
||||
|
||||
List<Integer> xValues = new CopyOnWriteArrayList<>();
|
||||
List<Double> tmp = new CopyOnWriteArrayList<>();
|
||||
int meanBatch = original.size() / totalXPoints;
|
||||
if(meanBatch < 1){
|
||||
if(meanBatch < 1) {
|
||||
meanBatch = 1;
|
||||
}
|
||||
|
||||
int idx = 0;
|
||||
int batchIdx = 0;
|
||||
double batchSum = 0;
|
||||
for(Double x: original) {
|
||||
for(Double x : original) {
|
||||
++idx;
|
||||
batchSum += x;
|
||||
if (idx == 1 || ++batchIdx % meanBatch == 0) {
|
||||
if(idx == 1 || ++batchIdx % meanBatch == 0) {
|
||||
tmp.add(batchSum / meanBatch);
|
||||
xValues.add(idx);
|
||||
batchSum = 0;
|
||||
}
|
||||
}
|
||||
return new Pair<>(xValues, tmp);
|
||||
return new ImmutablePair<>(xValues, tmp);
|
||||
}
|
||||
|
||||
public void updateLearningInfoPanel() {
|
||||
this.learningInfoPanel.refreshLabels();
|
||||
}
|
||||
|
||||
public void repaintEnvironment(){
|
||||
if (environmentFrame != null && learningInfoPanel.isDrawEnvironmentSelected()) {
|
||||
public void repaintEnvironment() {
|
||||
if(environmentFrame != null && learningInfoPanel.isDrawEnvironmentSelected()) {
|
||||
environmentFrame.repaint();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,4 +12,5 @@ public interface ViewListener {
|
|||
void onLearnMoreEpisodes(int nrOfEpisodes);
|
||||
void onLoadState(String fileName);
|
||||
void onSaveState(String fileName);
|
||||
void onShowQTable();
|
||||
}
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
package evironment.jumpingDino;
|
||||
|
||||
public class Config {
|
||||
public static final int FRAME_WIDTH = 1280;
|
||||
public static final int FRAME_HEIGHT = 720;
|
||||
public static final int FRAME_WIDTH = 800;
|
||||
public static final int FRAME_HEIGHT = 300;
|
||||
public static final int GROUND_Y = 50;
|
||||
public static final int DINO_STARTING_X = 50;
|
||||
public static final int DINO_SIZE = 50;
|
||||
|
|
|
@ -1,16 +1,20 @@
|
|||
package evironment.jumpingDino;
|
||||
|
||||
import core.State;
|
||||
import core.gui.Visualizable;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
|
||||
import javax.swing.*;
|
||||
import java.awt.*;
|
||||
import java.io.Serializable;
|
||||
import java.util.Objects;
|
||||
|
||||
@AllArgsConstructor
|
||||
@Getter
|
||||
public class DinoState implements State, Serializable {
|
||||
public class DinoState implements State, Serializable, Visualizable {
|
||||
private int xDistanceToObstacle;
|
||||
protected final double scale = 0.5;
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
|
@ -31,4 +35,31 @@ public class DinoState implements State, Serializable {
|
|||
public int hashCode() {
|
||||
return Objects.hash(xDistanceToObstacle);
|
||||
}
|
||||
|
||||
@Override
|
||||
public JComponent visualize() {
|
||||
return new JComponent() {
|
||||
{
|
||||
setPreferredSize(new Dimension(Config.FRAME_WIDTH, (int)(scale * Config.FRAME_HEIGHT)));
|
||||
setVisible(true);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void paintComponent(Graphics g) {
|
||||
super.paintComponents(g);
|
||||
drawObjects(g);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
public void drawObjects(Graphics g){
|
||||
g.setColor(Color.BLACK);
|
||||
g.fillRect(0, (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y)), Config.FRAME_WIDTH, 2);
|
||||
|
||||
g.fillRect((int)(scale * Config.DINO_STARTING_X), (int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE)), (int)(scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
|
||||
g.drawString("Distance: " + xDistanceToObstacle, (int)(scale * Config.DINO_STARTING_X),(int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40) ));
|
||||
|
||||
g.fillRect((int)(scale * (Config.DINO_STARTING_X + getXDistanceToObstacle())), (int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE)), (int)(scale * Config.OBSTACLE_SIZE), (int)(scale *Config.OBSTACLE_SIZE));
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
package evironment.jumpingDino;
|
||||
|
||||
import core.gui.Visualizable;
|
||||
import lombok.Getter;
|
||||
|
||||
import java.awt.*;
|
||||
import java.util.Objects;
|
||||
|
||||
@Getter
|
||||
public class DinoStateWithSpeed extends DinoState{
|
||||
public class DinoStateWithSpeed extends DinoState implements Visualizable {
|
||||
private int obstacleSpeed;
|
||||
|
||||
public DinoStateWithSpeed(int xDistanceToObstacle, int obstacleSpeed) {
|
||||
|
@ -33,4 +35,10 @@ public class DinoStateWithSpeed extends DinoState{
|
|||
public int hashCode() {
|
||||
return Objects.hash(super.hashCode(), getObstacleSpeed());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void drawObjects(Graphics g) {
|
||||
super.drawObjects(g);
|
||||
g.drawString("Speed: " + obstacleSpeed, (int)(scale * (Config.DINO_STARTING_X + getXDistanceToObstacle())),(int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40)) );
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ public class JumpingDino {
|
|||
Method.MC_ONPOLICY_EGREEDY,
|
||||
DinoAction.values());
|
||||
|
||||
rl.setDelay(0);
|
||||
rl.setDelay(100);
|
||||
rl.setDiscountFactor(1f);
|
||||
rl.setEpsilon(0.15f);
|
||||
rl.setNrOfEpisodes(100000);
|
||||
|
|
Loading…
Reference in New Issue