From 195722e98f9b9488e8412cb120e5090c47e1c576 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20L=C3=B6wenstrom?= Date: Sun, 29 Dec 2019 01:12:11 +0100 Subject: [PATCH] enhance save/load feature and change thread handling - saving monte carlo did not include returnSum and returnCount, so it the state would be wrong after loading. Learning, EpisodicLearning and MonteCarlo classes are all overriding custom save and load methods, calling super() each time but including fields that are necessary to replace on runtime. - moved generic episodic behaviour from monteCarlo to abstract top level class - using AtomicInteger for episodesToLearn - moved learning-Thread-handling from controller to model. Learning got one extra Leaning thread. - add feature to use custom speed and distance for dino world obstacles --- .gitignore | 3 + learningStates/.gitkeep | 0 src/main/java/core/algo/EpisodicLearning.java | 133 +++++++++++++----- src/main/java/core/algo/Learning.java | 27 +++- .../algo/MC/MonteCarloOnPolicyEGreedy.java | 25 +++- .../java/core/controller/RLController.java | 46 +++--- .../evironment/jumpingDino/DinoWorld.java | 32 ++++- src/main/java/example/JumpingDino.java | 2 +- 8 files changed, 193 insertions(+), 75 deletions(-) create mode 100644 learningStates/.gitkeep diff --git a/.gitignore b/.gitignore index 25d1e1c..3e25165 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +learningStates/* +!learningStates/.gitkeep + .idea/refo.iml .idea/misc.xml .idea/modules.xml diff --git a/learningStates/.gitkeep b/learningStates/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/src/main/java/core/algo/EpisodicLearning.java b/src/main/java/core/algo/EpisodicLearning.java index ceb3756..3a56363 100644 --- a/src/main/java/core/algo/EpisodicLearning.java +++ b/src/main/java/core/algo/EpisodicLearning.java @@ -2,59 +2,53 @@ package core.algo; import core.DiscreteActionSpace; import core.Environment; +import core.StepResult; import core.listener.LearningListener; +import lombok.Getter; import lombok.Setter; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + public abstract class EpisodicLearning extends Learning implements Episodic { @Setter protected int currentEpisode; - protected int episodesToLearn; + protected volatile AtomicInteger episodesToLearn = new AtomicInteger(0); + @Getter protected volatile int episodePerSecond; protected int episodeSumCurrentSecond; - private volatile boolean measureEpisodeBenchMark; + protected double sumOfRewards; + protected List> episode = new ArrayList<>(); public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, int delay) { super(environment, actionSpace, discountFactor, delay); + initBenchMarking(); } public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor) { super(environment, actionSpace, discountFactor); + initBenchMarking(); } public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace, int delay) { super(environment, actionSpace, delay); + initBenchMarking(); } public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace) { super(environment, actionSpace); + initBenchMarking(); } - protected void dispatchEpisodeEnd(double recentSumOfRewards){ - ++episodeSumCurrentSecond; - if(rewardHistory.size() > 10000){ - rewardHistory.clear(); - } - rewardHistory.add(recentSumOfRewards); - for(LearningListener l: learningListeners) { - l.onEpisodeEnd(rewardHistory); - } - } + protected abstract void nextEpisode(); - protected void dispatchEpisodeStart(){ - for(LearningListener l: learningListeners){ - l.onEpisodeStart(); - } - } - - @Override - public void learn(){ - learn(0); - } - - public void learn(int nrOfEpisodes){ - measureEpisodeBenchMark = true; + private void initBenchMarking(){ new Thread(()->{ - while(measureEpisodeBenchMark){ + while (true){ episodePerSecond = episodeSumCurrentSecond; episodeSumCurrentSecond = 0; try { @@ -64,24 +58,89 @@ public abstract class EpisodicLearning extends Learning imple } } }).start(); - episodesToLearn += nrOfEpisodes; - dispatchLearningStart(); - for(int i=0; i < nrOfEpisodes; ++i){ - nextEpisode(); - } - dispatchLearningEnd(); - measureEpisodeBenchMark = false; } - protected abstract void nextEpisode(); + protected void dispatchEpisodeEnd(){ + ++episodeSumCurrentSecond; + if(rewardHistory.size() > 10000){ + rewardHistory.clear(); + } + rewardHistory.add(sumOfRewards); + for(LearningListener l: learningListeners) { + l.onEpisodeEnd(rewardHistory); + } + } + + protected void dispatchEpisodeStart(){ + ++currentEpisode; + episodesToLearn.decrementAndGet(); + for(LearningListener l: learningListeners){ + l.onEpisodeStart(); + } + } @Override - public int getCurrentEpisode(){ - return currentEpisode; + public void learn(){ + // TODO remove or learn with default episode number + } + + private void startLearning(){ + learningExecutor.submit(()->{ + dispatchLearningStart(); + while(episodesToLearn.get() > 0){ + dispatchEpisodeStart(); + nextEpisode(); + dispatchEpisodeEnd(); + } + synchronized (this){ + dispatchLearningEnd(); + notifyAll(); + } + }); + } + + /** + * Stopping the while loop by setting episodesToLearn to 0. + * The current episode can not be interrupted, so the sleep delay + * is removed and the calling thread has to wait until the + * current episode is done. + * Resetting the delay afterwards. + */ + @Override + public synchronized void interruptLearning(){ + episodesToLearn.set(0); + int prevDelay = delay; + delay = 0; + while(currentlyLearning) { + try { + wait(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + delay = prevDelay; + } + + public synchronized void learn(int nrOfEpisodes){ + boolean isLearning = episodesToLearn.getAndAdd(nrOfEpisodes) != 0; + if(!isLearning) + startLearning(); } @Override public int getEpisodesToGo(){ - return episodesToLearn - currentEpisode; + return episodesToLearn.get(); + } + + @Override + public synchronized void save(ObjectOutputStream oos) throws IOException { + super.save(oos); + oos.writeInt(currentEpisode); + } + + @Override + public synchronized void load(ObjectInputStream ois) throws IOException, ClassNotFoundException { + super.load(ois); + currentEpisode = ois.readInt(); } } diff --git a/src/main/java/core/algo/Learning.java b/src/main/java/core/algo/Learning.java index 8c589d2..c63ef43 100644 --- a/src/main/java/core/algo/Learning.java +++ b/src/main/java/core/algo/Learning.java @@ -9,10 +9,16 @@ import core.policy.Policy; import lombok.Getter; import lombok.Setter; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; /** * @@ -30,14 +36,18 @@ public abstract class Learning{ @Setter protected int delay; protected List rewardHistory; + protected ExecutorService learningExecutor; + protected boolean currentlyLearning; public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, int delay) { this.environment = environment; this.actionSpace = actionSpace; this.discountFactor = discountFactor; this.delay = delay; + currentlyLearning = false; learningListeners = new HashSet<>(); rewardHistory = new CopyOnWriteArrayList<>(); + learningExecutor = Executors.newSingleThreadExecutor(); } public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor) { @@ -52,7 +62,6 @@ public abstract class Learning{ this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_DELAY); } - public abstract void learn(); public void addListener(LearningListener learningListener) { @@ -66,15 +75,31 @@ public abstract class Learning{ } protected void dispatchLearningStart() { + currentlyLearning = true; for (LearningListener l : learningListeners) { l.onLearningStart(); } } protected void dispatchLearningEnd() { + currentlyLearning = false; for (LearningListener l : learningListeners) { l.onLearningEnd(); } } + public synchronized void interruptLearning(){ + //TODO: for non episodic learning + } + + + public void save(ObjectOutputStream oos) throws IOException { + oos.writeObject(rewardHistory); + oos.writeObject(stateActionTable); + } + + public void load(ObjectInputStream ois) throws IOException, ClassNotFoundException { + rewardHistory = (List) ois.readObject(); + stateActionTable = (StateActionTable) ois.readObject(); + } } diff --git a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java index 6f468c0..e9b02c9 100644 --- a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java +++ b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java @@ -5,6 +5,9 @@ import core.algo.EpisodicLearning; import core.policy.EpsilonGreedyPolicy; import javafx.util.Pair; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; import java.util.*; /** @@ -44,19 +47,16 @@ public class MonteCarloOnPolicyEGreedy extends EpisodicLearning< this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay); } - @Override public void nextEpisode() { - ++currentEpisode; - List> episode = new ArrayList<>(); + episode = new ArrayList<>(); State state = environment.reset(); - dispatchEpisodeStart(); try { Thread.sleep(delay); } catch (InterruptedException e) { e.printStackTrace(); } - double sumOfRewards = 0; + sumOfRewards = 0; StepResultEnvironment envResult = null; while(envResult == null || !envResult.isDone()){ Map actionValues = stateActionTable.getActionValues(state); @@ -76,7 +76,6 @@ public class MonteCarloOnPolicyEGreedy extends EpisodicLearning< dispatchStepEnd(); } - dispatchEpisodeEnd(sumOfRewards); // System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards); Set> stateActionPairs = new LinkedHashSet<>(); @@ -115,4 +114,18 @@ public class MonteCarloOnPolicyEGreedy extends EpisodicLearning< public int getEpisodesPerSecond(){ return episodePerSecond; } + + @Override + public void save(ObjectOutputStream oos) throws IOException { + super.save(oos); + oos.writeObject(returnSum); + oos.writeObject(returnCount); + } + + @Override + public void load(ObjectInputStream ois) throws IOException, ClassNotFoundException { + super.load(ois); + returnSum = (Map, Double>) ois.readObject(); + returnCount = (Map, Integer>) ois.readObject(); + } } diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java index 4cafe4e..0fb589a 100644 --- a/src/main/java/core/controller/RLController.java +++ b/src/main/java/core/controller/RLController.java @@ -18,6 +18,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; public class RLController implements ViewListener, LearningListener { + private final String folderPrefix = "learningStates" + File.separator; private Environment environment; private DiscreteActionSpace discreteActionSpace; private Method method; @@ -26,15 +27,12 @@ public class RLController implements ViewListener, LearningListe private float epsilon = LearningConfig.DEFAULT_EPSILON; private Learning learning; private LearningView learningView; - private ExecutorService learningExecutor; - private boolean currentlyLearning; private boolean fastLearning; private List latestRewardsHistory; private int nrOfEpisodes; private int prevDelay; public RLController(Environment env, Method method, A... actions){ - learningExecutor = Executors.newSingleThreadExecutor(); setEnvironment(env); setMethod(method); setAllowedActions(actions); @@ -64,9 +62,9 @@ public class RLController implements ViewListener, LearningListe private void initLearning(){ if(learning instanceof EpisodicLearning){ - learningExecutor.submit(()->((EpisodicLearning) learning).learn(nrOfEpisodes)); + ((EpisodicLearning) learning).learn(nrOfEpisodes); }else{ - learningExecutor.submit(()->learning.learn()); + learning.learn(); } } @@ -75,27 +73,25 @@ public class RLController implements ViewListener, LearningListe *************************************************/ @Override public void onLearnMoreEpisodes(int nrOfEpisodes){ - if(!currentlyLearning){ - if(learning instanceof EpisodicLearning){ - learningExecutor.submit(()->((EpisodicLearning) learning).learn(nrOfEpisodes)); - }else{ - throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!"); - } + if(learning instanceof EpisodicLearning){ + ((EpisodicLearning) learning).learn(nrOfEpisodes); + }else{ + throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!"); } + learningView.updateLearningInfoPanel(); } @Override public void onLoadState(String fileName) { FileInputStream fis; - ObjectInput in; + ObjectInputStream in; try { fis = new FileInputStream(fileName); in = new ObjectInputStream(fis); - SaveState saveState = (SaveState) in.readObject(); - learning.setStateActionTable(saveState.getStateActionTable()); - if(learning instanceof EpisodicLearning){ - ((EpisodicLearning) learning).setCurrentEpisode(saveState.getCurrentEpisode()); - } + System.out.println("interrup" + Thread.currentThread().getId()); + learning.interruptLearning(); + learning.load(in); + SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel()); in.close(); } catch (IOException | ClassNotFoundException e) { e.printStackTrace(); @@ -107,15 +103,10 @@ public class RLController implements ViewListener, LearningListe FileOutputStream fos; ObjectOutputStream out; try{ - fos = new FileOutputStream(fileName); + fos = new FileOutputStream(folderPrefix + fileName); out = new ObjectOutputStream(fos); - int currentEpisode; - if(learning instanceof EpisodicLearning){ - currentEpisode = ((EpisodicLearning) learning).getCurrentEpisode(); - }else{ - currentEpisode = 0; - } - out.writeObject(new SaveState<>(learning.getStateActionTable(), currentEpisode)); + learning.interruptLearning(); + learning.save(out); out.close(); }catch (IOException e){ e.printStackTrace(); @@ -158,13 +149,12 @@ public class RLController implements ViewListener, LearningListe *************************************************/ @Override public void onLearningStart() { - currentlyLearning = true; } @Override public void onLearningEnd() { - currentlyLearning = false; SwingUtilities.invokeLater(()-> learningView.updateRewardGraph(latestRewardsHistory)); + onSaveState( method.toString() + System.currentTimeMillis()/1000 + (learning instanceof EpisodicLearning ? "e " + ((EpisodicLearning) learning).getCurrentEpisode() : "")); } @Override @@ -192,7 +182,7 @@ public class RLController implements ViewListener, LearningListe /************************************************* - ** SETTER ** + ** SETTERS ** *************************************************/ private void setEnvironment(Environment environment){ diff --git a/src/main/java/evironment/jumpingDino/DinoWorld.java b/src/main/java/evironment/jumpingDino/DinoWorld.java index 3f40524..9ca4fbf 100644 --- a/src/main/java/evironment/jumpingDino/DinoWorld.java +++ b/src/main/java/evironment/jumpingDino/DinoWorld.java @@ -14,12 +14,20 @@ import java.awt.*; public class DinoWorld implements Environment, Visualizable { private Dino dino; private Obstacle currentObstacle; + private boolean randomObstacleSpeed; + private boolean randomObstacleDistance; - public DinoWorld(){ + public DinoWorld(boolean randomObstacleSpeed, boolean randomObstacleDistance){ + this.randomObstacleSpeed = randomObstacleSpeed; + this.randomObstacleDistance = randomObstacleDistance; dino = new Dino(Config.DINO_SIZE, Config.DINO_STARTING_X, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE, 0, 0, Color.GREEN); spawnNewObstacle(); } + public DinoWorld(){ + this(false, false); + } + private boolean ranIntoObstacle(){ Obstacle o = currentObstacle; Dino p = dino; @@ -32,6 +40,7 @@ public class DinoWorld implements Environment, Visualizable { return xAxis && yAxis; } + private int getDistanceToObstacle(){ return currentObstacle.getX() - dino.getX() + Config.DINO_SIZE; } @@ -57,8 +66,27 @@ public class DinoWorld implements Environment, Visualizable { return new StepResultEnvironment(new DinoState(getDistanceToObstacle()), reward, done, ""); } + + private void spawnNewObstacle(){ - currentObstacle = new Obstacle(Config.OBSTACLE_SIZE, Config.FRAME_WIDTH + Config.OBSTACLE_SIZE, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE, -Config.OBSTACLE_SPEED, 0, Color.BLACK); + int dx; + int xSpawn; + + if(randomObstacleSpeed){ + dx = -(int)((Math.random() + 0.5) * Config.OBSTACLE_SPEED); + }else{ + dx = -Config.OBSTACLE_SPEED; + } + + if(randomObstacleDistance){ + // randomly spawning more right outside of the screen + xSpawn = (int)(Math.random() + 0.5 * Config.FRAME_WIDTH + Config.FRAME_WIDTH + Config.OBSTACLE_SIZE); + }else{ + // instantly respawning on the left screen border + xSpawn = Config.FRAME_WIDTH + Config.OBSTACLE_SIZE; + } + + currentObstacle = new Obstacle(Config.OBSTACLE_SIZE, xSpawn, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE, dx, 0, Color.BLACK); } private void spawnDino(){ diff --git a/src/main/java/example/JumpingDino.java b/src/main/java/example/JumpingDino.java index bddc8b2..2e0a447 100644 --- a/src/main/java/example/JumpingDino.java +++ b/src/main/java/example/JumpingDino.java @@ -11,7 +11,7 @@ public class JumpingDino { RNG.setSeed(55); RLController rl = new RLController<>( - new DinoWorld(), + new DinoWorld(true, true), Method.MC_ONPOLICY_EGREEDY, DinoAction.values());