diff --git a/src/main/java/core/algo/EpisodicLearning.java b/src/main/java/core/algo/EpisodicLearning.java index 559bfc6..b45ea50 100644 --- a/src/main/java/core/algo/EpisodicLearning.java +++ b/src/main/java/core/algo/EpisodicLearning.java @@ -50,7 +50,7 @@ public abstract class EpisodicLearning extends Learning imple private void initBenchMarking(){ new Thread(()->{ - while (true){ + while (currentlyLearning){ episodePerSecond = episodeSumCurrentSecond; episodeSumCurrentSecond = 0; try { @@ -62,7 +62,7 @@ public abstract class EpisodicLearning extends Learning imple }).start(); } - protected void dispatchEpisodeEnd(){ + private void dispatchEpisodeEnd(){ ++episodeSumCurrentSecond; if(rewardHistory.size() > 10000){ rewardHistory.clear(); @@ -75,18 +75,6 @@ public abstract class EpisodicLearning extends Learning imple protected void dispatchEpisodeStart(){ ++currentEpisode; - /* - 2f 0.02 => 100 - 1.5f 0.02 => 75 - 1.4f 0.02 => fail - 1.5f 0.1 => 16 ! - */ - if(this.policy instanceof EpsilonGreedyPolicy){ - float ep = 1.5f/(float)currentEpisode; - if(ep < 0.10) ep = 0; - ((EpsilonGreedyPolicy) this.policy).setEpsilon(ep); - System.out.println(ep); - } episodesToLearn.decrementAndGet(); for(LearningListener l: learningListeners){ l.onEpisodeStart(); @@ -97,31 +85,24 @@ public abstract class EpisodicLearning extends Learning imple protected void dispatchStepEnd() { super.dispatchStepEnd(); timestamp++; - // TODO: more sophisticated way to check convergence - if(timestamp > 300000){ - System.out.println("converged after: " + currentEpisode + " episode!"); - interruptLearning(); - } - } - - @Override - public void learn(){ - learn(LearningConfig.DEFAULT_NR_OF_EPISODES); } private void startLearning(){ - learningExecutor.submit(()->{ - dispatchLearningStart(); - while(episodesToLearn.get() > 0){ - dispatchEpisodeStart(); - nextEpisode(); - dispatchEpisodeEnd(); - } - synchronized (this){ - dispatchLearningEnd(); - notifyAll(); - } - }); + dispatchLearningStart(); + System.out.println(episodesToLearn.get()); + while(episodesToLearn.get() > 0){ + dispatchEpisodeStart(); + nextEpisode(); + dispatchEpisodeEnd(); + } + synchronized (this){ + dispatchLearningEnd(); + notifyAll(); + } + } + + public void learnMoreEpisodes(int nrOfEpisodes){ + episodesToLearn.addAndGet(nrOfEpisodes); } /** @@ -146,8 +127,14 @@ public abstract class EpisodicLearning extends Learning imple delay = prevDelay; } + @Override + public void learn(){ + learn(LearningConfig.DEFAULT_NR_OF_EPISODES); + } + public synchronized void learn(int nrOfEpisodes){ boolean isLearning = episodesToLearn.getAndAdd(nrOfEpisodes) != 0; + System.out.println(isLearning); if(!isLearning) startLearning(); } diff --git a/src/main/java/core/algo/Learning.java b/src/main/java/core/algo/Learning.java index fbba9ef..1bb5207 100644 --- a/src/main/java/core/algo/Learning.java +++ b/src/main/java/core/algo/Learning.java @@ -42,8 +42,7 @@ public abstract class Learning{ @Setter protected int delay; protected List rewardHistory; - protected ExecutorService learningExecutor; - protected boolean currentlyLearning; + protected volatile boolean currentlyLearning; public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, int delay) { this.environment = environment; @@ -53,7 +52,6 @@ public abstract class Learning{ currentlyLearning = false; learningListeners = new HashSet<>(); rewardHistory = new CopyOnWriteArrayList<>(); - learningExecutor = Executors.newSingleThreadExecutor(); } public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor) { @@ -89,8 +87,6 @@ public abstract class Learning{ protected void dispatchLearningEnd() { currentlyLearning = false; - System.out.println("Checksum: " + checkSum); - System.out.println("Reward Checksum: " + rewardCheckSum); for (LearningListener l : learningListeners) { l.onLearningEnd(); } diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java index d855ab1..837e427 100644 --- a/src/main/java/core/controller/RLController.java +++ b/src/main/java/core/controller/RLController.java @@ -83,7 +83,7 @@ public class RLController implements LearningListener { private void initLearning() { if(learning instanceof EpisodicLearning) { System.out.println("Starting learning of <" + nrOfEpisodes + "> episodes"); - ((EpisodicLearning) learning).learn(nrOfEpisodes); + ((EpisodicLearning) learning).learn(nrOfEpisodes); } else { learning.learn(); } @@ -95,7 +95,13 @@ public class RLController implements LearningListener { protected void learnMoreEpisodes(int nrOfEpisodes) { if(learning instanceof EpisodicLearning) { - ((EpisodicLearning) learning).learn(nrOfEpisodes); + if(learning.isCurrentlyLearning()){ + ((EpisodicLearning) learning).learnMoreEpisodes(nrOfEpisodes); + }else{ + new Thread(() -> { + ((EpisodicLearning) learning).learn(nrOfEpisodes); + }).start(); + } } else { throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!"); } @@ -169,8 +175,8 @@ public class RLController implements LearningListener { public void onEpisodeEnd(List rewardHistory) { latestRewardsHistory = rewardHistory; if(printNextEpisode) { - System.out.println("Episode " + ((EpisodicLearning) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1)); - System.out.println("Eps/sec: " + ((EpisodicLearning) learning).getEpisodePerSecond()); + System.out.println("Episode " + ((EpisodicLearning) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1)); + System.out.println("Eps/sec: " + ((EpisodicLearning) learning).getEpisodePerSecond()); printNextEpisode = false; } } diff --git a/src/main/java/core/gui/LearningInfoPanel.java b/src/main/java/core/gui/LearningInfoPanel.java index 4959cee..ed2ebad 100644 --- a/src/main/java/core/gui/LearningInfoPanel.java +++ b/src/main/java/core/gui/LearningInfoPanel.java @@ -60,9 +60,9 @@ public class LearningInfoPanel extends JPanel { viewListener.onFastLearnChange(fastLearning); }); smoothGraphCheckbox = new JCheckBox("Smoothen Graph"); - smoothGraphCheckbox.setSelected(false); + smoothGraphCheckbox.setSelected(true); last100Checkbox = new JCheckBox("Only show last 100 Rewards"); - last100Checkbox.setSelected(true); + last100Checkbox.setSelected(false); drawEnvironmentCheckbox = new JCheckBox("Update Environment"); drawEnvironmentCheckbox.setSelected(true); diff --git a/src/main/java/example/DinoSampling.java b/src/main/java/example/DinoSampling.java deleted file mode 100644 index 09d5c5b..0000000 --- a/src/main/java/example/DinoSampling.java +++ /dev/null @@ -1,27 +0,0 @@ -package example; - -import core.RNG; -import core.algo.Method; -import core.controller.RLController; -import evironment.jumpingDino.DinoAction; -import evironment.jumpingDino.DinoWorld; - -public class DinoSampling { - public static void main(String[] args) { - for (int i = 0; i < 10 ; i++) { - RNG.setSeed(55); - - RLController rl = new RLController<>( - new DinoWorld(false, false), - Method.MC_CONTROL_FIRST_VISIT, - DinoAction.values()); - - rl.setDelay(0); - rl.setDiscountFactor(1f); - rl.setEpsilon(0.15f); - rl.setLearningRate(1f); - rl.setNrOfEpisodes(400); - rl.start(); - } - } -} diff --git a/src/main/java/example/JumpingDino.java b/src/main/java/example/JumpingDino.java index 41d5290..ce81753 100644 --- a/src/main/java/example/JumpingDino.java +++ b/src/main/java/example/JumpingDino.java @@ -3,6 +3,7 @@ package example; import core.RNG; import core.algo.Method; import core.controller.RLController; +import core.controller.RLControllerGUI; import evironment.jumpingDino.DinoAction; import evironment.jumpingDino.DinoWorld; @@ -10,16 +11,16 @@ public class JumpingDino { public static void main(String[] args) { RNG.setSeed(55); - RLController rl = new RLController<>( + RLController rl = new RLControllerGUI<>( new DinoWorld(false, false), Method.MC_CONTROL_FIRST_VISIT, DinoAction.values()); - rl.setDelay(0); + rl.setDelay(100); rl.setDiscountFactor(1f); rl.setEpsilon(0.15f); rl.setLearningRate(1f); - rl.setNrOfEpisodes(400); + rl.setNrOfEpisodes(10000); rl.start(); } }