From 7db5a2af3b92f2702aeff9ec32d3e1496d6a3025 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20L=C3=B6wenstrom?= Date: Fri, 20 Dec 2019 16:51:09 +0100 Subject: [PATCH] add fix RNG, add extended interface EpsilonPolicy and move rewardHistory to model instead of view - only setting the seed of RNG once at the beginning and not reseeding it afterwards. Deep copying the initial AntWorld to use as blueprint for resetting the world instead of reseeding and creating pesudo random again. Reseeding the RNG has influence action selecting to always choose the same trajectory. - instance of is used to determine if policy has epsilon or not and the view will adopt to this, only showing epsilon slider if policy has epsilon --- src/main/java/core/algo/Learning.java | 26 +++++++++++------ .../algo/MC/MonteCarloOnPolicyEGreedy.java | 17 +++++++---- .../java/core/controller/RLController.java | 28 +++++++++++++------ src/main/java/core/gui/LearningInfoPanel.java | 26 +++++++++++------ src/main/java/core/gui/View.java | 12 ++++---- .../java/core/listener/LearningListener.java | 4 ++- .../java/core/policy/EpsilonGreedyPolicy.java | 8 +++++- src/main/java/core/policy/EpsilonPolicy.java | 6 ++++ src/main/java/core/policy/GreedyPolicy.java | 4 +-- src/main/java/core/policy/RandomPolicy.java | 1 + .../java/evironment/antGame/AntState.java | 16 ++--------- .../java/evironment/antGame/AntWorld.java | 9 ++++-- src/main/java/evironment/antGame/Grid.java | 14 +++++++--- src/main/java/evironment/antGame/Reward.java | 17 ++++++----- src/main/java/evironment/antGame/Util.java | 14 ++++++++++ src/main/java/example/RunningAnt.java | 2 +- 16 files changed, 130 insertions(+), 74 deletions(-) create mode 100644 src/main/java/core/policy/EpsilonPolicy.java create mode 100644 src/main/java/evironment/antGame/Util.java diff --git a/src/main/java/core/algo/Learning.java b/src/main/java/core/algo/Learning.java index 58a285a..a825bd5 100644 --- a/src/main/java/core/algo/Learning.java +++ b/src/main/java/core/algo/Learning.java @@ -10,8 +10,11 @@ import lombok.Getter; import lombok.Setter; import javax.swing.*; +import java.util.ArrayList; import java.util.HashSet; +import java.util.List; import java.util.Set; +import java.util.concurrent.CopyOnWriteArrayList; @Getter public abstract class Learning { @@ -20,38 +23,43 @@ public abstract class Learning { protected StateActionTable stateActionTable; protected Environment environment; protected float discountFactor; - @Setter - protected float epsilon; protected Set learningListeners; @Setter protected int delay; + private List rewardHistory; - public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay){ + public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, int delay){ this.environment = environment; this.actionSpace = actionSpace; this.discountFactor = discountFactor; - this.epsilon = epsilon; this.delay = delay; learningListeners = new HashSet<>(); + rewardHistory = new CopyOnWriteArrayList<>(); } - public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon){ - this(environment, actionSpace, discountFactor, epsilon, LearningConfig.DEFAULT_DELAY); + public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor){ + this(environment, actionSpace, discountFactor, LearningConfig.DEFAULT_DELAY); + } + + public Learning(Environment environment, DiscreteActionSpace actionSpace, int delay){ + this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, delay); } public Learning(Environment environment, DiscreteActionSpace actionSpace){ - this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, LearningConfig.DEFAULT_DELAY); + this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_DELAY); } + public abstract void learn(int nrOfEpisodes); public void addListener(LearningListener learningListener){ learningListeners.add(learningListener); } - protected void dispatchEpisodeEnd(double sum){ + protected void dispatchEpisodeEnd(double recentSumOfRewards){ + rewardHistory.add(recentSumOfRewards); for(LearningListener l: learningListeners) { - l.onEpisodeEnd(sum); + l.onEpisodeEnd(rewardHistory); } } diff --git a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java index d608b80..1bc1f11 100644 --- a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java +++ b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java @@ -4,6 +4,8 @@ import core.*; import core.algo.Learning; import core.policy.EpsilonGreedyPolicy; import javafx.util.Pair; +import lombok.Setter; + import java.util.*; /** @@ -26,13 +28,18 @@ import java.util.*; */ public class MonteCarloOnPolicyEGreedy extends Learning { - public MonteCarloOnPolicyEGreedy(Environment environment, DiscreteActionSpace actionSpace) { - super(environment, actionSpace); - discountFactor = 1f; - this.policy = new EpsilonGreedyPolicy<>(0.1f); - this.stateActionTable = new StateActionHashTable<>(actionSpace); + public MonteCarloOnPolicyEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay) { + super(environment, actionSpace, discountFactor, delay); + + this.policy = new EpsilonGreedyPolicy<>(epsilon); + this.stateActionTable = new StateActionHashTable<>(this.actionSpace); } + public MonteCarloOnPolicyEGreedy(Environment environment, DiscreteActionSpace actionSpace, int delay) { + this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay); + } + + @Override public void learn(int nrOfEpisodes) { diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java index c65e62d..80b4375 100644 --- a/src/main/java/core/controller/RLController.java +++ b/src/main/java/core/controller/RLController.java @@ -7,9 +7,9 @@ import core.algo.Learning; import core.algo.Method; import core.algo.mc.MonteCarloOnPolicyEGreedy; import core.gui.View; +import core.policy.EpsilonPolicy; import javax.swing.*; -import java.util.Optional; public class RLController implements ViewListener{ protected Environment environment; @@ -30,28 +30,38 @@ public class RLController implements ViewListener{ switch (method){ case MC_ONPOLICY_EGREEDY: - learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace); + learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, delay); break; case TD_ONPOLICY: break; default: throw new RuntimeException("Undefined method"); } - SwingUtilities.invokeLater(() ->{ - view = new View<>(learning, this); - learning.addListener(view); - }); + /* + not using SwingUtilities here on purpose to ensure the view is fully + initialized and can be passed as LearningListener. + */ + view = new View<>(learning, this); + learning.addListener(view); learning.learn(nrOfEpisodes); } - + @Override public void onEpsilonChange(float epsilon) { - learning.setEpsilon(epsilon); - SwingUtilities.invokeLater(() -> view.updateLearningInfoPanel()); + if(learning.getPolicy() instanceof EpsilonPolicy){ + ((EpsilonPolicy) learning.getPolicy()).setEpsilon(epsilon); + SwingUtilities.invokeLater(() -> view.updateLearningInfoPanel()); + }else{ + System.out.println("Trying to call inEpsilonChange on non-epsilon policy"); + } } @Override public void onDelayChange(int delay) { + learning.setDelay(delay); + SwingUtilities.invokeLater(() -> { + view.updateLearningInfoPanel(); + }); } public RLController setMethod(Method method){ diff --git a/src/main/java/core/gui/LearningInfoPanel.java b/src/main/java/core/gui/LearningInfoPanel.java index cbbd6ef..8c67589 100644 --- a/src/main/java/core/gui/LearningInfoPanel.java +++ b/src/main/java/core/gui/LearningInfoPanel.java @@ -2,6 +2,7 @@ package core.gui; import core.algo.Learning; import core.controller.ViewListener; +import core.policy.EpsilonPolicy; import javax.swing.*; @@ -11,6 +12,7 @@ public class LearningInfoPanel extends JPanel { private JLabel discountLabel; private JLabel epsilonLabel; private JSlider epsilonSlider; + private JLabel delayLabel; private JSlider delaySlider; public LearningInfoPanel(Learning learning, ViewListener viewListener){ @@ -19,12 +21,19 @@ public class LearningInfoPanel extends JPanel { policyLabel = new JLabel(); discountLabel = new JLabel(); epsilonLabel = new JLabel(); - epsilonSlider = new JSlider(0, 100, (int)(learning.getEpsilon() * 100)); - epsilonSlider.addChangeListener(e -> viewListener.onEpsilonChange(epsilonSlider.getValue() / 100f)); + delayLabel = new JLabel(); + delaySlider = new JSlider(0,1000, learning.getDelay()); + delaySlider.addChangeListener(e -> viewListener.onDelayChange(delaySlider.getValue())); add(policyLabel); add(discountLabel); - add(epsilonLabel); - add(epsilonSlider); + if(learning.getPolicy() instanceof EpsilonPolicy){ + epsilonSlider = new JSlider(0, 100, (int)((EpsilonPolicy)learning.getPolicy()).getEpsilon() * 100); + epsilonSlider.addChangeListener(e -> viewListener.onEpsilonChange(epsilonSlider.getValue() / 100f)); + add(epsilonLabel); + add(epsilonSlider); + } + add(delayLabel); + add(delaySlider); refreshLabels(); setVisible(true); } @@ -32,10 +41,9 @@ public class LearningInfoPanel extends JPanel { public void refreshLabels(){ policyLabel.setText("Policy: " + learning.getPolicy().getClass()); discountLabel.setText("Discount factor: " + learning.getDiscountFactor()); - epsilonLabel.setText("Exploration (Epsilon): " + learning.getEpsilon()); - } - - protected JSlider getEpsilonSlider(){ - return epsilonSlider; + if(learning.getPolicy() instanceof EpsilonPolicy){ + epsilonLabel.setText("Exploration (Epsilon): " + ((EpsilonPolicy)learning.getPolicy()).getEpsilon()); + } + delayLabel.setText("Delay (ms): " + learning.getDelay()); } } diff --git a/src/main/java/core/gui/View.java b/src/main/java/core/gui/View.java index 8939ac2..d031451 100644 --- a/src/main/java/core/gui/View.java +++ b/src/main/java/core/gui/View.java @@ -23,12 +23,10 @@ public class View implements LearningListener { private JFrame mainFrame; private XChartPanel rewardChartPanel; private ViewListener viewListener; - private List rewardHistory; public View(Learning learning, ViewListener viewListener){ this.learning = learning; this.viewListener = viewListener; - rewardHistory = new ArrayList<>(); this.initMainFrame(); } @@ -78,8 +76,7 @@ public class View implements LearningListener { }; } - public void updateRewardGraph(double recentReward){ - rewardHistory.add(recentReward); + public void updateRewardGraph(List rewardHistory){ chart.updateXYSeries("randomWalk", null, rewardHistory, null); rewardChartPanel.revalidate(); rewardChartPanel.repaint(); @@ -89,10 +86,11 @@ public class View implements LearningListener { this.learningInfoPanel.refreshLabels(); } - @Override - public void onEpisodeEnd(double sumOfRewards) { - SwingUtilities.invokeLater(()->updateRewardGraph(sumOfRewards)); + public void onEpisodeEnd(List rewardHistory) { + SwingUtilities.invokeLater(()->{ + updateRewardGraph(rewardHistory); + }); } @Override diff --git a/src/main/java/core/listener/LearningListener.java b/src/main/java/core/listener/LearningListener.java index 5a9d287..4147897 100644 --- a/src/main/java/core/listener/LearningListener.java +++ b/src/main/java/core/listener/LearningListener.java @@ -1,6 +1,8 @@ package core.listener; +import java.util.List; + public interface LearningListener{ - void onEpisodeEnd(double sumOfRewards); + void onEpisodeEnd(List rewardHistory); void onEpisodeStart(); } diff --git a/src/main/java/core/policy/EpsilonGreedyPolicy.java b/src/main/java/core/policy/EpsilonGreedyPolicy.java index 0e8d448..1288aed 100644 --- a/src/main/java/core/policy/EpsilonGreedyPolicy.java +++ b/src/main/java/core/policy/EpsilonGreedyPolicy.java @@ -1,6 +1,8 @@ package core.policy; import core.RNG; +import lombok.Getter; +import lombok.Setter; import java.util.Map; @@ -12,7 +14,9 @@ import java.util.Map; * * @param Discrete Action Enum */ -public class EpsilonGreedyPolicy implements Policy{ +public class EpsilonGreedyPolicy implements EpsilonPolicy{ + @Setter + @Getter private float epsilon; private RandomPolicy randomPolicy; private GreedyPolicy greedyPolicy; @@ -22,8 +26,10 @@ public class EpsilonGreedyPolicy implements Policy{ randomPolicy = new RandomPolicy<>(); greedyPolicy = new GreedyPolicy<>(); } + @Override public A chooseAction(Map actionValues) { + System.out.println("current epsilon " + epsilon); if(RNG.getRandom().nextFloat() < epsilon){ // Take random action return randomPolicy.chooseAction(actionValues); diff --git a/src/main/java/core/policy/EpsilonPolicy.java b/src/main/java/core/policy/EpsilonPolicy.java new file mode 100644 index 0000000..76bff45 --- /dev/null +++ b/src/main/java/core/policy/EpsilonPolicy.java @@ -0,0 +1,6 @@ +package core.policy; + +public interface EpsilonPolicy extends Policy { + float getEpsilon(); + void setEpsilon(float epsilon); +} diff --git a/src/main/java/core/policy/GreedyPolicy.java b/src/main/java/core/policy/GreedyPolicy.java index a727db3..6ff7739 100644 --- a/src/main/java/core/policy/GreedyPolicy.java +++ b/src/main/java/core/policy/GreedyPolicy.java @@ -1,7 +1,5 @@ package core.policy; -import core.RNG; - import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -13,7 +11,7 @@ public class GreedyPolicy implements Policy { public A chooseAction(Map actionValues) { if(actionValues.size() == 0) throw new RuntimeException("Empty actionActionValues set"); - Double highestValueAction = null; + Double highestValueAction = null; List equalHigh = new ArrayList<>(); diff --git a/src/main/java/core/policy/RandomPolicy.java b/src/main/java/core/policy/RandomPolicy.java index 1f8f086..094b41c 100644 --- a/src/main/java/core/policy/RandomPolicy.java +++ b/src/main/java/core/policy/RandomPolicy.java @@ -7,6 +7,7 @@ public class RandomPolicy implements Policy{ @Override public A chooseAction(Map actionValues) { int idx = RNG.getRandom().nextInt(actionValues.size()); + System.out.println("selected action " + idx); int i = 0; for(A action : actionValues.keySet()){ if(i++ == idx) return action; diff --git a/src/main/java/evironment/antGame/AntState.java b/src/main/java/evironment/antGame/AntState.java index 8c5bda7..ee8d347 100644 --- a/src/main/java/evironment/antGame/AntState.java +++ b/src/main/java/evironment/antGame/AntState.java @@ -20,13 +20,13 @@ public class AntState implements State, Visualizable { private final int computedHash; public AntState(Cell[][] knownWorld, Point antPosition, boolean hasFood){ - this.knownWorld = deepCopyCellGrid(knownWorld); + this.knownWorld = Util.deepCopyCellGrid(knownWorld); this.pos = deepCopyAntPosition(antPosition); this.hasFood = hasFood; computedHash = computeHash(); } - private int computeHash(){ + private int computeHash() { int hash = 7; int prime = 31; @@ -43,20 +43,10 @@ public class AntState implements State, Visualizable { } hash = prime * hash + unknown; hash = prime * hash * diff; - hash = prime * hash + (hasFood ? 1:0); + hash = prime * hash + (hasFood ? 1 : 0); hash = prime * hash + pos.hashCode(); return hash; } - private Cell[][] deepCopyCellGrid(Cell[][] toCopy){ - Cell[][] cells = new Cell[toCopy.length][toCopy[0].length]; - for (int i = 0; i < cells.length; i++) { - for (int j = 0; j < cells[i].length; j++) { - // calling copy constructor of Cell class - cells[i][j] = new Cell(toCopy[i][j]); - } - } - return cells; - } private Point deepCopyAntPosition(Point toCopy){ return new Point(toCopy.x,toCopy.y); diff --git a/src/main/java/evironment/antGame/AntWorld.java b/src/main/java/evironment/antGame/AntWorld.java index d68c597..e512ad7 100644 --- a/src/main/java/evironment/antGame/AntWorld.java +++ b/src/main/java/evironment/antGame/AntWorld.java @@ -151,9 +151,12 @@ public class AntWorld implements Environment{ done = grid.isAllFoodCollected(); } + + /* if(!done){ reward = -1; } + */ if(++tick == maxEpisodeTicks){ done = true; } @@ -172,8 +175,7 @@ public class AntWorld implements Environment{ } public State reset() { - RNG.reseed(); - grid.initRandomWorld(); + grid.resetWorld(); antAgent.initUnknownWorld(); tick = 0; myAnt.getPos().setLocation(grid.getStartPoint()); @@ -207,7 +209,8 @@ public class AntWorld implements Environment{ Learning monteCarlo = new MonteCarloOnPolicyEGreedy<>( new AntWorld(3, 3, 0.1), - new ListDiscreteActionSpace<>(AntAction.values()) + new ListDiscreteActionSpace<>(AntAction.values()), + 5 ); monteCarlo.learn(20000); } diff --git a/src/main/java/evironment/antGame/Grid.java b/src/main/java/evironment/antGame/Grid.java index 618f8ab..dced49a 100644 --- a/src/main/java/evironment/antGame/Grid.java +++ b/src/main/java/evironment/antGame/Grid.java @@ -10,31 +10,37 @@ public class Grid { private double foodDensity; private Point start; private Cell[][] grid; + private Cell[][] initialGrid; public Grid(int width, int height, double foodDensity){ this.width = width; this.height = height; this.foodDensity = foodDensity; - grid = new Cell[width][height]; + initialGrid = new Cell[width][height]; + initRandomWorld(); } public Grid(int width, int height){ this(width, height, 0); } + public void resetWorld(){ + grid = Util.deepCopyCellGrid(initialGrid); + } + public void initRandomWorld(){ for(int x = 0; x < width; ++x){ for(int y = 0; y < height; ++y){ if( RNG.getRandom().nextDouble() < foodDensity){ - grid[x][y] = new Cell(new Point(x,y), CellType.FREE, 1); + initialGrid[x][y] = new Cell(new Point(x,y), CellType.FREE, 1); }else{ - grid[x][y] = new Cell(new Point(x,y), CellType.FREE); + initialGrid[x][y] = new Cell(new Point(x,y), CellType.FREE); } } } start = new Point(RNG.getRandom().nextInt(width), RNG.getRandom().nextInt(height)); - grid[start.x][start.y] = new Cell(new Point(start.x, start.y), CellType.START); + initialGrid[start.x][start.y] = new Cell(new Point(start.x, start.y), CellType.START); } public Point getStartPoint(){ diff --git a/src/main/java/evironment/antGame/Reward.java b/src/main/java/evironment/antGame/Reward.java index 9a6926f..62f294a 100644 --- a/src/main/java/evironment/antGame/Reward.java +++ b/src/main/java/evironment/antGame/Reward.java @@ -1,17 +1,16 @@ package evironment.antGame; public class Reward { - public static final double FOOD_PICK_UP_SUCCESS = 0; - public static final double FOOD_PICK_UP_FAIL_NO_FOOD = 0; - public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = 0; + public static final double FOOD_PICK_UP_SUCCESS = 1; + public static final double FOOD_PICK_UP_FAIL_NO_FOOD = -1; + public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = -1; - public static final double FOOD_DROP_DOWN_FAIL_NO_FOOD = 0; - public static final double FOOD_DROP_DOWN_FAIL_NOT_START = 0; + public static final double FOOD_DROP_DOWN_FAIL_NO_FOOD = -1; + public static final double FOOD_DROP_DOWN_FAIL_NOT_START = -1; public static final double FOOD_DROP_DOWN_SUCCESS = 1; - public static final double UNKNOWN_FIELD_EXPLORED = 0; - - public static final double RAN_INTO_WALL = 0; - public static final double RAN_INTO_OBSTACLE = 0; + public static final double UNKNOWN_FIELD_EXPLORED = 1; + public static final double RAN_INTO_WALL = -1; + public static final double RAN_INTO_OBSTACLE = -1; } diff --git a/src/main/java/evironment/antGame/Util.java b/src/main/java/evironment/antGame/Util.java new file mode 100644 index 0000000..504b460 --- /dev/null +++ b/src/main/java/evironment/antGame/Util.java @@ -0,0 +1,14 @@ +package evironment.antGame; + +public class Util { + public static Cell[][] deepCopyCellGrid(Cell[][] toCopy){ + Cell[][] cells = new Cell[toCopy.length][toCopy[0].length]; + for (int i = 0; i < cells.length; i++) { + for (int j = 0; j < cells[i].length; j++) { + // calling copy constructor of Cell class + cells[i][j] = new Cell(toCopy[i][j]); + } + } + return cells; + } +} diff --git a/src/main/java/example/RunningAnt.java b/src/main/java/example/RunningAnt.java index 19311d0..dc22cc0 100644 --- a/src/main/java/example/RunningAnt.java +++ b/src/main/java/example/RunningAnt.java @@ -15,7 +15,7 @@ public class RunningAnt { .setAllowedActions(AntAction.values()) .setMethod(Method.MC_ONPOLICY_EGREEDY) .setDelay(10) - .setEpisodes(1000); + .setEpisodes(10000); rl.start(); } }