From 0e4f52a48e7cd4c3baf0e5128a40a9af5d0c91d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20L=C3=B6wenstrom?= Date: Thu, 27 Feb 2020 15:29:15 +0100 Subject: [PATCH] first epsilon decaying method --- src/main/java/core/algo/EpisodicLearning.java | 24 +++++++++++++++ src/main/java/core/algo/Learning.java | 10 ++++++- src/main/java/core/algo/Method.java | 2 +- .../MonteCarloControlFirstVisitEGreedy.java} | 13 +++++--- .../QLearningOffPolicyTDControl.java | 0 src/main/java/core/algo/{TD => td}/SARSA.java | 2 +- .../java/core/controller/RLController.java | 8 ++--- .../evironment/jumpingDino/DinoWorld.java | 4 +-- src/main/java/example/DinoSampling.java | 27 +++++++++++++++++ src/main/java/example/JumpingDino.java | 15 +++++----- src/main/java/example/RunningAnt.java | 2 +- src/test/java/MCFirstVisit.java | 30 +++++++++++++++++++ 12 files changed, 115 insertions(+), 22 deletions(-) rename src/main/java/core/algo/{MC/MonteCarloControlEGreedy.java => mc/MonteCarloControlFirstVisitEGreedy.java} (87%) rename src/main/java/core/algo/{TD => td}/QLearningOffPolicyTDControl.java (100%) rename src/main/java/core/algo/{TD => td}/SARSA.java (98%) create mode 100644 src/main/java/example/DinoSampling.java create mode 100644 src/test/java/MCFirstVisit.java diff --git a/src/main/java/core/algo/EpisodicLearning.java b/src/main/java/core/algo/EpisodicLearning.java index 4f997d3..559bfc6 100644 --- a/src/main/java/core/algo/EpisodicLearning.java +++ b/src/main/java/core/algo/EpisodicLearning.java @@ -5,6 +5,7 @@ import core.Environment; import core.LearningConfig; import core.StepResult; import core.listener.LearningListener; +import core.policy.EpsilonGreedyPolicy; import lombok.Getter; import lombok.Setter; @@ -74,12 +75,35 @@ public abstract class EpisodicLearning extends Learning imple protected void dispatchEpisodeStart(){ ++currentEpisode; + /* + 2f 0.02 => 100 + 1.5f 0.02 => 75 + 1.4f 0.02 => fail + 1.5f 0.1 => 16 ! + */ + if(this.policy instanceof EpsilonGreedyPolicy){ + float ep = 1.5f/(float)currentEpisode; + if(ep < 0.10) ep = 0; + ((EpsilonGreedyPolicy) this.policy).setEpsilon(ep); + System.out.println(ep); + } episodesToLearn.decrementAndGet(); for(LearningListener l: learningListeners){ l.onEpisodeStart(); } } + @Override + protected void dispatchStepEnd() { + super.dispatchStepEnd(); + timestamp++; + // TODO: more sophisticated way to check convergence + if(timestamp > 300000){ + System.out.println("converged after: " + currentEpisode + " episode!"); + interruptLearning(); + } + } + @Override public void learn(){ learn(LearningConfig.DEFAULT_NR_OF_EPISODES); diff --git a/src/main/java/core/algo/Learning.java b/src/main/java/core/algo/Learning.java index c63ef43..fbba9ef 100644 --- a/src/main/java/core/algo/Learning.java +++ b/src/main/java/core/algo/Learning.java @@ -12,7 +12,6 @@ import lombok.Setter; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; -import java.io.Serializable; import java.util.HashSet; import java.util.List; import java.util.Set; @@ -26,6 +25,13 @@ import java.util.concurrent.Executors; */ @Getter public abstract class Learning{ + // TODO: temp testing -> extract to dedicated test + protected int checkSum; + protected int rewardCheckSum; + + // current discrete timestamp t + protected int timestamp; + protected int currentEpisode; protected Policy policy; protected DiscreteActionSpace actionSpace; @Setter @@ -83,6 +89,8 @@ public abstract class Learning{ protected void dispatchLearningEnd() { currentlyLearning = false; + System.out.println("Checksum: " + checkSum); + System.out.println("Reward Checksum: " + rewardCheckSum); for (LearningListener l : learningListeners) { l.onLearningEnd(); } diff --git a/src/main/java/core/algo/Method.java b/src/main/java/core/algo/Method.java index 473a889..6372b24 100644 --- a/src/main/java/core/algo/Method.java +++ b/src/main/java/core/algo/Method.java @@ -5,5 +5,5 @@ package core.algo; * which RL-algorithm should be used. */ public enum Method { - MC_CONTROL_EGREEDY, SARSA_EPISODIC, Q_LEARNING_OFF_POLICY_CONTROL + MC_CONTROL_FIRST_VISIT, SARSA_EPISODIC, Q_LEARNING_OFF_POLICY_CONTROL } diff --git a/src/main/java/core/algo/MC/MonteCarloControlEGreedy.java b/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java similarity index 87% rename from src/main/java/core/algo/MC/MonteCarloControlEGreedy.java rename to src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java index d7cd028..50e9ce2 100644 --- a/src/main/java/core/algo/MC/MonteCarloControlEGreedy.java +++ b/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java @@ -17,7 +17,7 @@ import java.util.*; * For example: *

* startingState -> MOVE_LEFT : very first state action in the episode i = 1 - * image the agent does not collect the food and drops it to the start, the agent will receive + * image the agent does not collect the food and does not drop it onto start, the agent will receive * -1 for every timestamp hence (startingState -> MOVE_LEFT) will get a value of -10; *

* BUT image moving left from the starting position will have no impact on the state because @@ -30,12 +30,12 @@ import java.util.*; * * @param */ -public class MonteCarloControlEGreedy extends EpisodicLearning { +public class MonteCarloControlFirstVisitEGreedy extends EpisodicLearning { private Map, Double> returnSum; private Map, Integer> returnCount; - public MonteCarloControlEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay) { + public MonteCarloControlFirstVisitEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay) { super(environment, actionSpace, discountFactor, delay); this.policy = new EpsilonGreedyPolicy<>(epsilon); this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace); @@ -43,7 +43,7 @@ public class MonteCarloControlEGreedy extends EpisodicLearning(); } - public MonteCarloControlEGreedy(Environment environment, DiscreteActionSpace actionSpace, int delay) { + public MonteCarloControlFirstVisitEGreedy(Environment environment, DiscreteActionSpace actionSpace, int delay) { this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay); } @@ -58,12 +58,16 @@ public class MonteCarloControlEGreedy extends EpisodicLearning actionValues = stateActionTable.getActionValues(state); A chosenAction = policy.chooseAction(actionValues); + checkSum += chosenAction.ordinal(); envResult = environment.step(chosenAction); State nextState = envResult.getState(); sumOfRewards += envResult.getReward(); + rewardCheckSum += envResult.getReward(); episode.add(new StepResult<>(state, chosenAction, envResult.getReward())); state = nextState; @@ -73,6 +77,7 @@ public class MonteCarloControlEGreedy extends EpisodicLearning extends EpisodicLearning { actionValues = stateActionTable.getActionValues(nextState); A nextAction = policy.chooseAction(actionValues); - // TD update + // td update // target = reward + gamma * Q(nextState, nextAction) double currentQValue = stateActionTable.getActionValues(state).get(action); double nextQValue = stateActionTable.getActionValues(nextState).get(nextAction); diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java index 5bd3381..d36c1da 100644 --- a/src/main/java/core/controller/RLController.java +++ b/src/main/java/core/controller/RLController.java @@ -7,7 +7,7 @@ import core.ListDiscreteActionSpace; import core.algo.EpisodicLearning; import core.algo.Learning; import core.algo.Method; -import core.algo.mc.MonteCarloControlEGreedy; +import core.algo.mc.MonteCarloControlFirstVisitEGreedy; import core.algo.td.QLearningOffPolicyTDControl; import core.algo.td.SARSA; import core.listener.LearningListener; @@ -48,8 +48,8 @@ public class RLController implements LearningListener { public void start() { switch(method) { - case MC_CONTROL_EGREEDY: - learning = new MonteCarloControlEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay); + case MC_CONTROL_FIRST_VISIT: + learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay); break; case SARSA_EPISODIC: learning = new SARSA<>(environment, discreteActionSpace, discountFactor, epsilon, learningRate, delay); @@ -115,7 +115,7 @@ public class RLController implements LearningListener { try { fis = new FileInputStream(fileName); in = new ObjectInputStream(fis); - System.out.println("interrup" + Thread.currentThread().getId()); + System.out.println("interrupt" + Thread.currentThread().getId()); learning.interruptLearning(); learning.load(in); in.close(); diff --git a/src/main/java/evironment/jumpingDino/DinoWorld.java b/src/main/java/evironment/jumpingDino/DinoWorld.java index 953e2c4..103fc08 100644 --- a/src/main/java/evironment/jumpingDino/DinoWorld.java +++ b/src/main/java/evironment/jumpingDino/DinoWorld.java @@ -50,7 +50,7 @@ public class DinoWorld implements Environment, Visualizable { @Override public StepResultEnvironment step(DinoAction action) { boolean done = false; - int reward = 0; + int reward = 1; if(action == DinoAction.JUMP){ dino.jump(); @@ -74,7 +74,7 @@ public class DinoWorld implements Environment, Visualizable { spawnNewObstacle(); } if(ranIntoObstacle()) { - reward = -1; + reward = 0; done = true; } diff --git a/src/main/java/example/DinoSampling.java b/src/main/java/example/DinoSampling.java new file mode 100644 index 0000000..09d5c5b --- /dev/null +++ b/src/main/java/example/DinoSampling.java @@ -0,0 +1,27 @@ +package example; + +import core.RNG; +import core.algo.Method; +import core.controller.RLController; +import evironment.jumpingDino.DinoAction; +import evironment.jumpingDino.DinoWorld; + +public class DinoSampling { + public static void main(String[] args) { + for (int i = 0; i < 10 ; i++) { + RNG.setSeed(55); + + RLController rl = new RLController<>( + new DinoWorld(false, false), + Method.MC_CONTROL_FIRST_VISIT, + DinoAction.values()); + + rl.setDelay(0); + rl.setDiscountFactor(1f); + rl.setEpsilon(0.15f); + rl.setLearningRate(1f); + rl.setNrOfEpisodes(400); + rl.start(); + } + } +} diff --git a/src/main/java/example/JumpingDino.java b/src/main/java/example/JumpingDino.java index c033316..41d5290 100644 --- a/src/main/java/example/JumpingDino.java +++ b/src/main/java/example/JumpingDino.java @@ -3,7 +3,6 @@ package example; import core.RNG; import core.algo.Method; import core.controller.RLController; -import core.controller.RLControllerGUI; import evironment.jumpingDino.DinoAction; import evironment.jumpingDino.DinoWorld; @@ -11,16 +10,16 @@ public class JumpingDino { public static void main(String[] args) { RNG.setSeed(55); - RLController rl = new RLControllerGUI<>( + RLController rl = new RLController<>( new DinoWorld(false, false), - Method.Q_LEARNING_OFF_POLICY_CONTROL, + Method.MC_CONTROL_FIRST_VISIT, DinoAction.values()); - rl.setDelay(1000); - rl.setDiscountFactor(0.9f); - rl.setEpsilon(0.1f); - rl.setLearningRate(0.5f); - rl.setNrOfEpisodes(4000000); + rl.setDelay(0); + rl.setDiscountFactor(1f); + rl.setEpsilon(0.15f); + rl.setLearningRate(1f); + rl.setNrOfEpisodes(400); rl.start(); } } diff --git a/src/main/java/example/RunningAnt.java b/src/main/java/example/RunningAnt.java index 9b83316..f1d393c 100644 --- a/src/main/java/example/RunningAnt.java +++ b/src/main/java/example/RunningAnt.java @@ -13,7 +13,7 @@ public class RunningAnt { RLController rl = new RLControllerGUI<>( new AntWorld(3, 3, 0.1), - Method.MC_CONTROL_EGREEDY, + Method.MC_CONTROL_FIRST_VISIT, AntAction.values()); rl.setDelay(200); diff --git a/src/test/java/MCFirstVisit.java b/src/test/java/MCFirstVisit.java new file mode 100644 index 0000000..c1cd581 --- /dev/null +++ b/src/test/java/MCFirstVisit.java @@ -0,0 +1,30 @@ +import core.RNG; +import core.algo.Method; +import core.controller.RLController; +import core.controller.RLControllerGUI; +import evironment.jumpingDino.DinoAction; +import evironment.jumpingDino.DinoWorld; +import org.junit.Test; + +public class MCFirstVisit { + + /** + * Test if the action sequence is deterministic + */ + @Test + public void deterministicActionSequence(){ + RNG.setSeed(55); + + RLController rl = new RLControllerGUI<>( + new DinoWorld(false, false), + Method.MC_CONTROL_FIRST_VISIT, + DinoAction.values()); + + rl.setDelay(10); + rl.setDiscountFactor(1f); + rl.setEpsilon(0.1f); + rl.setLearningRate(0.8f); + rl.setNrOfEpisodes(4000000); + rl.start(); + } +}