From 5b82e7965da93c5d062bff60cb98b809fa87105f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20L=C3=B6wenstrom?= Date: Sun, 5 Apr 2020 12:29:44 +0200 Subject: [PATCH] rename MC class and improve specific analysis of antGame examples --- ...edy.java => MonteCarloControlEGreedy.java} | 10 ++-- .../algo/td/QLearningOffPolicyTDControl.java | 17 +----- .../java/core/controller/RLController.java | 6 +-- src/main/java/example/ContinuousAnt.java | 26 +++++----- src/main/java/example/DinoSampling.java | 11 ++-- src/main/java/example/Results | 15 ------ src/main/java/example/RunningAnt.java | 2 +- src/main/java/example/Test.java | 52 ------------------- 8 files changed, 28 insertions(+), 111 deletions(-) rename src/main/java/core/algo/mc/{MonteCarloControlFirstVisitEGreedy.java => MonteCarloControlEGreedy.java} (88%) delete mode 100644 src/main/java/example/Results delete mode 100644 src/main/java/example/Test.java diff --git a/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java b/src/main/java/core/algo/mc/MonteCarloControlEGreedy.java similarity index 88% rename from src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java rename to src/main/java/core/algo/mc/MonteCarloControlEGreedy.java index f00fa4b..c0f78d8 100644 --- a/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java +++ b/src/main/java/core/algo/mc/MonteCarloControlEGreedy.java @@ -12,19 +12,19 @@ import java.io.ObjectOutputStream; import java.util.*; /** - * Includes both variants of Monte-Carlo methods + * Includes both! variants of Monte-Carlo methods * Default method is First-Visit. * Change to Every-Visit by setting flag "useEveryVisit" in the constructor to true. * @param */ -public class MonteCarloControlFirstVisitEGreedy extends EpisodicLearning { +public class MonteCarloControlEGreedy extends EpisodicLearning { private Map, Double> returnSum; private Map, Integer> returnCount; private boolean isEveryVisit; - public MonteCarloControlFirstVisitEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay, boolean useEveryVisit) { + public MonteCarloControlEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay, boolean useEveryVisit) { super(environment, actionSpace, discountFactor, delay); isEveryVisit = useEveryVisit; this.policy = new EpsilonGreedyPolicy<>(epsilon); @@ -33,11 +33,11 @@ public class MonteCarloControlFirstVisitEGreedy extends Episodic returnCount = new HashMap<>(); } - public MonteCarloControlFirstVisitEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay) { + public MonteCarloControlEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay) { this(environment, actionSpace, discountFactor, epsilon, delay, false); } - public MonteCarloControlFirstVisitEGreedy(Environment environment, DiscreteActionSpace actionSpace, int delay) { + public MonteCarloControlEGreedy(Environment environment, DiscreteActionSpace actionSpace, int delay) { this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay); } diff --git a/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java b/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java index fe0c5cf..2da1c60 100644 --- a/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java +++ b/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java @@ -45,9 +45,7 @@ public class QLearningOffPolicyTDControl extends EpisodicLearnin sumOfRewards = 0; int timestampTilFood = 0; - int rewardsPer1000 = 0; int foodCollected = 0; - int iterations = 0; int foodTimestampsTotal= 0; while(envResult == null || !envResult.isDone()) { actionValues = stateActionTable.getActionValues(state); @@ -58,20 +56,8 @@ public class QLearningOffPolicyTDControl extends EpisodicLearnin double reward = envResult.getReward(); State nextState = envResult.getState(); sumOfRewards += reward; - rewardsPer1000+=reward; timestampTilFood++; - /* if(iterations == 100){ - File file = new File(ContinuousAnt.FILE_NAME); - try { - Files.writeString(Path.of(file.getPath()), "\n", StandardOpenOption.APPEND); - } catch (IOException e) { - e.printStackTrace(); - } - return; - }*/ - - if(reward == Reward.FOOD_DROP_DOWN_SUCCESS) { foodCollected++; foodTimestampsTotal += timestampTilFood; @@ -95,7 +81,7 @@ public class QLearningOffPolicyTDControl extends EpisodicLearnin ((EpsilonGreedyPolicy) this.policy).setEpsilon(0.05f); } if(foodCollected == 4000){ - System.out.println("final 0 expl"); + System.out.println("Reached 0 exploration"); ((EpsilonGreedyPolicy) this.policy).setEpsilon(0.00f); } if(foodCollected == 15000){ @@ -106,7 +92,6 @@ public class QLearningOffPolicyTDControl extends EpisodicLearnin } return; } - iterations++; timestampTilFood = 0; } diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java index 895102a..e4df7e0 100644 --- a/src/main/java/core/controller/RLController.java +++ b/src/main/java/core/controller/RLController.java @@ -7,7 +7,7 @@ import core.ListDiscreteActionSpace; import core.algo.EpisodicLearning; import core.algo.Learning; import core.algo.Method; -import core.algo.mc.MonteCarloControlFirstVisitEGreedy; +import core.algo.mc.MonteCarloControlEGreedy; import core.algo.td.QLearningOffPolicyTDControl; import core.algo.td.SARSA; import core.listener.LearningListener; @@ -49,10 +49,10 @@ public class RLController implements LearningListener { public void start() { switch(method) { case MC_CONTROL_FIRST_VISIT: - learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay); + learning = new MonteCarloControlEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay); break; case MC_CONTROL_EVERY_VISIT: - learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay, true); + learning = new MonteCarloControlEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay, true); break; case SARSA_ON_POLICY_CONTROL: diff --git a/src/main/java/example/ContinuousAnt.java b/src/main/java/example/ContinuousAnt.java index 7a49b1d..e9d7ace 100644 --- a/src/main/java/example/ContinuousAnt.java +++ b/src/main/java/example/ContinuousAnt.java @@ -11,30 +11,28 @@ import java.io.File; import java.io.IOException; public class ContinuousAnt { - public static final String FILE_NAME = "converge22.txt"; + public static final String FILE_NAME = "converge.txt"; + public static void main(String[] args) { - int i = 4+4+4+6+6+6+8+10+12+14+14+16+16+16+18+18+18+20+20+20+22+22+22+24+24+24+24+26+26+26+26+26+28+28+28+28+28+30+30+30+30+32+32+32+34+34+34+36+36+38+40+42; - System.out.println(i/52f); File file = new File(FILE_NAME); try { file.createNewFile(); } catch (IOException e) { e.printStackTrace(); } - RNG.setSeed(13); + RNG.setSeed(13); RLController rl = new RLControllerGUI<>( - new AntWorldContinuous(8, 8), - Method.Q_LEARNING_OFF_POLICY_CONTROL, - AntAction.values()); + new AntWorldContinuous(8, 8), + Method.Q_LEARNING_OFF_POLICY_CONTROL, + AntAction.values()); rl.setDelay(20); - rl.setNrOfEpisodes(1); - //0.99 0.9 0.5 - //0.99 0.95 0.9 0.7 0.5 0.3 0.1 + rl.setNrOfEpisodes(1); + // 0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 0.99 rl.setDiscountFactor(0.05f); - // 0.1, 0.3, 0.5, 0.7 0.9 - rl.setLearningRate(0.9f); - rl.setEpsilon(0.2f); - rl.start(); + // 0.1, 0.3, 0.5, 0.7 0.9 + rl.setLearningRate(0.9f); + rl.setEpsilon(0.2f); + rl.start(); } diff --git a/src/main/java/example/DinoSampling.java b/src/main/java/example/DinoSampling.java index c4bbde7..6c748ac 100644 --- a/src/main/java/example/DinoSampling.java +++ b/src/main/java/example/DinoSampling.java @@ -14,8 +14,8 @@ import java.nio.file.Path; import java.nio.file.StandardOpenOption; public class DinoSampling { - public static final float f =0.05f; public static final String FILE_NAME = "converge.txt"; + public static void main(String[] args) { File file = new File(FILE_NAME); try { @@ -23,15 +23,16 @@ public class DinoSampling { } catch (IOException e) { e.printStackTrace(); } - for(float f = 0.05f; f <=1.003 ; f+=0.05f) { + for(float f = 0.05f; f <= 1.003; f += 0.05f) { try { Files.writeString(Path.of(file.getPath()), f + ",", StandardOpenOption.APPEND); } catch (IOException e) { e.printStackTrace(); } - for (int i = 1; i <= 100; i++) { - System.out.println("seed: " + i * 13); - RNG.setSeed(i * 13); + for(int i = 1; i <= 100; i++) { + int seed = i * 13; + System.out.println("seed: " + seed); + RNG.setSeed(seed); RLController rl = new RLControllerGUI<>( new DinoWorld(), diff --git a/src/main/java/example/Results b/src/main/java/example/Results deleted file mode 100644 index ba42fc7..0000000 --- a/src/main/java/example/Results +++ /dev/null @@ -1,15 +0,0 @@ -Method: -Epsilon = k / currentEpisode -set to 0 if Epsilon < b - -k = 1.5 -b = 0.1 => conv. 16 - -k = 1.5 -b = 0.02 => 75 - -k = 1.4 -b = 0.02 => fail - -k = 2.0 -b = 0.02 => conv. 100 diff --git a/src/main/java/example/RunningAnt.java b/src/main/java/example/RunningAnt.java index d01461c..eddea17 100644 --- a/src/main/java/example/RunningAnt.java +++ b/src/main/java/example/RunningAnt.java @@ -19,8 +19,8 @@ public class RunningAnt { rl.setDelay(200); rl.setNrOfEpisodes(10000); rl.setDiscountFactor(0.9f); + rl.setLearningRate(0.9f); rl.setEpsilon(0.15f); - rl.start(); } } diff --git a/src/main/java/example/Test.java b/src/main/java/example/Test.java deleted file mode 100644 index c7ee691..0000000 --- a/src/main/java/example/Test.java +++ /dev/null @@ -1,52 +0,0 @@ -package example; - -public class Test { - interface Drawable{ - void draw(); - } - interface State{ - int getInt(); - } - - static class A implements Drawable, State{ - private int k; - public A(int a){ - k = a; - } - @Override - public void draw() { - System.out.println("draw " + k); - } - - @Override - public int getInt() { - System.out.println("getInt" + k); - return k; - } - } - - static class B implements State{ - @Override - public int getInt() { - return 0; - } - } - - public static void main(String[] args) { - State state = new A(24); - State state2 = new B(); - state.getInt(); - - System.out.println(state2 instanceof Drawable); - drawState(state2); - } - - static void drawState(State s){ - if(s instanceof Drawable){ - Drawable d = (Drawable) s; - d.draw(); - }else{ - System.out.println("invalid"); - } - } -}