diff --git a/src/main/java/core/Action.java b/src/main/java/core/Action.java
new file mode 100644
index 0000000..f1630e1
--- /dev/null
+++ b/src/main/java/core/Action.java
@@ -0,0 +1,4 @@
+package core;
+
+public interface Action {
+}
diff --git a/src/main/java/core/Environment.java b/src/main/java/core/Environment.java
index a93d1bb..e3fc91f 100644
--- a/src/main/java/core/Environment.java
+++ b/src/main/java/core/Environment.java
@@ -1,5 +1,6 @@
package core;
public interface Environment {
- StepResult step(A action);
+ StepResultEnvironment step(A action);
+ State reset();
}
diff --git a/src/main/java/core/StateActionHashTable.java b/src/main/java/core/StateActionHashTable.java
index 66b6ae0..dd981d0 100644
--- a/src/main/java/core/StateActionHashTable.java
+++ b/src/main/java/core/StateActionHashTable.java
@@ -55,11 +55,10 @@ public class StateActionHashTable implements StateActionTable
@Override
public Map getActionValues(State state) {
- Map actionValues = table.get(state);
- if(actionValues == null){
- actionValues = createDefaultActionValues();
+ if(table.get(state) == null){
+ table.put(state, createDefaultActionValues());
}
- return actionValues;
+ return table.get(state);
}
public static void main(String[] args) {
diff --git a/src/main/java/core/StepResult.java b/src/main/java/core/StepResult.java
index 4ed1586..7de2756 100644
--- a/src/main/java/core/StepResult.java
+++ b/src/main/java/core/StepResult.java
@@ -2,14 +2,11 @@ package core;
import lombok.AllArgsConstructor;
import lombok.Getter;
-import lombok.Setter;
-@Getter
-@Setter
@AllArgsConstructor
-public class StepResult {
+@Getter
+public class StepResult {
private State state;
+ private A action;
private double reward;
- private boolean done;
- private String info;
}
diff --git a/src/main/java/core/StepResultEnvironment.java b/src/main/java/core/StepResultEnvironment.java
new file mode 100644
index 0000000..b1d1c06
--- /dev/null
+++ b/src/main/java/core/StepResultEnvironment.java
@@ -0,0 +1,15 @@
+package core;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+import lombok.Setter;
+
+@Getter
+@Setter
+@AllArgsConstructor
+public class StepResultEnvironment {
+ private State state;
+ private double reward;
+ private boolean done;
+ private String info;
+}
diff --git a/src/main/java/core/algo/Learning.java b/src/main/java/core/algo/Learning.java
new file mode 100644
index 0000000..8d71a78
--- /dev/null
+++ b/src/main/java/core/algo/Learning.java
@@ -0,0 +1,27 @@
+package core.algo;
+
+import core.DiscreteActionSpace;
+import core.Environment;
+import core.StateActionTable;
+import core.policy.Policy;
+
+public abstract class Learning {
+ protected Policy policy;
+ protected DiscreteActionSpace actionSpace;
+ protected StateActionTable stateActionTable;
+ protected Environment environment;
+ protected float discountFactor;
+ protected float epsilon;
+
+ public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon){
+ this.environment = environment;
+ this.actionSpace = actionSpace;
+ this.discountFactor = discountFactor;
+ this.epsilon = epsilon;
+ }
+ public Learning(Environment environment, DiscreteActionSpace actionSpace){
+ this(environment, actionSpace, 1.0f, 0.1f);
+ }
+
+ public abstract void learn(int nrOfEpisodes, int delay);
+}
diff --git a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java
new file mode 100644
index 0000000..55c70ed
--- /dev/null
+++ b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java
@@ -0,0 +1,75 @@
+package core.algo.MC;
+
+import core.*;
+import core.algo.Learning;
+import core.policy.EpsilonGreedyPolicy;
+import javafx.util.Pair;
+
+import java.util.*;
+
+public class MonteCarloOnPolicyEGreedy extends Learning {
+
+ public MonteCarloOnPolicyEGreedy(Environment environment, DiscreteActionSpace actionSpace) {
+ super(environment, actionSpace);
+ discountFactor = 1f;
+ this.policy = new EpsilonGreedyPolicy<>(0.1f);
+ this.stateActionTable = new StateActionHashTable<>(actionSpace);
+ }
+
+ @Override
+ public void learn(int nrOfEpisodes, int delay) {
+
+ Map, Double> returnSum = new HashMap<>();
+ Map, Integer> returnCount = new HashMap<>();
+
+ for(int i = 0; i < nrOfEpisodes; ++i) {
+
+ List> episode = new ArrayList<>();
+ State state = environment.reset();
+ for(int j=0; j < 100; ++j){
+ Map actionValues = stateActionTable.getActionValues(state);
+ A chosenAction = policy.chooseAction(actionValues);
+ StepResultEnvironment envResult = environment.step(chosenAction);
+ State nextState = envResult.getState();
+ episode.add(new StepResult<>(state, chosenAction, envResult.getReward()));
+
+ if(envResult.isDone()) break;
+
+ state = nextState;
+
+ try {
+ Thread.sleep(10);
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ }
+
+ Set> stateActionPairs = new HashSet<>();
+
+ for(StepResult sr: episode){
+ stateActionPairs.add(new Pair<>(sr.getState(), sr.getAction()));
+ }
+
+ for(Pair stateActionPair: stateActionPairs){
+ int firstOccurenceIndex = 0;
+ // find first occurance of state action pair
+ for(StepResult sr: episode){
+ if(stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())){
+ break;
+ }
+ firstOccurenceIndex++;
+ }
+
+ double G = 0;
+ for(int l = firstOccurenceIndex; l < episode.size(); ++l){
+ G += episode.get(l).getReward() * (Math.pow(discountFactor, l - firstOccurenceIndex));
+ }
+ // slick trick to add G to the entry.
+ // if the key does not exists, it will create a new entry with G as default value
+ returnSum.merge(stateActionPair, G, Double::sum);
+ returnCount.merge(stateActionPair, 1, Integer::sum);
+ stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair));
+ }
+ }
+ }
+}
diff --git a/src/main/java/core/algo/TD/TemporalDifferenceOnPolicy.java b/src/main/java/core/algo/TD/TemporalDifferenceOnPolicy.java
new file mode 100644
index 0000000..33716f8
--- /dev/null
+++ b/src/main/java/core/algo/TD/TemporalDifferenceOnPolicy.java
@@ -0,0 +1,4 @@
+package core.algo.TD;
+
+public class TemporalDifferenceOnPolicy {
+}
diff --git a/src/main/java/core/policy/EpsilonGreedyPolicy.java b/src/main/java/core/policy/EpsilonGreedyPolicy.java
new file mode 100644
index 0000000..0e8d448
--- /dev/null
+++ b/src/main/java/core/policy/EpsilonGreedyPolicy.java
@@ -0,0 +1,35 @@
+package core.policy;
+
+import core.RNG;
+
+import java.util.Map;
+
+/**
+ * To prevent the agent from getting stuck only using the "best" action
+ * according to the current learning history, this policy
+ * will take random action with the probability of epsilon.
+ * (random action space includes the best action as well)
+ *
+ * @param Discrete Action Enum
+ */
+public class EpsilonGreedyPolicy implements Policy{
+ private float epsilon;
+ private RandomPolicy randomPolicy;
+ private GreedyPolicy greedyPolicy;
+
+ public EpsilonGreedyPolicy(float epsilon){
+ this.epsilon = epsilon;
+ randomPolicy = new RandomPolicy<>();
+ greedyPolicy = new GreedyPolicy<>();
+ }
+ @Override
+ public A chooseAction(Map actionValues) {
+ if(RNG.getRandom().nextFloat() < epsilon){
+ // Take random action
+ return randomPolicy.chooseAction(actionValues);
+ }else{
+ // Take the action with the highest value
+ return greedyPolicy.chooseAction(actionValues);
+ }
+ }
+}
diff --git a/src/main/java/core/policy/GreedyPolicy.java b/src/main/java/core/policy/GreedyPolicy.java
new file mode 100644
index 0000000..15b06f7
--- /dev/null
+++ b/src/main/java/core/policy/GreedyPolicy.java
@@ -0,0 +1,32 @@
+package core.policy;
+
+import core.RNG;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+public class GreedyPolicy implements Policy {
+
+ @Override
+ public A chooseAction(Map actionValues) {
+ if(actionValues.size() == 0) throw new RuntimeException("Empty actionActionValues set");
+
+ Double highestValueAction = null;
+
+ List equalHigh = new ArrayList<>();
+
+ for(Map.Entry actionValue : actionValues.entrySet()){
+ System.out.println(actionValue.getKey()+ " " + actionValue.getValue() );
+ if(highestValueAction == null || highestValueAction < actionValue.getValue()){
+ highestValueAction = actionValue.getValue();
+ equalHigh.clear();
+ equalHigh.add(actionValue.getKey());
+ }else if(highestValueAction.equals(actionValue.getValue())){
+ equalHigh.add(actionValue.getKey());
+ }
+ }
+
+ return equalHigh.get(RNG.getRandom().nextInt(equalHigh.size()));
+ }
+}
diff --git a/src/main/java/core/policy/Policy.java b/src/main/java/core/policy/Policy.java
new file mode 100644
index 0000000..fcb9d04
--- /dev/null
+++ b/src/main/java/core/policy/Policy.java
@@ -0,0 +1,7 @@
+package core.policy;
+
+import java.util.Map;
+
+public interface Policy {
+ A chooseAction(Map actionValues);
+}
diff --git a/src/main/java/core/policy/RandomPolicy.java b/src/main/java/core/policy/RandomPolicy.java
new file mode 100644
index 0000000..1f8f086
--- /dev/null
+++ b/src/main/java/core/policy/RandomPolicy.java
@@ -0,0 +1,17 @@
+package core.policy;
+
+import core.RNG;
+import java.util.Map;
+
+public class RandomPolicy implements Policy{
+ @Override
+ public A chooseAction(Map actionValues) {
+ int idx = RNG.getRandom().nextInt(actionValues.size());
+ int i = 0;
+ for(A action : actionValues.keySet()){
+ if(i++ == idx) return action;
+ }
+
+ return null;
+ }
+}
diff --git a/src/main/java/evironment/antGame/Ant.java b/src/main/java/evironment/antGame/Ant.java
index 40beb2c..4bd4599 100644
--- a/src/main/java/evironment/antGame/Ant.java
+++ b/src/main/java/evironment/antGame/Ant.java
@@ -7,7 +7,6 @@ import lombok.Setter;
import java.awt.*;
-@AllArgsConstructor
@Getter
@Setter
public class Ant {
@@ -15,10 +14,14 @@ public class Ant {
@Setter(AccessLevel.NONE)
private Point pos;
private int points;
- private boolean spawned;
@Getter(AccessLevel.NONE)
private boolean hasFood;
+ public Ant(){
+ pos = new Point();
+ points = 0;
+ hasFood = false;
+ }
public boolean hasFood(){
return hasFood;
}
diff --git a/src/main/java/evironment/antGame/AntAgent.java b/src/main/java/evironment/antGame/AntAgent.java
index 4e4d28b..055ff87 100644
--- a/src/main/java/evironment/antGame/AntAgent.java
+++ b/src/main/java/evironment/antGame/AntAgent.java
@@ -9,7 +9,6 @@ public class AntAgent {
public AntAgent(int width, int height){
knownWorld = new Cell[width][height];
- initUnknownWorld();
}
/**
@@ -24,7 +23,7 @@ public class AntAgent {
return new AntState(knownWorld, observation.getPos(), observation.hasFood());
}
- private void initUnknownWorld(){
+ public void initUnknownWorld(){
for(int x = 0; x < knownWorld.length; ++x){
for(int y = 0; y < knownWorld[x].length; ++y){
knownWorld[x][y] = new Cell(new Point(x,y), CellType.UNKNOWN);
diff --git a/src/main/java/evironment/antGame/AntState.java b/src/main/java/evironment/antGame/AntState.java
index 6d4a77b..d1a31b7 100644
--- a/src/main/java/evironment/antGame/AntState.java
+++ b/src/main/java/evironment/antGame/AntState.java
@@ -11,17 +11,39 @@ import java.util.Arrays;
* and therefor has to be deep copied
*/
public class AntState implements State {
- private Cell[][] knownWorld;
- private Point pos;
- private boolean hasFood;
-
+ private final Cell[][] knownWorld;
+ private final Point pos;
+ private final boolean hasFood;
+ private final int computedHash;
public AntState(Cell[][] knownWorld, Point antPosition, boolean hasFood){
this.knownWorld = deepCopyCellGrid(knownWorld);
this.pos = deepCopyAntPosition(antPosition);
this.hasFood = hasFood;
+ computedHash = computeHash();
}
+ private int computeHash(){
+ int hash = 7;
+ int prime = 31;
+
+ int unknown = 0;
+ int diff = 0;
+ for (int i = 0; i < knownWorld.length; i++) {
+ for (int j = 0; j < knownWorld[i].length; j++) {
+ if(knownWorld[i][j].getType() == CellType.UNKNOWN){
+ unknown += 1;
+ }else{
+ diff +=1;
+ }
+ }
+ }
+ hash = prime * hash + unknown;
+ hash = prime * hash * diff;
+ hash = prime * hash + (hasFood ? 1:0);
+ hash = prime * hash + pos.hashCode();
+ return hash;
+ }
private Cell[][] deepCopyCellGrid(Cell[][] toCopy){
Cell[][] cells = new Cell[toCopy.length][toCopy[0].length];
for (int i = 0; i < cells.length; i++) {
@@ -45,12 +67,7 @@ public class AntState implements State {
//TODO: make this a utility function to generate hash Code based upon 2 prime numbers
@Override
public int hashCode(){
- int hash = 7;
- int prime = 31;
- hash = prime * hash + Arrays.hashCode(knownWorld);
- hash = prime * hash + (hasFood ? 1:0);
- hash = prime * hash + pos.hashCode();
- return hash;
+ return computedHash;
}
@Override
diff --git a/src/main/java/evironment/antGame/AntWorld.java b/src/main/java/evironment/antGame/AntWorld.java
index ef305e0..ff48eb4 100644
--- a/src/main/java/evironment/antGame/AntWorld.java
+++ b/src/main/java/evironment/antGame/AntWorld.java
@@ -1,10 +1,11 @@
package evironment.antGame;
import core.*;
+import core.algo.Learning;
+import core.algo.MC.MonteCarloOnPolicyEGreedy;
import evironment.antGame.gui.MainFrame;
-import javax.swing.*;
import java.awt.*;
public class AntWorld implements Environment{
@@ -39,9 +40,8 @@ public class AntWorld implements Environment{
public AntWorld(int width, int height, double foodDensity){
grid = new Grid(width, height, foodDensity);
antAgent = new AntAgent(width, height);
- myAnt = new Ant(new Point(-1,-1), 0, false, false);
+ myAnt = new Ant();
gui = new MainFrame(this, antAgent);
- tick = 0;
maxEpisodeTicks = 1000;
reset();
}
@@ -55,23 +55,13 @@ public class AntWorld implements Environment{
}
@Override
- public StepResult step(AntAction action){
+ public StepResultEnvironment step(AntAction action){
AntObservation observation;
State newState;
double reward = 0;
String info = "";
boolean done = false;
- if(!myAnt.isSpawned()){
- myAnt.setSpawned(true);
- myAnt.getPos().setLocation(grid.getStartPoint());
- observation = new AntObservation(grid.getCell(myAnt.getPos()), myAnt.getPos(), myAnt.hasFood());
- newState = antAgent.feedObservation(observation);
- reward = 0.0;
- ++tick;
- return new StepResult(newState, reward, false, "Just spawned on the map");
- }
-
Cell currentCell = grid.getCell(myAnt.getPos());
Point potentialNextPos = new Point(myAnt.getPos().x, myAnt.getPos().y);
boolean stayOnCell = true;
@@ -107,7 +97,7 @@ public class AntWorld implements Environment{
// Ant successfully picks up food
currentCell.setFood(currentCell.getFood() - 1);
myAnt.setHasFood(true);
- reward = Reward.FOOD_DROP_DOWN_SUCCESS;
+ reward = Reward.FOOD_PICK_UP_SUCCESS;
}
break;
case DROP_DOWN:
@@ -169,24 +159,30 @@ public class AntWorld implements Environment{
if(++tick == maxEpisodeTicks){
done = true;
}
- return new StepResult(newState, reward, done, info);
+
+ StepResultEnvironment result = new StepResultEnvironment(newState, reward, done, info);
+ getGui().update(action, result);
+ return result;
}
private boolean isInGrid(Point pos){
- return pos.x > 0 && pos.x < grid.getWidth() && pos.y > 0 && pos.y < grid.getHeight();
+ return pos.x >= 0 && pos.x < grid.getWidth() && pos.y >= 0 && pos.y < grid.getHeight();
}
private boolean hitObstacle(Point pos){
return grid.getCell(pos).getType() == CellType.OBSTACLE;
}
- public void reset() {
+ public State reset() {
RNG.reseed();
grid.initRandomWorld();
- myAnt.getPos().setLocation(-1,-1);
+ antAgent.initUnknownWorld();
+ tick = 0;
+ myAnt.getPos().setLocation(grid.getStartPoint());
myAnt.setPoints(0);
myAnt.setHasFood(false);
- myAnt.setSpawned(false);
+ AntObservation observation = new AntObservation(grid.getCell(myAnt.getPos()), myAnt.getPos(), myAnt.hasFood());
+ return antAgent.feedObservation(observation);
}
public void setMaxEpisodeLength(int maxTicks){
@@ -207,21 +203,14 @@ public class AntWorld implements Environment{
public Ant getAnt(){
return myAnt;
}
+
public static void main(String[] args) {
RNG.setSeed(1993);
- AntWorld a = new AntWorld(10, 10, 0.1);
- ListDiscreteActionSpace actionSpace =
- new ListDiscreteActionSpace<>(AntAction.MOVE_LEFT, AntAction.MOVE_RIGHT);
- for(int i = 0; i< 1000; ++i){
- AntAction selectedAction = actionSpace.getAllActions().get(RNG.getRandom().nextInt(actionSpace.getNumberOfAction()));
- StepResult step = a.step(selectedAction);
- SwingUtilities.invokeLater(()-> a.getGui().update(selectedAction, step));
- try {
- Thread.sleep(100);
- } catch (InterruptedException e) {
- e.printStackTrace();
- }
- }
+ Learning monteCarlo = new MonteCarloOnPolicyEGreedy<>(
+ new AntWorld(3, 3, 0.1),
+ new ListDiscreteActionSpace<>(AntAction.values())
+ );
+ monteCarlo.learn(100,5);
}
}
diff --git a/src/main/java/evironment/antGame/Reward.java b/src/main/java/evironment/antGame/Reward.java
index b1fee95..855c6ff 100644
--- a/src/main/java/evironment/antGame/Reward.java
+++ b/src/main/java/evironment/antGame/Reward.java
@@ -1,17 +1,17 @@
package evironment.antGame;
public class Reward {
- public static final double FOOD_PICK_UP_SUCCESS = 1;
- public static final double FOOD_PICK_UP_FAIL_NO_FOOD = -1000;
- public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = -1000;
+ public static final double FOOD_PICK_UP_SUCCESS = 0;
+ public static final double FOOD_PICK_UP_FAIL_NO_FOOD = 0;
+ public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = 0;
- public static final double FOOD_DROP_DOWN_FAIL_NO_FOOD = -1000;
- public static final double FOOD_DROP_DOWN_FAIL_NOT_START = -1000;
+ public static final double FOOD_DROP_DOWN_FAIL_NO_FOOD = 0;
+ public static final double FOOD_DROP_DOWN_FAIL_NOT_START = 0;
public static final double FOOD_DROP_DOWN_SUCCESS = 1000;
- public static final double UNKNOWN_FIELD_EXPLORED = 1;
+ public static final double UNKNOWN_FIELD_EXPLORED = 0;
- public static final double RAN_INTO_WALL = -100;
- public static final double RAN_INTO_OBSTACLE = -100;
+ public static final double RAN_INTO_WALL = 0;
+ public static final double RAN_INTO_OBSTACLE = 0;
}
diff --git a/src/main/java/evironment/antGame/gui/MainFrame.java b/src/main/java/evironment/antGame/gui/MainFrame.java
index 079bfa8..c299d78 100644
--- a/src/main/java/evironment/antGame/gui/MainFrame.java
+++ b/src/main/java/evironment/antGame/gui/MainFrame.java
@@ -1,6 +1,6 @@
package evironment.antGame.gui;
-import core.StepResult;
+import core.StepResultEnvironment;
import evironment.antGame.AntAction;
import evironment.antGame.AntAgent;
import evironment.antGame.AntWorld;
@@ -33,9 +33,9 @@ public class MainFrame extends JFrame {
setVisible(true);
}
- public void update(AntAction lastAction, StepResult stepResult){
+ public void update(AntAction lastAction, StepResultEnvironment stepResultEnvironment){
historyPanel.addText(String.format("Tick %d: \t Selected action: %s \t Reward: %f \t Info: %s \n totalPoints: %d \t hasFood: %b \t ",
- antWorld.getTick(), lastAction.toString(), stepResult.getReward(), stepResult.getInfo(), antWorld.getAnt().getPoints(), antWorld.getAnt().hasFood()));
+ antWorld.getTick(), lastAction.toString(), stepResultEnvironment.getReward(), stepResultEnvironment.getInfo(), antWorld.getAnt().getPoints(), antWorld.getAnt().hasFood()));
repaint();
}