From 518683b6763c9f737b8fc63d3fc8d51a04190fde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20L=C3=B6wenstrom?= Date: Tue, 31 Dec 2019 14:43:40 +0100 Subject: [PATCH] split GUI parts from controller into sub class --- .idea/compiler.xml | 2 +- .../java/core/controller/RLController.java | 106 ++++++++---------- .../java/core/controller/RLControllerGUI.java | 73 ++++++++++++ src/main/java/example/JumpingDino.java | 7 +- src/main/java/example/RunningAnt.java | 2 +- 5 files changed, 128 insertions(+), 62 deletions(-) create mode 100644 src/main/java/core/controller/RLControllerGUI.java diff --git a/.idea/compiler.xml b/.idea/compiler.xml index 95a88ae..a1757ae 100644 --- a/.idea/compiler.xml +++ b/.idea/compiler.xml @@ -5,4 +5,4 @@ - + \ No newline at end of file diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java index 0fb589a..e2b9ec3 100644 --- a/src/main/java/core/controller/RLController.java +++ b/src/main/java/core/controller/RLController.java @@ -10,6 +10,7 @@ import core.gui.View; import core.listener.LearningListener; import core.listener.ViewListener; import core.policy.EpsilonPolicy; +import lombok.Setter; import javax.swing.*; import java.io.*; @@ -18,24 +19,29 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; public class RLController implements ViewListener, LearningListener { - private final String folderPrefix = "learningStates" + File.separator; - private Environment environment; - private DiscreteActionSpace discreteActionSpace; - private Method method; - private int delay = LearningConfig.DEFAULT_DELAY; - private float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR; - private float epsilon = LearningConfig.DEFAULT_EPSILON; - private Learning learning; - private LearningView learningView; - private boolean fastLearning; - private List latestRewardsHistory; - private int nrOfEpisodes; - private int prevDelay; + protected final String folderPrefix = "learningStates" + File.separator; + protected Environment environment; + protected DiscreteActionSpace discreteActionSpace; + protected Method method; + @Setter + protected int delay = LearningConfig.DEFAULT_DELAY; + @Setter + protected float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR; + @Setter + protected float epsilon = LearningConfig.DEFAULT_EPSILON; + protected Learning learning; + protected boolean fastLearning; + protected List latestRewardsHistory; + @Setter + protected int nrOfEpisodes; + protected int prevDelay; + protected volatile boolean printNextEpisode; public RLController(Environment env, Method method, A... actions){ setEnvironment(env); setMethod(method); setAllowedActions(actions); + printNextEpisode = true; } public void start(){ @@ -48,20 +54,29 @@ public class RLController implements ViewListener, LearningListe default: throw new IllegalArgumentException("Undefined method"); } - - initGUI(); + System.out.println("Initialized learning: " + learning.getClass()); + initListeners(); + System.out.println("Set listeners"); initLearning(); } - private void initGUI(){ - SwingUtilities.invokeLater(()->{ - learningView = new View<>(learning, environment, this); - learning.addListener(this); - }); + protected void initListeners(){ + learning.addListener(this); + new Thread(() -> { + while (true){ + printNextEpisode = true; + try { + Thread.sleep(30*1000); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + }).start(); } private void initLearning(){ if(learning instanceof EpisodicLearning){ + System.out.println("Starting learning of <" + nrOfEpisodes + "> episodes"); ((EpisodicLearning) learning).learn(nrOfEpisodes); }else{ learning.learn(); @@ -78,7 +93,6 @@ public class RLController implements ViewListener, LearningListe }else{ throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!"); } - learningView.updateLearningInfoPanel(); } @Override @@ -88,10 +102,9 @@ public class RLController implements ViewListener, LearningListe try { fis = new FileInputStream(fileName); in = new ObjectInputStream(fis); - System.out.println("interrup" + Thread.currentThread().getId()); + System.out.println("interrupt" + Thread.currentThread().getId()); learning.interruptLearning(); learning.load(in); - SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel()); in.close(); } catch (IOException | ClassNotFoundException e) { e.printStackTrace(); @@ -117,7 +130,6 @@ public class RLController implements ViewListener, LearningListe public void onEpsilonChange(float epsilon) { if(learning.getPolicy() instanceof EpsilonPolicy){ ((EpsilonPolicy) learning.getPolicy()).setEpsilon(epsilon); - SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel()); }else{ System.out.println("Trying to call inEpsilonChange on non-epsilon policy"); } @@ -128,9 +140,8 @@ public class RLController implements ViewListener, LearningListe changeLearningDelay(delay); } - private void changeLearningDelay(int delay){ + protected void changeLearningDelay(int delay){ learning.setDelay(delay); - SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel()); } @Override @@ -153,19 +164,8 @@ public class RLController implements ViewListener, LearningListe @Override public void onLearningEnd() { - SwingUtilities.invokeLater(()-> learningView.updateRewardGraph(latestRewardsHistory)); - onSaveState( method.toString() + System.currentTimeMillis()/1000 + (learning instanceof EpisodicLearning ? "e " + ((EpisodicLearning) learning).getCurrentEpisode() : "")); - } - - @Override - public void onEpisodeEnd(List rewardHistory) { - latestRewardsHistory = rewardHistory; - SwingUtilities.invokeLater(() ->{ - if(!fastLearning){ - learningView.updateRewardGraph(latestRewardsHistory); - } - learningView.updateLearningInfoPanel(); - }); + System.out.println("Learning finished"); + onSaveState( method.toString() + System.currentTimeMillis()/1000 + (learning instanceof EpisodicLearning ? "e" + ((EpisodicLearning) learning).getCurrentEpisode() : "")); } @Override @@ -174,12 +174,19 @@ public class RLController implements ViewListener, LearningListe } @Override - public void onStepEnd() { - if(!fastLearning){ - SwingUtilities.invokeLater(() -> learningView.repaintEnvironment()); + public void onEpisodeEnd(List rewardHistory) { + latestRewardsHistory = rewardHistory; + if(printNextEpisode){ + System.out.println("Episode " + ((EpisodicLearning) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size()-1)); + System.out.println("Eps/sec: " + ((EpisodicLearning) learning).getEpisodePerSecond()); + printNextEpisode = false; } } + @Override + public void onStepEnd() { + } + /************************************************* ** SETTERS ** @@ -205,19 +212,4 @@ public class RLController implements ViewListener, LearningListe } this.discreteActionSpace = new ListDiscreteActionSpace<>(actions); } - - public void setDelay(int delay){ - this.delay = delay; - } - - public void setEpisodes(int nrOfEpisodes){ - this.nrOfEpisodes = nrOfEpisodes; - } - - public void setDiscountFactor(float discountFactor){ - this.discountFactor = discountFactor; - } - public void setEpsilon(float epsilon){ - this.epsilon = epsilon; - } } diff --git a/src/main/java/core/controller/RLControllerGUI.java b/src/main/java/core/controller/RLControllerGUI.java new file mode 100644 index 0000000..567ad54 --- /dev/null +++ b/src/main/java/core/controller/RLControllerGUI.java @@ -0,0 +1,73 @@ +package core.controller; + +import core.Environment; +import core.algo.Method; +import core.gui.LearningView; +import core.gui.View; + +import javax.swing.*; +import java.util.List; + +public class RLControllerGUI extends RLController { + private LearningView learningView; + + public RLControllerGUI(Environment env, Method method, A... actions) { + super(env, method, actions); + } + + @Override + protected void initListeners() { + SwingUtilities.invokeLater(() -> { + learningView = new View<>(learning, environment, this); + learning.addListener(this); + }); + } + + @Override + public void onLearnMoreEpisodes(int nrOfEpisodes) { + super.onLearnMoreEpisodes(nrOfEpisodes); + learningView.updateLearningInfoPanel(); + } + + @Override + public void onLoadState(String fileName) { + super.onLoadState(fileName); + SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel()); + } + + @Override + public void onEpsilonChange(float epsilon) { + super.onEpsilonChange(epsilon); + SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel()); + } + + @Override + protected void changeLearningDelay(int delay) { + super.changeLearningDelay(delay); + SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel()); + } + + @Override + public void onLearningEnd() { + super.onLearningEnd(); + SwingUtilities.invokeLater(() -> learningView.updateRewardGraph(latestRewardsHistory)); + } + + @Override + public void onEpisodeEnd(List rewardHistory) { + super.onEpisodeEnd(rewardHistory); + SwingUtilities.invokeLater(() -> { + if (!fastLearning) { + learningView.updateRewardGraph(latestRewardsHistory); + } + learningView.updateLearningInfoPanel(); + }); + } + + @Override + public void onStepEnd() { + if (!fastLearning) { + SwingUtilities.invokeLater(() -> learningView.repaintEnvironment()); + } + } +} diff --git a/src/main/java/example/JumpingDino.java b/src/main/java/example/JumpingDino.java index 2e0a447..7cdd445 100644 --- a/src/main/java/example/JumpingDino.java +++ b/src/main/java/example/JumpingDino.java @@ -3,6 +3,7 @@ package example; import core.RNG; import core.algo.Method; import core.controller.RLController; +import core.controller.RLControllerGUI; import evironment.jumpingDino.DinoAction; import evironment.jumpingDino.DinoWorld; @@ -10,15 +11,15 @@ public class JumpingDino { public static void main(String[] args) { RNG.setSeed(55); - RLController rl = new RLController<>( + RLController rl = new RLControllerGUI<>( new DinoWorld(true, true), Method.MC_ONPOLICY_EGREEDY, DinoAction.values()); - rl.setDelay(200); + rl.setDelay(0); rl.setDiscountFactor(1f); rl.setEpsilon(0.15f); - rl.setEpisodes(5000); + rl.setNrOfEpisodes(100000); rl.start(); } diff --git a/src/main/java/example/RunningAnt.java b/src/main/java/example/RunningAnt.java index bb2fa2e..ade0e92 100644 --- a/src/main/java/example/RunningAnt.java +++ b/src/main/java/example/RunningAnt.java @@ -16,7 +16,7 @@ public class RunningAnt { AntAction.values()); rl.setDelay(200); - rl.setEpisodes(10000); + rl.setNrOfEpisodes(10000); rl.setDiscountFactor(1f); rl.setEpsilon(0.15f);