diff --git a/src/main/java/core/algo/Learning.java b/src/main/java/core/algo/Learning.java index 58a285a..a825bd5 100644 --- a/src/main/java/core/algo/Learning.java +++ b/src/main/java/core/algo/Learning.java @@ -10,8 +10,11 @@ import lombok.Getter; import lombok.Setter; import javax.swing.*; +import java.util.ArrayList; import java.util.HashSet; +import java.util.List; import java.util.Set; +import java.util.concurrent.CopyOnWriteArrayList; @Getter public abstract class Learning { @@ -20,38 +23,43 @@ public abstract class Learning { protected StateActionTable stateActionTable; protected Environment environment; protected float discountFactor; - @Setter - protected float epsilon; protected Set learningListeners; @Setter protected int delay; + private List rewardHistory; - public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay){ + public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, int delay){ this.environment = environment; this.actionSpace = actionSpace; this.discountFactor = discountFactor; - this.epsilon = epsilon; this.delay = delay; learningListeners = new HashSet<>(); + rewardHistory = new CopyOnWriteArrayList<>(); } - public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon){ - this(environment, actionSpace, discountFactor, epsilon, LearningConfig.DEFAULT_DELAY); + public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor){ + this(environment, actionSpace, discountFactor, LearningConfig.DEFAULT_DELAY); + } + + public Learning(Environment environment, DiscreteActionSpace actionSpace, int delay){ + this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, delay); } public Learning(Environment environment, DiscreteActionSpace actionSpace){ - this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, LearningConfig.DEFAULT_DELAY); + this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_DELAY); } + public abstract void learn(int nrOfEpisodes); public void addListener(LearningListener learningListener){ learningListeners.add(learningListener); } - protected void dispatchEpisodeEnd(double sum){ + protected void dispatchEpisodeEnd(double recentSumOfRewards){ + rewardHistory.add(recentSumOfRewards); for(LearningListener l: learningListeners) { - l.onEpisodeEnd(sum); + l.onEpisodeEnd(rewardHistory); } } diff --git a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java index d608b80..1bc1f11 100644 --- a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java +++ b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java @@ -4,6 +4,8 @@ import core.*; import core.algo.Learning; import core.policy.EpsilonGreedyPolicy; import javafx.util.Pair; +import lombok.Setter; + import java.util.*; /** @@ -26,13 +28,18 @@ import java.util.*; */ public class MonteCarloOnPolicyEGreedy extends Learning { - public MonteCarloOnPolicyEGreedy(Environment environment, DiscreteActionSpace actionSpace) { - super(environment, actionSpace); - discountFactor = 1f; - this.policy = new EpsilonGreedyPolicy<>(0.1f); - this.stateActionTable = new StateActionHashTable<>(actionSpace); + public MonteCarloOnPolicyEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay) { + super(environment, actionSpace, discountFactor, delay); + + this.policy = new EpsilonGreedyPolicy<>(epsilon); + this.stateActionTable = new StateActionHashTable<>(this.actionSpace); } + public MonteCarloOnPolicyEGreedy(Environment environment, DiscreteActionSpace actionSpace, int delay) { + this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay); + } + + @Override public void learn(int nrOfEpisodes) { diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java index c65e62d..80b4375 100644 --- a/src/main/java/core/controller/RLController.java +++ b/src/main/java/core/controller/RLController.java @@ -7,9 +7,9 @@ import core.algo.Learning; import core.algo.Method; import core.algo.mc.MonteCarloOnPolicyEGreedy; import core.gui.View; +import core.policy.EpsilonPolicy; import javax.swing.*; -import java.util.Optional; public class RLController implements ViewListener{ protected Environment environment; @@ -30,28 +30,38 @@ public class RLController implements ViewListener{ switch (method){ case MC_ONPOLICY_EGREEDY: - learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace); + learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, delay); break; case TD_ONPOLICY: break; default: throw new RuntimeException("Undefined method"); } - SwingUtilities.invokeLater(() ->{ - view = new View<>(learning, this); - learning.addListener(view); - }); + /* + not using SwingUtilities here on purpose to ensure the view is fully + initialized and can be passed as LearningListener. + */ + view = new View<>(learning, this); + learning.addListener(view); learning.learn(nrOfEpisodes); } - + @Override public void onEpsilonChange(float epsilon) { - learning.setEpsilon(epsilon); - SwingUtilities.invokeLater(() -> view.updateLearningInfoPanel()); + if(learning.getPolicy() instanceof EpsilonPolicy){ + ((EpsilonPolicy) learning.getPolicy()).setEpsilon(epsilon); + SwingUtilities.invokeLater(() -> view.updateLearningInfoPanel()); + }else{ + System.out.println("Trying to call inEpsilonChange on non-epsilon policy"); + } } @Override public void onDelayChange(int delay) { + learning.setDelay(delay); + SwingUtilities.invokeLater(() -> { + view.updateLearningInfoPanel(); + }); } public RLController setMethod(Method method){ diff --git a/src/main/java/core/gui/LearningInfoPanel.java b/src/main/java/core/gui/LearningInfoPanel.java index cbbd6ef..8c67589 100644 --- a/src/main/java/core/gui/LearningInfoPanel.java +++ b/src/main/java/core/gui/LearningInfoPanel.java @@ -2,6 +2,7 @@ package core.gui; import core.algo.Learning; import core.controller.ViewListener; +import core.policy.EpsilonPolicy; import javax.swing.*; @@ -11,6 +12,7 @@ public class LearningInfoPanel extends JPanel { private JLabel discountLabel; private JLabel epsilonLabel; private JSlider epsilonSlider; + private JLabel delayLabel; private JSlider delaySlider; public LearningInfoPanel(Learning learning, ViewListener viewListener){ @@ -19,12 +21,19 @@ public class LearningInfoPanel extends JPanel { policyLabel = new JLabel(); discountLabel = new JLabel(); epsilonLabel = new JLabel(); - epsilonSlider = new JSlider(0, 100, (int)(learning.getEpsilon() * 100)); - epsilonSlider.addChangeListener(e -> viewListener.onEpsilonChange(epsilonSlider.getValue() / 100f)); + delayLabel = new JLabel(); + delaySlider = new JSlider(0,1000, learning.getDelay()); + delaySlider.addChangeListener(e -> viewListener.onDelayChange(delaySlider.getValue())); add(policyLabel); add(discountLabel); - add(epsilonLabel); - add(epsilonSlider); + if(learning.getPolicy() instanceof EpsilonPolicy){ + epsilonSlider = new JSlider(0, 100, (int)((EpsilonPolicy)learning.getPolicy()).getEpsilon() * 100); + epsilonSlider.addChangeListener(e -> viewListener.onEpsilonChange(epsilonSlider.getValue() / 100f)); + add(epsilonLabel); + add(epsilonSlider); + } + add(delayLabel); + add(delaySlider); refreshLabels(); setVisible(true); } @@ -32,10 +41,9 @@ public class LearningInfoPanel extends JPanel { public void refreshLabels(){ policyLabel.setText("Policy: " + learning.getPolicy().getClass()); discountLabel.setText("Discount factor: " + learning.getDiscountFactor()); - epsilonLabel.setText("Exploration (Epsilon): " + learning.getEpsilon()); - } - - protected JSlider getEpsilonSlider(){ - return epsilonSlider; + if(learning.getPolicy() instanceof EpsilonPolicy){ + epsilonLabel.setText("Exploration (Epsilon): " + ((EpsilonPolicy)learning.getPolicy()).getEpsilon()); + } + delayLabel.setText("Delay (ms): " + learning.getDelay()); } } diff --git a/src/main/java/core/gui/View.java b/src/main/java/core/gui/View.java index 8939ac2..d031451 100644 --- a/src/main/java/core/gui/View.java +++ b/src/main/java/core/gui/View.java @@ -23,12 +23,10 @@ public class View implements LearningListener { private JFrame mainFrame; private XChartPanel rewardChartPanel; private ViewListener viewListener; - private List rewardHistory; public View(Learning learning, ViewListener viewListener){ this.learning = learning; this.viewListener = viewListener; - rewardHistory = new ArrayList<>(); this.initMainFrame(); } @@ -78,8 +76,7 @@ public class View implements LearningListener { }; } - public void updateRewardGraph(double recentReward){ - rewardHistory.add(recentReward); + public void updateRewardGraph(List rewardHistory){ chart.updateXYSeries("randomWalk", null, rewardHistory, null); rewardChartPanel.revalidate(); rewardChartPanel.repaint(); @@ -89,10 +86,11 @@ public class View implements LearningListener { this.learningInfoPanel.refreshLabels(); } - @Override - public void onEpisodeEnd(double sumOfRewards) { - SwingUtilities.invokeLater(()->updateRewardGraph(sumOfRewards)); + public void onEpisodeEnd(List rewardHistory) { + SwingUtilities.invokeLater(()->{ + updateRewardGraph(rewardHistory); + }); } @Override diff --git a/src/main/java/core/listener/LearningListener.java b/src/main/java/core/listener/LearningListener.java index 5a9d287..4147897 100644 --- a/src/main/java/core/listener/LearningListener.java +++ b/src/main/java/core/listener/LearningListener.java @@ -1,6 +1,8 @@ package core.listener; +import java.util.List; + public interface LearningListener{ - void onEpisodeEnd(double sumOfRewards); + void onEpisodeEnd(List rewardHistory); void onEpisodeStart(); } diff --git a/src/main/java/core/policy/EpsilonGreedyPolicy.java b/src/main/java/core/policy/EpsilonGreedyPolicy.java index 0e8d448..1288aed 100644 --- a/src/main/java/core/policy/EpsilonGreedyPolicy.java +++ b/src/main/java/core/policy/EpsilonGreedyPolicy.java @@ -1,6 +1,8 @@ package core.policy; import core.RNG; +import lombok.Getter; +import lombok.Setter; import java.util.Map; @@ -12,7 +14,9 @@ import java.util.Map; * * @param Discrete Action Enum */ -public class EpsilonGreedyPolicy implements Policy{ +public class EpsilonGreedyPolicy implements EpsilonPolicy{ + @Setter + @Getter private float epsilon; private RandomPolicy randomPolicy; private GreedyPolicy greedyPolicy; @@ -22,8 +26,10 @@ public class EpsilonGreedyPolicy implements Policy{ randomPolicy = new RandomPolicy<>(); greedyPolicy = new GreedyPolicy<>(); } + @Override public A chooseAction(Map actionValues) { + System.out.println("current epsilon " + epsilon); if(RNG.getRandom().nextFloat() < epsilon){ // Take random action return randomPolicy.chooseAction(actionValues); diff --git a/src/main/java/core/policy/EpsilonPolicy.java b/src/main/java/core/policy/EpsilonPolicy.java new file mode 100644 index 0000000..76bff45 --- /dev/null +++ b/src/main/java/core/policy/EpsilonPolicy.java @@ -0,0 +1,6 @@ +package core.policy; + +public interface EpsilonPolicy extends Policy { + float getEpsilon(); + void setEpsilon(float epsilon); +} diff --git a/src/main/java/core/policy/GreedyPolicy.java b/src/main/java/core/policy/GreedyPolicy.java index a727db3..6ff7739 100644 --- a/src/main/java/core/policy/GreedyPolicy.java +++ b/src/main/java/core/policy/GreedyPolicy.java @@ -1,7 +1,5 @@ package core.policy; -import core.RNG; - import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -13,7 +11,7 @@ public class GreedyPolicy implements Policy { public A chooseAction(Map actionValues) { if(actionValues.size() == 0) throw new RuntimeException("Empty actionActionValues set"); - Double highestValueAction = null; + Double highestValueAction = null; List equalHigh = new ArrayList<>(); diff --git a/src/main/java/core/policy/RandomPolicy.java b/src/main/java/core/policy/RandomPolicy.java index 1f8f086..094b41c 100644 --- a/src/main/java/core/policy/RandomPolicy.java +++ b/src/main/java/core/policy/RandomPolicy.java @@ -7,6 +7,7 @@ public class RandomPolicy implements Policy{ @Override public A chooseAction(Map actionValues) { int idx = RNG.getRandom().nextInt(actionValues.size()); + System.out.println("selected action " + idx); int i = 0; for(A action : actionValues.keySet()){ if(i++ == idx) return action; diff --git a/src/main/java/evironment/antGame/AntState.java b/src/main/java/evironment/antGame/AntState.java index 8c5bda7..ee8d347 100644 --- a/src/main/java/evironment/antGame/AntState.java +++ b/src/main/java/evironment/antGame/AntState.java @@ -20,13 +20,13 @@ public class AntState implements State, Visualizable { private final int computedHash; public AntState(Cell[][] knownWorld, Point antPosition, boolean hasFood){ - this.knownWorld = deepCopyCellGrid(knownWorld); + this.knownWorld = Util.deepCopyCellGrid(knownWorld); this.pos = deepCopyAntPosition(antPosition); this.hasFood = hasFood; computedHash = computeHash(); } - private int computeHash(){ + private int computeHash() { int hash = 7; int prime = 31; @@ -43,20 +43,10 @@ public class AntState implements State, Visualizable { } hash = prime * hash + unknown; hash = prime * hash * diff; - hash = prime * hash + (hasFood ? 1:0); + hash = prime * hash + (hasFood ? 1 : 0); hash = prime * hash + pos.hashCode(); return hash; } - private Cell[][] deepCopyCellGrid(Cell[][] toCopy){ - Cell[][] cells = new Cell[toCopy.length][toCopy[0].length]; - for (int i = 0; i < cells.length; i++) { - for (int j = 0; j < cells[i].length; j++) { - // calling copy constructor of Cell class - cells[i][j] = new Cell(toCopy[i][j]); - } - } - return cells; - } private Point deepCopyAntPosition(Point toCopy){ return new Point(toCopy.x,toCopy.y); diff --git a/src/main/java/evironment/antGame/AntWorld.java b/src/main/java/evironment/antGame/AntWorld.java index d68c597..e512ad7 100644 --- a/src/main/java/evironment/antGame/AntWorld.java +++ b/src/main/java/evironment/antGame/AntWorld.java @@ -151,9 +151,12 @@ public class AntWorld implements Environment{ done = grid.isAllFoodCollected(); } + + /* if(!done){ reward = -1; } + */ if(++tick == maxEpisodeTicks){ done = true; } @@ -172,8 +175,7 @@ public class AntWorld implements Environment{ } public State reset() { - RNG.reseed(); - grid.initRandomWorld(); + grid.resetWorld(); antAgent.initUnknownWorld(); tick = 0; myAnt.getPos().setLocation(grid.getStartPoint()); @@ -207,7 +209,8 @@ public class AntWorld implements Environment{ Learning monteCarlo = new MonteCarloOnPolicyEGreedy<>( new AntWorld(3, 3, 0.1), - new ListDiscreteActionSpace<>(AntAction.values()) + new ListDiscreteActionSpace<>(AntAction.values()), + 5 ); monteCarlo.learn(20000); } diff --git a/src/main/java/evironment/antGame/Grid.java b/src/main/java/evironment/antGame/Grid.java index 618f8ab..dced49a 100644 --- a/src/main/java/evironment/antGame/Grid.java +++ b/src/main/java/evironment/antGame/Grid.java @@ -10,31 +10,37 @@ public class Grid { private double foodDensity; private Point start; private Cell[][] grid; + private Cell[][] initialGrid; public Grid(int width, int height, double foodDensity){ this.width = width; this.height = height; this.foodDensity = foodDensity; - grid = new Cell[width][height]; + initialGrid = new Cell[width][height]; + initRandomWorld(); } public Grid(int width, int height){ this(width, height, 0); } + public void resetWorld(){ + grid = Util.deepCopyCellGrid(initialGrid); + } + public void initRandomWorld(){ for(int x = 0; x < width; ++x){ for(int y = 0; y < height; ++y){ if( RNG.getRandom().nextDouble() < foodDensity){ - grid[x][y] = new Cell(new Point(x,y), CellType.FREE, 1); + initialGrid[x][y] = new Cell(new Point(x,y), CellType.FREE, 1); }else{ - grid[x][y] = new Cell(new Point(x,y), CellType.FREE); + initialGrid[x][y] = new Cell(new Point(x,y), CellType.FREE); } } } start = new Point(RNG.getRandom().nextInt(width), RNG.getRandom().nextInt(height)); - grid[start.x][start.y] = new Cell(new Point(start.x, start.y), CellType.START); + initialGrid[start.x][start.y] = new Cell(new Point(start.x, start.y), CellType.START); } public Point getStartPoint(){ diff --git a/src/main/java/evironment/antGame/Reward.java b/src/main/java/evironment/antGame/Reward.java index 9a6926f..62f294a 100644 --- a/src/main/java/evironment/antGame/Reward.java +++ b/src/main/java/evironment/antGame/Reward.java @@ -1,17 +1,16 @@ package evironment.antGame; public class Reward { - public static final double FOOD_PICK_UP_SUCCESS = 0; - public static final double FOOD_PICK_UP_FAIL_NO_FOOD = 0; - public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = 0; + public static final double FOOD_PICK_UP_SUCCESS = 1; + public static final double FOOD_PICK_UP_FAIL_NO_FOOD = -1; + public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = -1; - public static final double FOOD_DROP_DOWN_FAIL_NO_FOOD = 0; - public static final double FOOD_DROP_DOWN_FAIL_NOT_START = 0; + public static final double FOOD_DROP_DOWN_FAIL_NO_FOOD = -1; + public static final double FOOD_DROP_DOWN_FAIL_NOT_START = -1; public static final double FOOD_DROP_DOWN_SUCCESS = 1; - public static final double UNKNOWN_FIELD_EXPLORED = 0; - - public static final double RAN_INTO_WALL = 0; - public static final double RAN_INTO_OBSTACLE = 0; + public static final double UNKNOWN_FIELD_EXPLORED = 1; + public static final double RAN_INTO_WALL = -1; + public static final double RAN_INTO_OBSTACLE = -1; } diff --git a/src/main/java/evironment/antGame/Util.java b/src/main/java/evironment/antGame/Util.java new file mode 100644 index 0000000..504b460 --- /dev/null +++ b/src/main/java/evironment/antGame/Util.java @@ -0,0 +1,14 @@ +package evironment.antGame; + +public class Util { + public static Cell[][] deepCopyCellGrid(Cell[][] toCopy){ + Cell[][] cells = new Cell[toCopy.length][toCopy[0].length]; + for (int i = 0; i < cells.length; i++) { + for (int j = 0; j < cells[i].length; j++) { + // calling copy constructor of Cell class + cells[i][j] = new Cell(toCopy[i][j]); + } + } + return cells; + } +} diff --git a/src/main/java/example/RunningAnt.java b/src/main/java/example/RunningAnt.java index 19311d0..dc22cc0 100644 --- a/src/main/java/example/RunningAnt.java +++ b/src/main/java/example/RunningAnt.java @@ -15,7 +15,7 @@ public class RunningAnt { .setAllowedActions(AntAction.values()) .setMethod(Method.MC_ONPOLICY_EGREEDY) .setDelay(10) - .setEpisodes(1000); + .setEpisodes(10000); rl.start(); } }