split GUI parts from controller into sub class

This commit is contained in:
Jan Löwenstrom 2019-12-31 14:43:40 +01:00
parent 195722e98f
commit 518683b676
5 changed files with 128 additions and 62 deletions

View File

@ -5,4 +5,4 @@
<profile default="true" name="Default" enabled="true" />
</annotationProcessing>
</component>
</project>
</project>

View File

@ -10,6 +10,7 @@ import core.gui.View;
import core.listener.LearningListener;
import core.listener.ViewListener;
import core.policy.EpsilonPolicy;
import lombok.Setter;
import javax.swing.*;
import java.io.*;
@ -18,24 +19,29 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
public class RLController<A extends Enum> implements ViewListener, LearningListener {
private final String folderPrefix = "learningStates" + File.separator;
private Environment<A> environment;
private DiscreteActionSpace<A> discreteActionSpace;
private Method method;
private int delay = LearningConfig.DEFAULT_DELAY;
private float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR;
private float epsilon = LearningConfig.DEFAULT_EPSILON;
private Learning<A> learning;
private LearningView learningView;
private boolean fastLearning;
private List<Double> latestRewardsHistory;
private int nrOfEpisodes;
private int prevDelay;
protected final String folderPrefix = "learningStates" + File.separator;
protected Environment<A> environment;
protected DiscreteActionSpace<A> discreteActionSpace;
protected Method method;
@Setter
protected int delay = LearningConfig.DEFAULT_DELAY;
@Setter
protected float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR;
@Setter
protected float epsilon = LearningConfig.DEFAULT_EPSILON;
protected Learning<A> learning;
protected boolean fastLearning;
protected List<Double> latestRewardsHistory;
@Setter
protected int nrOfEpisodes;
protected int prevDelay;
protected volatile boolean printNextEpisode;
public RLController(Environment<A> env, Method method, A... actions){
setEnvironment(env);
setMethod(method);
setAllowedActions(actions);
printNextEpisode = true;
}
public void start(){
@ -48,20 +54,29 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
default:
throw new IllegalArgumentException("Undefined method");
}
initGUI();
System.out.println("Initialized learning: " + learning.getClass());
initListeners();
System.out.println("Set listeners");
initLearning();
}
private void initGUI(){
SwingUtilities.invokeLater(()->{
learningView = new View<>(learning, environment, this);
learning.addListener(this);
});
protected void initListeners(){
learning.addListener(this);
new Thread(() -> {
while (true){
printNextEpisode = true;
try {
Thread.sleep(30*1000);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}).start();
}
private void initLearning(){
if(learning instanceof EpisodicLearning){
System.out.println("Starting learning of <" + nrOfEpisodes + "> episodes");
((EpisodicLearning) learning).learn(nrOfEpisodes);
}else{
learning.learn();
@ -78,7 +93,6 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
}else{
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
}
learningView.updateLearningInfoPanel();
}
@Override
@ -88,10 +102,9 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
try {
fis = new FileInputStream(fileName);
in = new ObjectInputStream(fis);
System.out.println("interrup" + Thread.currentThread().getId());
System.out.println("interrupt" + Thread.currentThread().getId());
learning.interruptLearning();
learning.load(in);
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
in.close();
} catch (IOException | ClassNotFoundException e) {
e.printStackTrace();
@ -117,7 +130,6 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
public void onEpsilonChange(float epsilon) {
if(learning.getPolicy() instanceof EpsilonPolicy){
((EpsilonPolicy<A>) learning.getPolicy()).setEpsilon(epsilon);
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
}else{
System.out.println("Trying to call inEpsilonChange on non-epsilon policy");
}
@ -128,9 +140,8 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
changeLearningDelay(delay);
}
private void changeLearningDelay(int delay){
protected void changeLearningDelay(int delay){
learning.setDelay(delay);
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
}
@Override
@ -153,19 +164,8 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
@Override
public void onLearningEnd() {
SwingUtilities.invokeLater(()-> learningView.updateRewardGraph(latestRewardsHistory));
onSaveState( method.toString() + System.currentTimeMillis()/1000 + (learning instanceof EpisodicLearning ? "e " + ((EpisodicLearning) learning).getCurrentEpisode() : ""));
}
@Override
public void onEpisodeEnd(List<Double> rewardHistory) {
latestRewardsHistory = rewardHistory;
SwingUtilities.invokeLater(() ->{
if(!fastLearning){
learningView.updateRewardGraph(latestRewardsHistory);
}
learningView.updateLearningInfoPanel();
});
System.out.println("Learning finished");
onSaveState( method.toString() + System.currentTimeMillis()/1000 + (learning instanceof EpisodicLearning ? "e" + ((EpisodicLearning) learning).getCurrentEpisode() : ""));
}
@Override
@ -174,12 +174,19 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
}
@Override
public void onStepEnd() {
if(!fastLearning){
SwingUtilities.invokeLater(() -> learningView.repaintEnvironment());
public void onEpisodeEnd(List<Double> rewardHistory) {
latestRewardsHistory = rewardHistory;
if(printNextEpisode){
System.out.println("Episode " + ((EpisodicLearning) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size()-1));
System.out.println("Eps/sec: " + ((EpisodicLearning) learning).getEpisodePerSecond());
printNextEpisode = false;
}
}
@Override
public void onStepEnd() {
}
/*************************************************
** SETTERS **
@ -205,19 +212,4 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
}
this.discreteActionSpace = new ListDiscreteActionSpace<>(actions);
}
public void setDelay(int delay){
this.delay = delay;
}
public void setEpisodes(int nrOfEpisodes){
this.nrOfEpisodes = nrOfEpisodes;
}
public void setDiscountFactor(float discountFactor){
this.discountFactor = discountFactor;
}
public void setEpsilon(float epsilon){
this.epsilon = epsilon;
}
}

View File

@ -0,0 +1,73 @@
package core.controller;
import core.Environment;
import core.algo.Method;
import core.gui.LearningView;
import core.gui.View;
import javax.swing.*;
import java.util.List;
public class RLControllerGUI<A extends Enum> extends RLController<A> {
private LearningView learningView;
public RLControllerGUI(Environment<A> env, Method method, A... actions) {
super(env, method, actions);
}
@Override
protected void initListeners() {
SwingUtilities.invokeLater(() -> {
learningView = new View<>(learning, environment, this);
learning.addListener(this);
});
}
@Override
public void onLearnMoreEpisodes(int nrOfEpisodes) {
super.onLearnMoreEpisodes(nrOfEpisodes);
learningView.updateLearningInfoPanel();
}
@Override
public void onLoadState(String fileName) {
super.onLoadState(fileName);
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
}
@Override
public void onEpsilonChange(float epsilon) {
super.onEpsilonChange(epsilon);
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
}
@Override
protected void changeLearningDelay(int delay) {
super.changeLearningDelay(delay);
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
}
@Override
public void onLearningEnd() {
super.onLearningEnd();
SwingUtilities.invokeLater(() -> learningView.updateRewardGraph(latestRewardsHistory));
}
@Override
public void onEpisodeEnd(List<Double> rewardHistory) {
super.onEpisodeEnd(rewardHistory);
SwingUtilities.invokeLater(() -> {
if (!fastLearning) {
learningView.updateRewardGraph(latestRewardsHistory);
}
learningView.updateLearningInfoPanel();
});
}
@Override
public void onStepEnd() {
if (!fastLearning) {
SwingUtilities.invokeLater(() -> learningView.repaintEnvironment());
}
}
}

View File

@ -3,6 +3,7 @@ package example;
import core.RNG;
import core.algo.Method;
import core.controller.RLController;
import core.controller.RLControllerGUI;
import evironment.jumpingDino.DinoAction;
import evironment.jumpingDino.DinoWorld;
@ -10,15 +11,15 @@ public class JumpingDino {
public static void main(String[] args) {
RNG.setSeed(55);
RLController<DinoAction> rl = new RLController<>(
RLController<DinoAction> rl = new RLControllerGUI<>(
new DinoWorld(true, true),
Method.MC_ONPOLICY_EGREEDY,
DinoAction.values());
rl.setDelay(200);
rl.setDelay(0);
rl.setDiscountFactor(1f);
rl.setEpsilon(0.15f);
rl.setEpisodes(5000);
rl.setNrOfEpisodes(100000);
rl.start();
}

View File

@ -16,7 +16,7 @@ public class RunningAnt {
AntAction.values());
rl.setDelay(200);
rl.setEpisodes(10000);
rl.setNrOfEpisodes(10000);
rl.setDiscountFactor(1f);
rl.setEpsilon(0.15f);