diff --git a/src/main/java/core/algo/Learning.java b/src/main/java/core/algo/Learning.java
index 58a285a..a825bd5 100644
--- a/src/main/java/core/algo/Learning.java
+++ b/src/main/java/core/algo/Learning.java
@@ -10,8 +10,11 @@ import lombok.Getter;
import lombok.Setter;
import javax.swing.*;
+import java.util.ArrayList;
import java.util.HashSet;
+import java.util.List;
import java.util.Set;
+import java.util.concurrent.CopyOnWriteArrayList;
@Getter
public abstract class Learning {
@@ -20,38 +23,43 @@ public abstract class Learning {
protected StateActionTable stateActionTable;
protected Environment environment;
protected float discountFactor;
- @Setter
- protected float epsilon;
protected Set learningListeners;
@Setter
protected int delay;
+ private List rewardHistory;
- public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay){
+ public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, int delay){
this.environment = environment;
this.actionSpace = actionSpace;
this.discountFactor = discountFactor;
- this.epsilon = epsilon;
this.delay = delay;
learningListeners = new HashSet<>();
+ rewardHistory = new CopyOnWriteArrayList<>();
}
- public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon){
- this(environment, actionSpace, discountFactor, epsilon, LearningConfig.DEFAULT_DELAY);
+ public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor){
+ this(environment, actionSpace, discountFactor, LearningConfig.DEFAULT_DELAY);
+ }
+
+ public Learning(Environment environment, DiscreteActionSpace actionSpace, int delay){
+ this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, delay);
}
public Learning(Environment environment, DiscreteActionSpace actionSpace){
- this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, LearningConfig.DEFAULT_DELAY);
+ this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_DELAY);
}
+
public abstract void learn(int nrOfEpisodes);
public void addListener(LearningListener learningListener){
learningListeners.add(learningListener);
}
- protected void dispatchEpisodeEnd(double sum){
+ protected void dispatchEpisodeEnd(double recentSumOfRewards){
+ rewardHistory.add(recentSumOfRewards);
for(LearningListener l: learningListeners) {
- l.onEpisodeEnd(sum);
+ l.onEpisodeEnd(rewardHistory);
}
}
diff --git a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java
index d608b80..1bc1f11 100644
--- a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java
+++ b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java
@@ -4,6 +4,8 @@ import core.*;
import core.algo.Learning;
import core.policy.EpsilonGreedyPolicy;
import javafx.util.Pair;
+import lombok.Setter;
+
import java.util.*;
/**
@@ -26,13 +28,18 @@ import java.util.*;
*/
public class MonteCarloOnPolicyEGreedy extends Learning {
- public MonteCarloOnPolicyEGreedy(Environment environment, DiscreteActionSpace actionSpace) {
- super(environment, actionSpace);
- discountFactor = 1f;
- this.policy = new EpsilonGreedyPolicy<>(0.1f);
- this.stateActionTable = new StateActionHashTable<>(actionSpace);
+ public MonteCarloOnPolicyEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay) {
+ super(environment, actionSpace, discountFactor, delay);
+
+ this.policy = new EpsilonGreedyPolicy<>(epsilon);
+ this.stateActionTable = new StateActionHashTable<>(this.actionSpace);
}
+ public MonteCarloOnPolicyEGreedy(Environment environment, DiscreteActionSpace actionSpace, int delay) {
+ this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay);
+ }
+
+
@Override
public void learn(int nrOfEpisodes) {
diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java
index c65e62d..80b4375 100644
--- a/src/main/java/core/controller/RLController.java
+++ b/src/main/java/core/controller/RLController.java
@@ -7,9 +7,9 @@ import core.algo.Learning;
import core.algo.Method;
import core.algo.mc.MonteCarloOnPolicyEGreedy;
import core.gui.View;
+import core.policy.EpsilonPolicy;
import javax.swing.*;
-import java.util.Optional;
public class RLController implements ViewListener{
protected Environment environment;
@@ -30,28 +30,38 @@ public class RLController implements ViewListener{
switch (method){
case MC_ONPOLICY_EGREEDY:
- learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace);
+ learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, delay);
break;
case TD_ONPOLICY:
break;
default:
throw new RuntimeException("Undefined method");
}
- SwingUtilities.invokeLater(() ->{
- view = new View<>(learning, this);
- learning.addListener(view);
- });
+ /*
+ not using SwingUtilities here on purpose to ensure the view is fully
+ initialized and can be passed as LearningListener.
+ */
+ view = new View<>(learning, this);
+ learning.addListener(view);
learning.learn(nrOfEpisodes);
}
-
+
@Override
public void onEpsilonChange(float epsilon) {
- learning.setEpsilon(epsilon);
- SwingUtilities.invokeLater(() -> view.updateLearningInfoPanel());
+ if(learning.getPolicy() instanceof EpsilonPolicy){
+ ((EpsilonPolicy) learning.getPolicy()).setEpsilon(epsilon);
+ SwingUtilities.invokeLater(() -> view.updateLearningInfoPanel());
+ }else{
+ System.out.println("Trying to call inEpsilonChange on non-epsilon policy");
+ }
}
@Override
public void onDelayChange(int delay) {
+ learning.setDelay(delay);
+ SwingUtilities.invokeLater(() -> {
+ view.updateLearningInfoPanel();
+ });
}
public RLController setMethod(Method method){
diff --git a/src/main/java/core/gui/LearningInfoPanel.java b/src/main/java/core/gui/LearningInfoPanel.java
index cbbd6ef..8c67589 100644
--- a/src/main/java/core/gui/LearningInfoPanel.java
+++ b/src/main/java/core/gui/LearningInfoPanel.java
@@ -2,6 +2,7 @@ package core.gui;
import core.algo.Learning;
import core.controller.ViewListener;
+import core.policy.EpsilonPolicy;
import javax.swing.*;
@@ -11,6 +12,7 @@ public class LearningInfoPanel extends JPanel {
private JLabel discountLabel;
private JLabel epsilonLabel;
private JSlider epsilonSlider;
+ private JLabel delayLabel;
private JSlider delaySlider;
public LearningInfoPanel(Learning learning, ViewListener viewListener){
@@ -19,12 +21,19 @@ public class LearningInfoPanel extends JPanel {
policyLabel = new JLabel();
discountLabel = new JLabel();
epsilonLabel = new JLabel();
- epsilonSlider = new JSlider(0, 100, (int)(learning.getEpsilon() * 100));
- epsilonSlider.addChangeListener(e -> viewListener.onEpsilonChange(epsilonSlider.getValue() / 100f));
+ delayLabel = new JLabel();
+ delaySlider = new JSlider(0,1000, learning.getDelay());
+ delaySlider.addChangeListener(e -> viewListener.onDelayChange(delaySlider.getValue()));
add(policyLabel);
add(discountLabel);
- add(epsilonLabel);
- add(epsilonSlider);
+ if(learning.getPolicy() instanceof EpsilonPolicy){
+ epsilonSlider = new JSlider(0, 100, (int)((EpsilonPolicy)learning.getPolicy()).getEpsilon() * 100);
+ epsilonSlider.addChangeListener(e -> viewListener.onEpsilonChange(epsilonSlider.getValue() / 100f));
+ add(epsilonLabel);
+ add(epsilonSlider);
+ }
+ add(delayLabel);
+ add(delaySlider);
refreshLabels();
setVisible(true);
}
@@ -32,10 +41,9 @@ public class LearningInfoPanel extends JPanel {
public void refreshLabels(){
policyLabel.setText("Policy: " + learning.getPolicy().getClass());
discountLabel.setText("Discount factor: " + learning.getDiscountFactor());
- epsilonLabel.setText("Exploration (Epsilon): " + learning.getEpsilon());
- }
-
- protected JSlider getEpsilonSlider(){
- return epsilonSlider;
+ if(learning.getPolicy() instanceof EpsilonPolicy){
+ epsilonLabel.setText("Exploration (Epsilon): " + ((EpsilonPolicy)learning.getPolicy()).getEpsilon());
+ }
+ delayLabel.setText("Delay (ms): " + learning.getDelay());
}
}
diff --git a/src/main/java/core/gui/View.java b/src/main/java/core/gui/View.java
index 8939ac2..d031451 100644
--- a/src/main/java/core/gui/View.java
+++ b/src/main/java/core/gui/View.java
@@ -23,12 +23,10 @@ public class View implements LearningListener {
private JFrame mainFrame;
private XChartPanel rewardChartPanel;
private ViewListener viewListener;
- private List rewardHistory;
public View(Learning learning, ViewListener viewListener){
this.learning = learning;
this.viewListener = viewListener;
- rewardHistory = new ArrayList<>();
this.initMainFrame();
}
@@ -78,8 +76,7 @@ public class View implements LearningListener {
};
}
- public void updateRewardGraph(double recentReward){
- rewardHistory.add(recentReward);
+ public void updateRewardGraph(List rewardHistory){
chart.updateXYSeries("randomWalk", null, rewardHistory, null);
rewardChartPanel.revalidate();
rewardChartPanel.repaint();
@@ -89,10 +86,11 @@ public class View implements LearningListener {
this.learningInfoPanel.refreshLabels();
}
-
@Override
- public void onEpisodeEnd(double sumOfRewards) {
- SwingUtilities.invokeLater(()->updateRewardGraph(sumOfRewards));
+ public void onEpisodeEnd(List rewardHistory) {
+ SwingUtilities.invokeLater(()->{
+ updateRewardGraph(rewardHistory);
+ });
}
@Override
diff --git a/src/main/java/core/listener/LearningListener.java b/src/main/java/core/listener/LearningListener.java
index 5a9d287..4147897 100644
--- a/src/main/java/core/listener/LearningListener.java
+++ b/src/main/java/core/listener/LearningListener.java
@@ -1,6 +1,8 @@
package core.listener;
+import java.util.List;
+
public interface LearningListener{
- void onEpisodeEnd(double sumOfRewards);
+ void onEpisodeEnd(List rewardHistory);
void onEpisodeStart();
}
diff --git a/src/main/java/core/policy/EpsilonGreedyPolicy.java b/src/main/java/core/policy/EpsilonGreedyPolicy.java
index 0e8d448..1288aed 100644
--- a/src/main/java/core/policy/EpsilonGreedyPolicy.java
+++ b/src/main/java/core/policy/EpsilonGreedyPolicy.java
@@ -1,6 +1,8 @@
package core.policy;
import core.RNG;
+import lombok.Getter;
+import lombok.Setter;
import java.util.Map;
@@ -12,7 +14,9 @@ import java.util.Map;
*
* @param Discrete Action Enum
*/
-public class EpsilonGreedyPolicy implements Policy{
+public class EpsilonGreedyPolicy implements EpsilonPolicy{
+ @Setter
+ @Getter
private float epsilon;
private RandomPolicy randomPolicy;
private GreedyPolicy greedyPolicy;
@@ -22,8 +26,10 @@ public class EpsilonGreedyPolicy implements Policy{
randomPolicy = new RandomPolicy<>();
greedyPolicy = new GreedyPolicy<>();
}
+
@Override
public A chooseAction(Map actionValues) {
+ System.out.println("current epsilon " + epsilon);
if(RNG.getRandom().nextFloat() < epsilon){
// Take random action
return randomPolicy.chooseAction(actionValues);
diff --git a/src/main/java/core/policy/EpsilonPolicy.java b/src/main/java/core/policy/EpsilonPolicy.java
new file mode 100644
index 0000000..76bff45
--- /dev/null
+++ b/src/main/java/core/policy/EpsilonPolicy.java
@@ -0,0 +1,6 @@
+package core.policy;
+
+public interface EpsilonPolicy extends Policy {
+ float getEpsilon();
+ void setEpsilon(float epsilon);
+}
diff --git a/src/main/java/core/policy/GreedyPolicy.java b/src/main/java/core/policy/GreedyPolicy.java
index a727db3..6ff7739 100644
--- a/src/main/java/core/policy/GreedyPolicy.java
+++ b/src/main/java/core/policy/GreedyPolicy.java
@@ -1,7 +1,5 @@
package core.policy;
-import core.RNG;
-
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@@ -13,7 +11,7 @@ public class GreedyPolicy implements Policy {
public A chooseAction(Map actionValues) {
if(actionValues.size() == 0) throw new RuntimeException("Empty actionActionValues set");
- Double highestValueAction = null;
+ Double highestValueAction = null;
List equalHigh = new ArrayList<>();
diff --git a/src/main/java/core/policy/RandomPolicy.java b/src/main/java/core/policy/RandomPolicy.java
index 1f8f086..094b41c 100644
--- a/src/main/java/core/policy/RandomPolicy.java
+++ b/src/main/java/core/policy/RandomPolicy.java
@@ -7,6 +7,7 @@ public class RandomPolicy implements Policy{
@Override
public A chooseAction(Map actionValues) {
int idx = RNG.getRandom().nextInt(actionValues.size());
+ System.out.println("selected action " + idx);
int i = 0;
for(A action : actionValues.keySet()){
if(i++ == idx) return action;
diff --git a/src/main/java/evironment/antGame/AntState.java b/src/main/java/evironment/antGame/AntState.java
index 8c5bda7..ee8d347 100644
--- a/src/main/java/evironment/antGame/AntState.java
+++ b/src/main/java/evironment/antGame/AntState.java
@@ -20,13 +20,13 @@ public class AntState implements State, Visualizable {
private final int computedHash;
public AntState(Cell[][] knownWorld, Point antPosition, boolean hasFood){
- this.knownWorld = deepCopyCellGrid(knownWorld);
+ this.knownWorld = Util.deepCopyCellGrid(knownWorld);
this.pos = deepCopyAntPosition(antPosition);
this.hasFood = hasFood;
computedHash = computeHash();
}
- private int computeHash(){
+ private int computeHash() {
int hash = 7;
int prime = 31;
@@ -43,20 +43,10 @@ public class AntState implements State, Visualizable {
}
hash = prime * hash + unknown;
hash = prime * hash * diff;
- hash = prime * hash + (hasFood ? 1:0);
+ hash = prime * hash + (hasFood ? 1 : 0);
hash = prime * hash + pos.hashCode();
return hash;
}
- private Cell[][] deepCopyCellGrid(Cell[][] toCopy){
- Cell[][] cells = new Cell[toCopy.length][toCopy[0].length];
- for (int i = 0; i < cells.length; i++) {
- for (int j = 0; j < cells[i].length; j++) {
- // calling copy constructor of Cell class
- cells[i][j] = new Cell(toCopy[i][j]);
- }
- }
- return cells;
- }
private Point deepCopyAntPosition(Point toCopy){
return new Point(toCopy.x,toCopy.y);
diff --git a/src/main/java/evironment/antGame/AntWorld.java b/src/main/java/evironment/antGame/AntWorld.java
index d68c597..e512ad7 100644
--- a/src/main/java/evironment/antGame/AntWorld.java
+++ b/src/main/java/evironment/antGame/AntWorld.java
@@ -151,9 +151,12 @@ public class AntWorld implements Environment{
done = grid.isAllFoodCollected();
}
+
+ /*
if(!done){
reward = -1;
}
+ */
if(++tick == maxEpisodeTicks){
done = true;
}
@@ -172,8 +175,7 @@ public class AntWorld implements Environment{
}
public State reset() {
- RNG.reseed();
- grid.initRandomWorld();
+ grid.resetWorld();
antAgent.initUnknownWorld();
tick = 0;
myAnt.getPos().setLocation(grid.getStartPoint());
@@ -207,7 +209,8 @@ public class AntWorld implements Environment{
Learning monteCarlo = new MonteCarloOnPolicyEGreedy<>(
new AntWorld(3, 3, 0.1),
- new ListDiscreteActionSpace<>(AntAction.values())
+ new ListDiscreteActionSpace<>(AntAction.values()),
+ 5
);
monteCarlo.learn(20000);
}
diff --git a/src/main/java/evironment/antGame/Grid.java b/src/main/java/evironment/antGame/Grid.java
index 618f8ab..dced49a 100644
--- a/src/main/java/evironment/antGame/Grid.java
+++ b/src/main/java/evironment/antGame/Grid.java
@@ -10,31 +10,37 @@ public class Grid {
private double foodDensity;
private Point start;
private Cell[][] grid;
+ private Cell[][] initialGrid;
public Grid(int width, int height, double foodDensity){
this.width = width;
this.height = height;
this.foodDensity = foodDensity;
-
grid = new Cell[width][height];
+ initialGrid = new Cell[width][height];
+ initRandomWorld();
}
public Grid(int width, int height){
this(width, height, 0);
}
+ public void resetWorld(){
+ grid = Util.deepCopyCellGrid(initialGrid);
+ }
+
public void initRandomWorld(){
for(int x = 0; x < width; ++x){
for(int y = 0; y < height; ++y){
if( RNG.getRandom().nextDouble() < foodDensity){
- grid[x][y] = new Cell(new Point(x,y), CellType.FREE, 1);
+ initialGrid[x][y] = new Cell(new Point(x,y), CellType.FREE, 1);
}else{
- grid[x][y] = new Cell(new Point(x,y), CellType.FREE);
+ initialGrid[x][y] = new Cell(new Point(x,y), CellType.FREE);
}
}
}
start = new Point(RNG.getRandom().nextInt(width), RNG.getRandom().nextInt(height));
- grid[start.x][start.y] = new Cell(new Point(start.x, start.y), CellType.START);
+ initialGrid[start.x][start.y] = new Cell(new Point(start.x, start.y), CellType.START);
}
public Point getStartPoint(){
diff --git a/src/main/java/evironment/antGame/Reward.java b/src/main/java/evironment/antGame/Reward.java
index 9a6926f..62f294a 100644
--- a/src/main/java/evironment/antGame/Reward.java
+++ b/src/main/java/evironment/antGame/Reward.java
@@ -1,17 +1,16 @@
package evironment.antGame;
public class Reward {
- public static final double FOOD_PICK_UP_SUCCESS = 0;
- public static final double FOOD_PICK_UP_FAIL_NO_FOOD = 0;
- public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = 0;
+ public static final double FOOD_PICK_UP_SUCCESS = 1;
+ public static final double FOOD_PICK_UP_FAIL_NO_FOOD = -1;
+ public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = -1;
- public static final double FOOD_DROP_DOWN_FAIL_NO_FOOD = 0;
- public static final double FOOD_DROP_DOWN_FAIL_NOT_START = 0;
+ public static final double FOOD_DROP_DOWN_FAIL_NO_FOOD = -1;
+ public static final double FOOD_DROP_DOWN_FAIL_NOT_START = -1;
public static final double FOOD_DROP_DOWN_SUCCESS = 1;
- public static final double UNKNOWN_FIELD_EXPLORED = 0;
-
- public static final double RAN_INTO_WALL = 0;
- public static final double RAN_INTO_OBSTACLE = 0;
+ public static final double UNKNOWN_FIELD_EXPLORED = 1;
+ public static final double RAN_INTO_WALL = -1;
+ public static final double RAN_INTO_OBSTACLE = -1;
}
diff --git a/src/main/java/evironment/antGame/Util.java b/src/main/java/evironment/antGame/Util.java
new file mode 100644
index 0000000..504b460
--- /dev/null
+++ b/src/main/java/evironment/antGame/Util.java
@@ -0,0 +1,14 @@
+package evironment.antGame;
+
+public class Util {
+ public static Cell[][] deepCopyCellGrid(Cell[][] toCopy){
+ Cell[][] cells = new Cell[toCopy.length][toCopy[0].length];
+ for (int i = 0; i < cells.length; i++) {
+ for (int j = 0; j < cells[i].length; j++) {
+ // calling copy constructor of Cell class
+ cells[i][j] = new Cell(toCopy[i][j]);
+ }
+ }
+ return cells;
+ }
+}
diff --git a/src/main/java/example/RunningAnt.java b/src/main/java/example/RunningAnt.java
index 19311d0..dc22cc0 100644
--- a/src/main/java/example/RunningAnt.java
+++ b/src/main/java/example/RunningAnt.java
@@ -15,7 +15,7 @@ public class RunningAnt {
.setAllowedActions(AntAction.values())
.setMethod(Method.MC_ONPOLICY_EGREEDY)
.setDelay(10)
- .setEpisodes(1000);
+ .setEpisodes(10000);
rl.start();
}
}