split GUI parts from controller into sub class

This commit is contained in:
Jan Löwenstrom 2019-12-31 14:43:40 +01:00
parent 195722e98f
commit 518683b676
5 changed files with 128 additions and 62 deletions

View File

@ -10,6 +10,7 @@ import core.gui.View;
import core.listener.LearningListener; import core.listener.LearningListener;
import core.listener.ViewListener; import core.listener.ViewListener;
import core.policy.EpsilonPolicy; import core.policy.EpsilonPolicy;
import lombok.Setter;
import javax.swing.*; import javax.swing.*;
import java.io.*; import java.io.*;
@ -18,24 +19,29 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
public class RLController<A extends Enum> implements ViewListener, LearningListener { public class RLController<A extends Enum> implements ViewListener, LearningListener {
private final String folderPrefix = "learningStates" + File.separator; protected final String folderPrefix = "learningStates" + File.separator;
private Environment<A> environment; protected Environment<A> environment;
private DiscreteActionSpace<A> discreteActionSpace; protected DiscreteActionSpace<A> discreteActionSpace;
private Method method; protected Method method;
private int delay = LearningConfig.DEFAULT_DELAY; @Setter
private float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR; protected int delay = LearningConfig.DEFAULT_DELAY;
private float epsilon = LearningConfig.DEFAULT_EPSILON; @Setter
private Learning<A> learning; protected float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR;
private LearningView learningView; @Setter
private boolean fastLearning; protected float epsilon = LearningConfig.DEFAULT_EPSILON;
private List<Double> latestRewardsHistory; protected Learning<A> learning;
private int nrOfEpisodes; protected boolean fastLearning;
private int prevDelay; protected List<Double> latestRewardsHistory;
@Setter
protected int nrOfEpisodes;
protected int prevDelay;
protected volatile boolean printNextEpisode;
public RLController(Environment<A> env, Method method, A... actions){ public RLController(Environment<A> env, Method method, A... actions){
setEnvironment(env); setEnvironment(env);
setMethod(method); setMethod(method);
setAllowedActions(actions); setAllowedActions(actions);
printNextEpisode = true;
} }
public void start(){ public void start(){
@ -48,20 +54,29 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
default: default:
throw new IllegalArgumentException("Undefined method"); throw new IllegalArgumentException("Undefined method");
} }
System.out.println("Initialized learning: " + learning.getClass());
initGUI(); initListeners();
System.out.println("Set listeners");
initLearning(); initLearning();
} }
private void initGUI(){ protected void initListeners(){
SwingUtilities.invokeLater(()->{
learningView = new View<>(learning, environment, this);
learning.addListener(this); learning.addListener(this);
}); new Thread(() -> {
while (true){
printNextEpisode = true;
try {
Thread.sleep(30*1000);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}).start();
} }
private void initLearning(){ private void initLearning(){
if(learning instanceof EpisodicLearning){ if(learning instanceof EpisodicLearning){
System.out.println("Starting learning of <" + nrOfEpisodes + "> episodes");
((EpisodicLearning) learning).learn(nrOfEpisodes); ((EpisodicLearning) learning).learn(nrOfEpisodes);
}else{ }else{
learning.learn(); learning.learn();
@ -78,7 +93,6 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
}else{ }else{
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!"); throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
} }
learningView.updateLearningInfoPanel();
} }
@Override @Override
@ -88,10 +102,9 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
try { try {
fis = new FileInputStream(fileName); fis = new FileInputStream(fileName);
in = new ObjectInputStream(fis); in = new ObjectInputStream(fis);
System.out.println("interrup" + Thread.currentThread().getId()); System.out.println("interrupt" + Thread.currentThread().getId());
learning.interruptLearning(); learning.interruptLearning();
learning.load(in); learning.load(in);
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
in.close(); in.close();
} catch (IOException | ClassNotFoundException e) { } catch (IOException | ClassNotFoundException e) {
e.printStackTrace(); e.printStackTrace();
@ -117,7 +130,6 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
public void onEpsilonChange(float epsilon) { public void onEpsilonChange(float epsilon) {
if(learning.getPolicy() instanceof EpsilonPolicy){ if(learning.getPolicy() instanceof EpsilonPolicy){
((EpsilonPolicy<A>) learning.getPolicy()).setEpsilon(epsilon); ((EpsilonPolicy<A>) learning.getPolicy()).setEpsilon(epsilon);
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
}else{ }else{
System.out.println("Trying to call inEpsilonChange on non-epsilon policy"); System.out.println("Trying to call inEpsilonChange on non-epsilon policy");
} }
@ -128,9 +140,8 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
changeLearningDelay(delay); changeLearningDelay(delay);
} }
private void changeLearningDelay(int delay){ protected void changeLearningDelay(int delay){
learning.setDelay(delay); learning.setDelay(delay);
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
} }
@Override @Override
@ -153,19 +164,8 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
@Override @Override
public void onLearningEnd() { public void onLearningEnd() {
SwingUtilities.invokeLater(()-> learningView.updateRewardGraph(latestRewardsHistory)); System.out.println("Learning finished");
onSaveState( method.toString() + System.currentTimeMillis()/1000 + (learning instanceof EpisodicLearning ? "e " + ((EpisodicLearning) learning).getCurrentEpisode() : "")); onSaveState( method.toString() + System.currentTimeMillis()/1000 + (learning instanceof EpisodicLearning ? "e" + ((EpisodicLearning) learning).getCurrentEpisode() : ""));
}
@Override
public void onEpisodeEnd(List<Double> rewardHistory) {
latestRewardsHistory = rewardHistory;
SwingUtilities.invokeLater(() ->{
if(!fastLearning){
learningView.updateRewardGraph(latestRewardsHistory);
}
learningView.updateLearningInfoPanel();
});
} }
@Override @Override
@ -174,12 +174,19 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
} }
@Override @Override
public void onStepEnd() { public void onEpisodeEnd(List<Double> rewardHistory) {
if(!fastLearning){ latestRewardsHistory = rewardHistory;
SwingUtilities.invokeLater(() -> learningView.repaintEnvironment()); if(printNextEpisode){
System.out.println("Episode " + ((EpisodicLearning) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size()-1));
System.out.println("Eps/sec: " + ((EpisodicLearning) learning).getEpisodePerSecond());
printNextEpisode = false;
} }
} }
@Override
public void onStepEnd() {
}
/************************************************* /*************************************************
** SETTERS ** ** SETTERS **
@ -205,19 +212,4 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
} }
this.discreteActionSpace = new ListDiscreteActionSpace<>(actions); this.discreteActionSpace = new ListDiscreteActionSpace<>(actions);
} }
public void setDelay(int delay){
this.delay = delay;
}
public void setEpisodes(int nrOfEpisodes){
this.nrOfEpisodes = nrOfEpisodes;
}
public void setDiscountFactor(float discountFactor){
this.discountFactor = discountFactor;
}
public void setEpsilon(float epsilon){
this.epsilon = epsilon;
}
} }

View File

@ -0,0 +1,73 @@
package core.controller;
import core.Environment;
import core.algo.Method;
import core.gui.LearningView;
import core.gui.View;
import javax.swing.*;
import java.util.List;
public class RLControllerGUI<A extends Enum> extends RLController<A> {
private LearningView learningView;
public RLControllerGUI(Environment<A> env, Method method, A... actions) {
super(env, method, actions);
}
@Override
protected void initListeners() {
SwingUtilities.invokeLater(() -> {
learningView = new View<>(learning, environment, this);
learning.addListener(this);
});
}
@Override
public void onLearnMoreEpisodes(int nrOfEpisodes) {
super.onLearnMoreEpisodes(nrOfEpisodes);
learningView.updateLearningInfoPanel();
}
@Override
public void onLoadState(String fileName) {
super.onLoadState(fileName);
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
}
@Override
public void onEpsilonChange(float epsilon) {
super.onEpsilonChange(epsilon);
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
}
@Override
protected void changeLearningDelay(int delay) {
super.changeLearningDelay(delay);
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
}
@Override
public void onLearningEnd() {
super.onLearningEnd();
SwingUtilities.invokeLater(() -> learningView.updateRewardGraph(latestRewardsHistory));
}
@Override
public void onEpisodeEnd(List<Double> rewardHistory) {
super.onEpisodeEnd(rewardHistory);
SwingUtilities.invokeLater(() -> {
if (!fastLearning) {
learningView.updateRewardGraph(latestRewardsHistory);
}
learningView.updateLearningInfoPanel();
});
}
@Override
public void onStepEnd() {
if (!fastLearning) {
SwingUtilities.invokeLater(() -> learningView.repaintEnvironment());
}
}
}

View File

@ -3,6 +3,7 @@ package example;
import core.RNG; import core.RNG;
import core.algo.Method; import core.algo.Method;
import core.controller.RLController; import core.controller.RLController;
import core.controller.RLControllerGUI;
import evironment.jumpingDino.DinoAction; import evironment.jumpingDino.DinoAction;
import evironment.jumpingDino.DinoWorld; import evironment.jumpingDino.DinoWorld;
@ -10,15 +11,15 @@ public class JumpingDino {
public static void main(String[] args) { public static void main(String[] args) {
RNG.setSeed(55); RNG.setSeed(55);
RLController<DinoAction> rl = new RLController<>( RLController<DinoAction> rl = new RLControllerGUI<>(
new DinoWorld(true, true), new DinoWorld(true, true),
Method.MC_ONPOLICY_EGREEDY, Method.MC_ONPOLICY_EGREEDY,
DinoAction.values()); DinoAction.values());
rl.setDelay(200); rl.setDelay(0);
rl.setDiscountFactor(1f); rl.setDiscountFactor(1f);
rl.setEpsilon(0.15f); rl.setEpsilon(0.15f);
rl.setEpisodes(5000); rl.setNrOfEpisodes(100000);
rl.start(); rl.start();
} }

View File

@ -16,7 +16,7 @@ public class RunningAnt {
AntAction.values()); AntAction.values());
rl.setDelay(200); rl.setDelay(200);
rl.setEpisodes(10000); rl.setNrOfEpisodes(10000);
rl.setDiscountFactor(1f); rl.setDiscountFactor(1f);
rl.setEpsilon(0.15f); rl.setEpsilon(0.15f);