diff --git a/.gitignore b/.gitignore index 25d1e1c..3e25165 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +learningStates/* +!learningStates/.gitkeep + .idea/refo.iml .idea/misc.xml .idea/modules.xml diff --git a/learningStates/.gitkeep b/learningStates/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/src/main/java/core/algo/EpisodicLearning.java b/src/main/java/core/algo/EpisodicLearning.java index ceb3756..3a56363 100644 --- a/src/main/java/core/algo/EpisodicLearning.java +++ b/src/main/java/core/algo/EpisodicLearning.java @@ -2,59 +2,53 @@ package core.algo; import core.DiscreteActionSpace; import core.Environment; +import core.StepResult; import core.listener.LearningListener; +import lombok.Getter; import lombok.Setter; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + public abstract class EpisodicLearning extends Learning implements Episodic { @Setter protected int currentEpisode; - protected int episodesToLearn; + protected volatile AtomicInteger episodesToLearn = new AtomicInteger(0); + @Getter protected volatile int episodePerSecond; protected int episodeSumCurrentSecond; - private volatile boolean measureEpisodeBenchMark; + protected double sumOfRewards; + protected List> episode = new ArrayList<>(); public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, int delay) { super(environment, actionSpace, discountFactor, delay); + initBenchMarking(); } public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor) { super(environment, actionSpace, discountFactor); + initBenchMarking(); } public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace, int delay) { super(environment, actionSpace, delay); + initBenchMarking(); } public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace) { super(environment, actionSpace); + initBenchMarking(); } - protected void dispatchEpisodeEnd(double recentSumOfRewards){ - ++episodeSumCurrentSecond; - if(rewardHistory.size() > 10000){ - rewardHistory.clear(); - } - rewardHistory.add(recentSumOfRewards); - for(LearningListener l: learningListeners) { - l.onEpisodeEnd(rewardHistory); - } - } + protected abstract void nextEpisode(); - protected void dispatchEpisodeStart(){ - for(LearningListener l: learningListeners){ - l.onEpisodeStart(); - } - } - - @Override - public void learn(){ - learn(0); - } - - public void learn(int nrOfEpisodes){ - measureEpisodeBenchMark = true; + private void initBenchMarking(){ new Thread(()->{ - while(measureEpisodeBenchMark){ + while (true){ episodePerSecond = episodeSumCurrentSecond; episodeSumCurrentSecond = 0; try { @@ -64,24 +58,89 @@ public abstract class EpisodicLearning extends Learning imple } } }).start(); - episodesToLearn += nrOfEpisodes; - dispatchLearningStart(); - for(int i=0; i < nrOfEpisodes; ++i){ - nextEpisode(); - } - dispatchLearningEnd(); - measureEpisodeBenchMark = false; } - protected abstract void nextEpisode(); + protected void dispatchEpisodeEnd(){ + ++episodeSumCurrentSecond; + if(rewardHistory.size() > 10000){ + rewardHistory.clear(); + } + rewardHistory.add(sumOfRewards); + for(LearningListener l: learningListeners) { + l.onEpisodeEnd(rewardHistory); + } + } + + protected void dispatchEpisodeStart(){ + ++currentEpisode; + episodesToLearn.decrementAndGet(); + for(LearningListener l: learningListeners){ + l.onEpisodeStart(); + } + } @Override - public int getCurrentEpisode(){ - return currentEpisode; + public void learn(){ + // TODO remove or learn with default episode number + } + + private void startLearning(){ + learningExecutor.submit(()->{ + dispatchLearningStart(); + while(episodesToLearn.get() > 0){ + dispatchEpisodeStart(); + nextEpisode(); + dispatchEpisodeEnd(); + } + synchronized (this){ + dispatchLearningEnd(); + notifyAll(); + } + }); + } + + /** + * Stopping the while loop by setting episodesToLearn to 0. + * The current episode can not be interrupted, so the sleep delay + * is removed and the calling thread has to wait until the + * current episode is done. + * Resetting the delay afterwards. + */ + @Override + public synchronized void interruptLearning(){ + episodesToLearn.set(0); + int prevDelay = delay; + delay = 0; + while(currentlyLearning) { + try { + wait(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + delay = prevDelay; + } + + public synchronized void learn(int nrOfEpisodes){ + boolean isLearning = episodesToLearn.getAndAdd(nrOfEpisodes) != 0; + if(!isLearning) + startLearning(); } @Override public int getEpisodesToGo(){ - return episodesToLearn - currentEpisode; + return episodesToLearn.get(); + } + + @Override + public synchronized void save(ObjectOutputStream oos) throws IOException { + super.save(oos); + oos.writeInt(currentEpisode); + } + + @Override + public synchronized void load(ObjectInputStream ois) throws IOException, ClassNotFoundException { + super.load(ois); + currentEpisode = ois.readInt(); } } diff --git a/src/main/java/core/algo/Learning.java b/src/main/java/core/algo/Learning.java index 8c589d2..c63ef43 100644 --- a/src/main/java/core/algo/Learning.java +++ b/src/main/java/core/algo/Learning.java @@ -9,10 +9,16 @@ import core.policy.Policy; import lombok.Getter; import lombok.Setter; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; /** * @@ -30,14 +36,18 @@ public abstract class Learning{ @Setter protected int delay; protected List rewardHistory; + protected ExecutorService learningExecutor; + protected boolean currentlyLearning; public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, int delay) { this.environment = environment; this.actionSpace = actionSpace; this.discountFactor = discountFactor; this.delay = delay; + currentlyLearning = false; learningListeners = new HashSet<>(); rewardHistory = new CopyOnWriteArrayList<>(); + learningExecutor = Executors.newSingleThreadExecutor(); } public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor) { @@ -52,7 +62,6 @@ public abstract class Learning{ this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_DELAY); } - public abstract void learn(); public void addListener(LearningListener learningListener) { @@ -66,15 +75,31 @@ public abstract class Learning{ } protected void dispatchLearningStart() { + currentlyLearning = true; for (LearningListener l : learningListeners) { l.onLearningStart(); } } protected void dispatchLearningEnd() { + currentlyLearning = false; for (LearningListener l : learningListeners) { l.onLearningEnd(); } } + public synchronized void interruptLearning(){ + //TODO: for non episodic learning + } + + + public void save(ObjectOutputStream oos) throws IOException { + oos.writeObject(rewardHistory); + oos.writeObject(stateActionTable); + } + + public void load(ObjectInputStream ois) throws IOException, ClassNotFoundException { + rewardHistory = (List) ois.readObject(); + stateActionTable = (StateActionTable) ois.readObject(); + } } diff --git a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java index 6f468c0..e9b02c9 100644 --- a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java +++ b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java @@ -5,6 +5,9 @@ import core.algo.EpisodicLearning; import core.policy.EpsilonGreedyPolicy; import javafx.util.Pair; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; import java.util.*; /** @@ -44,19 +47,16 @@ public class MonteCarloOnPolicyEGreedy extends EpisodicLearning< this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay); } - @Override public void nextEpisode() { - ++currentEpisode; - List> episode = new ArrayList<>(); + episode = new ArrayList<>(); State state = environment.reset(); - dispatchEpisodeStart(); try { Thread.sleep(delay); } catch (InterruptedException e) { e.printStackTrace(); } - double sumOfRewards = 0; + sumOfRewards = 0; StepResultEnvironment envResult = null; while(envResult == null || !envResult.isDone()){ Map actionValues = stateActionTable.getActionValues(state); @@ -76,7 +76,6 @@ public class MonteCarloOnPolicyEGreedy extends EpisodicLearning< dispatchStepEnd(); } - dispatchEpisodeEnd(sumOfRewards); // System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards); Set> stateActionPairs = new LinkedHashSet<>(); @@ -115,4 +114,18 @@ public class MonteCarloOnPolicyEGreedy extends EpisodicLearning< public int getEpisodesPerSecond(){ return episodePerSecond; } + + @Override + public void save(ObjectOutputStream oos) throws IOException { + super.save(oos); + oos.writeObject(returnSum); + oos.writeObject(returnCount); + } + + @Override + public void load(ObjectInputStream ois) throws IOException, ClassNotFoundException { + super.load(ois); + returnSum = (Map, Double>) ois.readObject(); + returnCount = (Map, Integer>) ois.readObject(); + } } diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java index 4cafe4e..0fb589a 100644 --- a/src/main/java/core/controller/RLController.java +++ b/src/main/java/core/controller/RLController.java @@ -18,6 +18,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; public class RLController implements ViewListener, LearningListener { + private final String folderPrefix = "learningStates" + File.separator; private Environment environment; private DiscreteActionSpace discreteActionSpace; private Method method; @@ -26,15 +27,12 @@ public class RLController implements ViewListener, LearningListe private float epsilon = LearningConfig.DEFAULT_EPSILON; private Learning learning; private LearningView learningView; - private ExecutorService learningExecutor; - private boolean currentlyLearning; private boolean fastLearning; private List latestRewardsHistory; private int nrOfEpisodes; private int prevDelay; public RLController(Environment env, Method method, A... actions){ - learningExecutor = Executors.newSingleThreadExecutor(); setEnvironment(env); setMethod(method); setAllowedActions(actions); @@ -64,9 +62,9 @@ public class RLController implements ViewListener, LearningListe private void initLearning(){ if(learning instanceof EpisodicLearning){ - learningExecutor.submit(()->((EpisodicLearning) learning).learn(nrOfEpisodes)); + ((EpisodicLearning) learning).learn(nrOfEpisodes); }else{ - learningExecutor.submit(()->learning.learn()); + learning.learn(); } } @@ -75,27 +73,25 @@ public class RLController implements ViewListener, LearningListe *************************************************/ @Override public void onLearnMoreEpisodes(int nrOfEpisodes){ - if(!currentlyLearning){ - if(learning instanceof EpisodicLearning){ - learningExecutor.submit(()->((EpisodicLearning) learning).learn(nrOfEpisodes)); - }else{ - throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!"); - } + if(learning instanceof EpisodicLearning){ + ((EpisodicLearning) learning).learn(nrOfEpisodes); + }else{ + throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!"); } + learningView.updateLearningInfoPanel(); } @Override public void onLoadState(String fileName) { FileInputStream fis; - ObjectInput in; + ObjectInputStream in; try { fis = new FileInputStream(fileName); in = new ObjectInputStream(fis); - SaveState saveState = (SaveState) in.readObject(); - learning.setStateActionTable(saveState.getStateActionTable()); - if(learning instanceof EpisodicLearning){ - ((EpisodicLearning) learning).setCurrentEpisode(saveState.getCurrentEpisode()); - } + System.out.println("interrup" + Thread.currentThread().getId()); + learning.interruptLearning(); + learning.load(in); + SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel()); in.close(); } catch (IOException | ClassNotFoundException e) { e.printStackTrace(); @@ -107,15 +103,10 @@ public class RLController implements ViewListener, LearningListe FileOutputStream fos; ObjectOutputStream out; try{ - fos = new FileOutputStream(fileName); + fos = new FileOutputStream(folderPrefix + fileName); out = new ObjectOutputStream(fos); - int currentEpisode; - if(learning instanceof EpisodicLearning){ - currentEpisode = ((EpisodicLearning) learning).getCurrentEpisode(); - }else{ - currentEpisode = 0; - } - out.writeObject(new SaveState<>(learning.getStateActionTable(), currentEpisode)); + learning.interruptLearning(); + learning.save(out); out.close(); }catch (IOException e){ e.printStackTrace(); @@ -158,13 +149,12 @@ public class RLController implements ViewListener, LearningListe *************************************************/ @Override public void onLearningStart() { - currentlyLearning = true; } @Override public void onLearningEnd() { - currentlyLearning = false; SwingUtilities.invokeLater(()-> learningView.updateRewardGraph(latestRewardsHistory)); + onSaveState( method.toString() + System.currentTimeMillis()/1000 + (learning instanceof EpisodicLearning ? "e " + ((EpisodicLearning) learning).getCurrentEpisode() : "")); } @Override @@ -192,7 +182,7 @@ public class RLController implements ViewListener, LearningListe /************************************************* - ** SETTER ** + ** SETTERS ** *************************************************/ private void setEnvironment(Environment environment){ diff --git a/src/main/java/evironment/jumpingDino/DinoWorld.java b/src/main/java/evironment/jumpingDino/DinoWorld.java index 3f40524..9ca4fbf 100644 --- a/src/main/java/evironment/jumpingDino/DinoWorld.java +++ b/src/main/java/evironment/jumpingDino/DinoWorld.java @@ -14,12 +14,20 @@ import java.awt.*; public class DinoWorld implements Environment, Visualizable { private Dino dino; private Obstacle currentObstacle; + private boolean randomObstacleSpeed; + private boolean randomObstacleDistance; - public DinoWorld(){ + public DinoWorld(boolean randomObstacleSpeed, boolean randomObstacleDistance){ + this.randomObstacleSpeed = randomObstacleSpeed; + this.randomObstacleDistance = randomObstacleDistance; dino = new Dino(Config.DINO_SIZE, Config.DINO_STARTING_X, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE, 0, 0, Color.GREEN); spawnNewObstacle(); } + public DinoWorld(){ + this(false, false); + } + private boolean ranIntoObstacle(){ Obstacle o = currentObstacle; Dino p = dino; @@ -32,6 +40,7 @@ public class DinoWorld implements Environment, Visualizable { return xAxis && yAxis; } + private int getDistanceToObstacle(){ return currentObstacle.getX() - dino.getX() + Config.DINO_SIZE; } @@ -57,8 +66,27 @@ public class DinoWorld implements Environment, Visualizable { return new StepResultEnvironment(new DinoState(getDistanceToObstacle()), reward, done, ""); } + + private void spawnNewObstacle(){ - currentObstacle = new Obstacle(Config.OBSTACLE_SIZE, Config.FRAME_WIDTH + Config.OBSTACLE_SIZE, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE, -Config.OBSTACLE_SPEED, 0, Color.BLACK); + int dx; + int xSpawn; + + if(randomObstacleSpeed){ + dx = -(int)((Math.random() + 0.5) * Config.OBSTACLE_SPEED); + }else{ + dx = -Config.OBSTACLE_SPEED; + } + + if(randomObstacleDistance){ + // randomly spawning more right outside of the screen + xSpawn = (int)(Math.random() + 0.5 * Config.FRAME_WIDTH + Config.FRAME_WIDTH + Config.OBSTACLE_SIZE); + }else{ + // instantly respawning on the left screen border + xSpawn = Config.FRAME_WIDTH + Config.OBSTACLE_SIZE; + } + + currentObstacle = new Obstacle(Config.OBSTACLE_SIZE, xSpawn, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE, dx, 0, Color.BLACK); } private void spawnDino(){ diff --git a/src/main/java/example/JumpingDino.java b/src/main/java/example/JumpingDino.java index bddc8b2..2e0a447 100644 --- a/src/main/java/example/JumpingDino.java +++ b/src/main/java/example/JumpingDino.java @@ -11,7 +11,7 @@ public class JumpingDino { RNG.setSeed(55); RLController rl = new RLController<>( - new DinoWorld(), + new DinoWorld(true, true), Method.MC_ONPOLICY_EGREEDY, DinoAction.values());