From 55d8bbf5dc9bd01d5368f23a23c3bfb81bcda5c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20L=C3=B6wenstrom?= Date: Mon, 9 Dec 2019 23:21:48 +0100 Subject: [PATCH] add Random-, Greedy and EGreedy-Policy and first implementation of monte carlo method - fixed bug regarding wrong generation of hashCode. hashCodes needs to be equal across equal objects. Compute hashCode on final states once and return this value instead of computing it every time .hashCode() gets called. - --- src/main/java/core/Action.java | 4 + src/main/java/core/Environment.java | 3 +- src/main/java/core/StateActionHashTable.java | 7 +- src/main/java/core/StepResult.java | 9 +-- src/main/java/core/StepResultEnvironment.java | 15 ++++ src/main/java/core/algo/Learning.java | 27 +++++++ .../algo/MC/MonteCarloOnPolicyEGreedy.java | 75 +++++++++++++++++++ .../algo/TD/TemporalDifferenceOnPolicy.java | 4 + .../java/core/policy/EpsilonGreedyPolicy.java | 35 +++++++++ src/main/java/core/policy/GreedyPolicy.java | 32 ++++++++ src/main/java/core/policy/Policy.java | 7 ++ src/main/java/core/policy/RandomPolicy.java | 17 +++++ src/main/java/evironment/antGame/Ant.java | 7 +- .../java/evironment/antGame/AntAgent.java | 3 +- .../java/evironment/antGame/AntState.java | 37 ++++++--- .../java/evironment/antGame/AntWorld.java | 55 ++++++-------- src/main/java/evironment/antGame/Reward.java | 16 ++-- .../evironment/antGame/gui/MainFrame.java | 6 +- 18 files changed, 290 insertions(+), 69 deletions(-) create mode 100644 src/main/java/core/Action.java create mode 100644 src/main/java/core/StepResultEnvironment.java create mode 100644 src/main/java/core/algo/Learning.java create mode 100644 src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java create mode 100644 src/main/java/core/algo/TD/TemporalDifferenceOnPolicy.java create mode 100644 src/main/java/core/policy/EpsilonGreedyPolicy.java create mode 100644 src/main/java/core/policy/GreedyPolicy.java create mode 100644 src/main/java/core/policy/Policy.java create mode 100644 src/main/java/core/policy/RandomPolicy.java diff --git a/src/main/java/core/Action.java b/src/main/java/core/Action.java new file mode 100644 index 0000000..f1630e1 --- /dev/null +++ b/src/main/java/core/Action.java @@ -0,0 +1,4 @@ +package core; + +public interface Action { +} diff --git a/src/main/java/core/Environment.java b/src/main/java/core/Environment.java index a93d1bb..e3fc91f 100644 --- a/src/main/java/core/Environment.java +++ b/src/main/java/core/Environment.java @@ -1,5 +1,6 @@ package core; public interface Environment { - StepResult step(A action); + StepResultEnvironment step(A action); + State reset(); } diff --git a/src/main/java/core/StateActionHashTable.java b/src/main/java/core/StateActionHashTable.java index 66b6ae0..dd981d0 100644 --- a/src/main/java/core/StateActionHashTable.java +++ b/src/main/java/core/StateActionHashTable.java @@ -55,11 +55,10 @@ public class StateActionHashTable implements StateActionTable @Override public Map getActionValues(State state) { - Map actionValues = table.get(state); - if(actionValues == null){ - actionValues = createDefaultActionValues(); + if(table.get(state) == null){ + table.put(state, createDefaultActionValues()); } - return actionValues; + return table.get(state); } public static void main(String[] args) { diff --git a/src/main/java/core/StepResult.java b/src/main/java/core/StepResult.java index 4ed1586..7de2756 100644 --- a/src/main/java/core/StepResult.java +++ b/src/main/java/core/StepResult.java @@ -2,14 +2,11 @@ package core; import lombok.AllArgsConstructor; import lombok.Getter; -import lombok.Setter; -@Getter -@Setter @AllArgsConstructor -public class StepResult { +@Getter +public class StepResult { private State state; + private A action; private double reward; - private boolean done; - private String info; } diff --git a/src/main/java/core/StepResultEnvironment.java b/src/main/java/core/StepResultEnvironment.java new file mode 100644 index 0000000..b1d1c06 --- /dev/null +++ b/src/main/java/core/StepResultEnvironment.java @@ -0,0 +1,15 @@ +package core; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +@AllArgsConstructor +public class StepResultEnvironment { + private State state; + private double reward; + private boolean done; + private String info; +} diff --git a/src/main/java/core/algo/Learning.java b/src/main/java/core/algo/Learning.java new file mode 100644 index 0000000..8d71a78 --- /dev/null +++ b/src/main/java/core/algo/Learning.java @@ -0,0 +1,27 @@ +package core.algo; + +import core.DiscreteActionSpace; +import core.Environment; +import core.StateActionTable; +import core.policy.Policy; + +public abstract class Learning { + protected Policy policy; + protected DiscreteActionSpace actionSpace; + protected StateActionTable stateActionTable; + protected Environment environment; + protected float discountFactor; + protected float epsilon; + + public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon){ + this.environment = environment; + this.actionSpace = actionSpace; + this.discountFactor = discountFactor; + this.epsilon = epsilon; + } + public Learning(Environment environment, DiscreteActionSpace actionSpace){ + this(environment, actionSpace, 1.0f, 0.1f); + } + + public abstract void learn(int nrOfEpisodes, int delay); +} diff --git a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java new file mode 100644 index 0000000..55c70ed --- /dev/null +++ b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java @@ -0,0 +1,75 @@ +package core.algo.MC; + +import core.*; +import core.algo.Learning; +import core.policy.EpsilonGreedyPolicy; +import javafx.util.Pair; + +import java.util.*; + +public class MonteCarloOnPolicyEGreedy extends Learning { + + public MonteCarloOnPolicyEGreedy(Environment environment, DiscreteActionSpace actionSpace) { + super(environment, actionSpace); + discountFactor = 1f; + this.policy = new EpsilonGreedyPolicy<>(0.1f); + this.stateActionTable = new StateActionHashTable<>(actionSpace); + } + + @Override + public void learn(int nrOfEpisodes, int delay) { + + Map, Double> returnSum = new HashMap<>(); + Map, Integer> returnCount = new HashMap<>(); + + for(int i = 0; i < nrOfEpisodes; ++i) { + + List> episode = new ArrayList<>(); + State state = environment.reset(); + for(int j=0; j < 100; ++j){ + Map actionValues = stateActionTable.getActionValues(state); + A chosenAction = policy.chooseAction(actionValues); + StepResultEnvironment envResult = environment.step(chosenAction); + State nextState = envResult.getState(); + episode.add(new StepResult<>(state, chosenAction, envResult.getReward())); + + if(envResult.isDone()) break; + + state = nextState; + + try { + Thread.sleep(10); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + + Set> stateActionPairs = new HashSet<>(); + + for(StepResult sr: episode){ + stateActionPairs.add(new Pair<>(sr.getState(), sr.getAction())); + } + + for(Pair stateActionPair: stateActionPairs){ + int firstOccurenceIndex = 0; + // find first occurance of state action pair + for(StepResult sr: episode){ + if(stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())){ + break; + } + firstOccurenceIndex++; + } + + double G = 0; + for(int l = firstOccurenceIndex; l < episode.size(); ++l){ + G += episode.get(l).getReward() * (Math.pow(discountFactor, l - firstOccurenceIndex)); + } + // slick trick to add G to the entry. + // if the key does not exists, it will create a new entry with G as default value + returnSum.merge(stateActionPair, G, Double::sum); + returnCount.merge(stateActionPair, 1, Integer::sum); + stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair)); + } + } + } +} diff --git a/src/main/java/core/algo/TD/TemporalDifferenceOnPolicy.java b/src/main/java/core/algo/TD/TemporalDifferenceOnPolicy.java new file mode 100644 index 0000000..33716f8 --- /dev/null +++ b/src/main/java/core/algo/TD/TemporalDifferenceOnPolicy.java @@ -0,0 +1,4 @@ +package core.algo.TD; + +public class TemporalDifferenceOnPolicy { +} diff --git a/src/main/java/core/policy/EpsilonGreedyPolicy.java b/src/main/java/core/policy/EpsilonGreedyPolicy.java new file mode 100644 index 0000000..0e8d448 --- /dev/null +++ b/src/main/java/core/policy/EpsilonGreedyPolicy.java @@ -0,0 +1,35 @@ +package core.policy; + +import core.RNG; + +import java.util.Map; + +/** + * To prevent the agent from getting stuck only using the "best" action + * according to the current learning history, this policy + * will take random action with the probability of epsilon. + * (random action space includes the best action as well) + * + * @param Discrete Action Enum + */ +public class EpsilonGreedyPolicy implements Policy{ + private float epsilon; + private RandomPolicy randomPolicy; + private GreedyPolicy greedyPolicy; + + public EpsilonGreedyPolicy(float epsilon){ + this.epsilon = epsilon; + randomPolicy = new RandomPolicy<>(); + greedyPolicy = new GreedyPolicy<>(); + } + @Override + public A chooseAction(Map actionValues) { + if(RNG.getRandom().nextFloat() < epsilon){ + // Take random action + return randomPolicy.chooseAction(actionValues); + }else{ + // Take the action with the highest value + return greedyPolicy.chooseAction(actionValues); + } + } +} diff --git a/src/main/java/core/policy/GreedyPolicy.java b/src/main/java/core/policy/GreedyPolicy.java new file mode 100644 index 0000000..15b06f7 --- /dev/null +++ b/src/main/java/core/policy/GreedyPolicy.java @@ -0,0 +1,32 @@ +package core.policy; + +import core.RNG; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +public class GreedyPolicy implements Policy { + + @Override + public A chooseAction(Map actionValues) { + if(actionValues.size() == 0) throw new RuntimeException("Empty actionActionValues set"); + + Double highestValueAction = null; + + List equalHigh = new ArrayList<>(); + + for(Map.Entry actionValue : actionValues.entrySet()){ + System.out.println(actionValue.getKey()+ " " + actionValue.getValue() ); + if(highestValueAction == null || highestValueAction < actionValue.getValue()){ + highestValueAction = actionValue.getValue(); + equalHigh.clear(); + equalHigh.add(actionValue.getKey()); + }else if(highestValueAction.equals(actionValue.getValue())){ + equalHigh.add(actionValue.getKey()); + } + } + + return equalHigh.get(RNG.getRandom().nextInt(equalHigh.size())); + } +} diff --git a/src/main/java/core/policy/Policy.java b/src/main/java/core/policy/Policy.java new file mode 100644 index 0000000..fcb9d04 --- /dev/null +++ b/src/main/java/core/policy/Policy.java @@ -0,0 +1,7 @@ +package core.policy; + +import java.util.Map; + +public interface Policy { + A chooseAction(Map actionValues); +} diff --git a/src/main/java/core/policy/RandomPolicy.java b/src/main/java/core/policy/RandomPolicy.java new file mode 100644 index 0000000..1f8f086 --- /dev/null +++ b/src/main/java/core/policy/RandomPolicy.java @@ -0,0 +1,17 @@ +package core.policy; + +import core.RNG; +import java.util.Map; + +public class RandomPolicy implements Policy{ + @Override + public A chooseAction(Map actionValues) { + int idx = RNG.getRandom().nextInt(actionValues.size()); + int i = 0; + for(A action : actionValues.keySet()){ + if(i++ == idx) return action; + } + + return null; + } +} diff --git a/src/main/java/evironment/antGame/Ant.java b/src/main/java/evironment/antGame/Ant.java index 40beb2c..4bd4599 100644 --- a/src/main/java/evironment/antGame/Ant.java +++ b/src/main/java/evironment/antGame/Ant.java @@ -7,7 +7,6 @@ import lombok.Setter; import java.awt.*; -@AllArgsConstructor @Getter @Setter public class Ant { @@ -15,10 +14,14 @@ public class Ant { @Setter(AccessLevel.NONE) private Point pos; private int points; - private boolean spawned; @Getter(AccessLevel.NONE) private boolean hasFood; + public Ant(){ + pos = new Point(); + points = 0; + hasFood = false; + } public boolean hasFood(){ return hasFood; } diff --git a/src/main/java/evironment/antGame/AntAgent.java b/src/main/java/evironment/antGame/AntAgent.java index 4e4d28b..055ff87 100644 --- a/src/main/java/evironment/antGame/AntAgent.java +++ b/src/main/java/evironment/antGame/AntAgent.java @@ -9,7 +9,6 @@ public class AntAgent { public AntAgent(int width, int height){ knownWorld = new Cell[width][height]; - initUnknownWorld(); } /** @@ -24,7 +23,7 @@ public class AntAgent { return new AntState(knownWorld, observation.getPos(), observation.hasFood()); } - private void initUnknownWorld(){ + public void initUnknownWorld(){ for(int x = 0; x < knownWorld.length; ++x){ for(int y = 0; y < knownWorld[x].length; ++y){ knownWorld[x][y] = new Cell(new Point(x,y), CellType.UNKNOWN); diff --git a/src/main/java/evironment/antGame/AntState.java b/src/main/java/evironment/antGame/AntState.java index 6d4a77b..d1a31b7 100644 --- a/src/main/java/evironment/antGame/AntState.java +++ b/src/main/java/evironment/antGame/AntState.java @@ -11,17 +11,39 @@ import java.util.Arrays; * and therefor has to be deep copied */ public class AntState implements State { - private Cell[][] knownWorld; - private Point pos; - private boolean hasFood; - + private final Cell[][] knownWorld; + private final Point pos; + private final boolean hasFood; + private final int computedHash; public AntState(Cell[][] knownWorld, Point antPosition, boolean hasFood){ this.knownWorld = deepCopyCellGrid(knownWorld); this.pos = deepCopyAntPosition(antPosition); this.hasFood = hasFood; + computedHash = computeHash(); } + private int computeHash(){ + int hash = 7; + int prime = 31; + + int unknown = 0; + int diff = 0; + for (int i = 0; i < knownWorld.length; i++) { + for (int j = 0; j < knownWorld[i].length; j++) { + if(knownWorld[i][j].getType() == CellType.UNKNOWN){ + unknown += 1; + }else{ + diff +=1; + } + } + } + hash = prime * hash + unknown; + hash = prime * hash * diff; + hash = prime * hash + (hasFood ? 1:0); + hash = prime * hash + pos.hashCode(); + return hash; + } private Cell[][] deepCopyCellGrid(Cell[][] toCopy){ Cell[][] cells = new Cell[toCopy.length][toCopy[0].length]; for (int i = 0; i < cells.length; i++) { @@ -45,12 +67,7 @@ public class AntState implements State { //TODO: make this a utility function to generate hash Code based upon 2 prime numbers @Override public int hashCode(){ - int hash = 7; - int prime = 31; - hash = prime * hash + Arrays.hashCode(knownWorld); - hash = prime * hash + (hasFood ? 1:0); - hash = prime * hash + pos.hashCode(); - return hash; + return computedHash; } @Override diff --git a/src/main/java/evironment/antGame/AntWorld.java b/src/main/java/evironment/antGame/AntWorld.java index ef305e0..ff48eb4 100644 --- a/src/main/java/evironment/antGame/AntWorld.java +++ b/src/main/java/evironment/antGame/AntWorld.java @@ -1,10 +1,11 @@ package evironment.antGame; import core.*; +import core.algo.Learning; +import core.algo.MC.MonteCarloOnPolicyEGreedy; import evironment.antGame.gui.MainFrame; -import javax.swing.*; import java.awt.*; public class AntWorld implements Environment{ @@ -39,9 +40,8 @@ public class AntWorld implements Environment{ public AntWorld(int width, int height, double foodDensity){ grid = new Grid(width, height, foodDensity); antAgent = new AntAgent(width, height); - myAnt = new Ant(new Point(-1,-1), 0, false, false); + myAnt = new Ant(); gui = new MainFrame(this, antAgent); - tick = 0; maxEpisodeTicks = 1000; reset(); } @@ -55,23 +55,13 @@ public class AntWorld implements Environment{ } @Override - public StepResult step(AntAction action){ + public StepResultEnvironment step(AntAction action){ AntObservation observation; State newState; double reward = 0; String info = ""; boolean done = false; - if(!myAnt.isSpawned()){ - myAnt.setSpawned(true); - myAnt.getPos().setLocation(grid.getStartPoint()); - observation = new AntObservation(grid.getCell(myAnt.getPos()), myAnt.getPos(), myAnt.hasFood()); - newState = antAgent.feedObservation(observation); - reward = 0.0; - ++tick; - return new StepResult(newState, reward, false, "Just spawned on the map"); - } - Cell currentCell = grid.getCell(myAnt.getPos()); Point potentialNextPos = new Point(myAnt.getPos().x, myAnt.getPos().y); boolean stayOnCell = true; @@ -107,7 +97,7 @@ public class AntWorld implements Environment{ // Ant successfully picks up food currentCell.setFood(currentCell.getFood() - 1); myAnt.setHasFood(true); - reward = Reward.FOOD_DROP_DOWN_SUCCESS; + reward = Reward.FOOD_PICK_UP_SUCCESS; } break; case DROP_DOWN: @@ -169,24 +159,30 @@ public class AntWorld implements Environment{ if(++tick == maxEpisodeTicks){ done = true; } - return new StepResult(newState, reward, done, info); + + StepResultEnvironment result = new StepResultEnvironment(newState, reward, done, info); + getGui().update(action, result); + return result; } private boolean isInGrid(Point pos){ - return pos.x > 0 && pos.x < grid.getWidth() && pos.y > 0 && pos.y < grid.getHeight(); + return pos.x >= 0 && pos.x < grid.getWidth() && pos.y >= 0 && pos.y < grid.getHeight(); } private boolean hitObstacle(Point pos){ return grid.getCell(pos).getType() == CellType.OBSTACLE; } - public void reset() { + public State reset() { RNG.reseed(); grid.initRandomWorld(); - myAnt.getPos().setLocation(-1,-1); + antAgent.initUnknownWorld(); + tick = 0; + myAnt.getPos().setLocation(grid.getStartPoint()); myAnt.setPoints(0); myAnt.setHasFood(false); - myAnt.setSpawned(false); + AntObservation observation = new AntObservation(grid.getCell(myAnt.getPos()), myAnt.getPos(), myAnt.hasFood()); + return antAgent.feedObservation(observation); } public void setMaxEpisodeLength(int maxTicks){ @@ -207,21 +203,14 @@ public class AntWorld implements Environment{ public Ant getAnt(){ return myAnt; } + public static void main(String[] args) { RNG.setSeed(1993); - AntWorld a = new AntWorld(10, 10, 0.1); - ListDiscreteActionSpace actionSpace = - new ListDiscreteActionSpace<>(AntAction.MOVE_LEFT, AntAction.MOVE_RIGHT); - for(int i = 0; i< 1000; ++i){ - AntAction selectedAction = actionSpace.getAllActions().get(RNG.getRandom().nextInt(actionSpace.getNumberOfAction())); - StepResult step = a.step(selectedAction); - SwingUtilities.invokeLater(()-> a.getGui().update(selectedAction, step)); - try { - Thread.sleep(100); - } catch (InterruptedException e) { - e.printStackTrace(); - } - } + Learning monteCarlo = new MonteCarloOnPolicyEGreedy<>( + new AntWorld(3, 3, 0.1), + new ListDiscreteActionSpace<>(AntAction.values()) + ); + monteCarlo.learn(100,5); } } diff --git a/src/main/java/evironment/antGame/Reward.java b/src/main/java/evironment/antGame/Reward.java index b1fee95..855c6ff 100644 --- a/src/main/java/evironment/antGame/Reward.java +++ b/src/main/java/evironment/antGame/Reward.java @@ -1,17 +1,17 @@ package evironment.antGame; public class Reward { - public static final double FOOD_PICK_UP_SUCCESS = 1; - public static final double FOOD_PICK_UP_FAIL_NO_FOOD = -1000; - public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = -1000; + public static final double FOOD_PICK_UP_SUCCESS = 0; + public static final double FOOD_PICK_UP_FAIL_NO_FOOD = 0; + public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = 0; - public static final double FOOD_DROP_DOWN_FAIL_NO_FOOD = -1000; - public static final double FOOD_DROP_DOWN_FAIL_NOT_START = -1000; + public static final double FOOD_DROP_DOWN_FAIL_NO_FOOD = 0; + public static final double FOOD_DROP_DOWN_FAIL_NOT_START = 0; public static final double FOOD_DROP_DOWN_SUCCESS = 1000; - public static final double UNKNOWN_FIELD_EXPLORED = 1; + public static final double UNKNOWN_FIELD_EXPLORED = 0; - public static final double RAN_INTO_WALL = -100; - public static final double RAN_INTO_OBSTACLE = -100; + public static final double RAN_INTO_WALL = 0; + public static final double RAN_INTO_OBSTACLE = 0; } diff --git a/src/main/java/evironment/antGame/gui/MainFrame.java b/src/main/java/evironment/antGame/gui/MainFrame.java index 079bfa8..c299d78 100644 --- a/src/main/java/evironment/antGame/gui/MainFrame.java +++ b/src/main/java/evironment/antGame/gui/MainFrame.java @@ -1,6 +1,6 @@ package evironment.antGame.gui; -import core.StepResult; +import core.StepResultEnvironment; import evironment.antGame.AntAction; import evironment.antGame.AntAgent; import evironment.antGame.AntWorld; @@ -33,9 +33,9 @@ public class MainFrame extends JFrame { setVisible(true); } - public void update(AntAction lastAction, StepResult stepResult){ + public void update(AntAction lastAction, StepResultEnvironment stepResultEnvironment){ historyPanel.addText(String.format("Tick %d: \t Selected action: %s \t Reward: %f \t Info: %s \n totalPoints: %d \t hasFood: %b \t ", - antWorld.getTick(), lastAction.toString(), stepResult.getReward(), stepResult.getInfo(), antWorld.getAnt().getPoints(), antWorld.getAnt().hasFood())); + antWorld.getTick(), lastAction.toString(), stepResultEnvironment.getReward(), stepResultEnvironment.getInfo(), antWorld.getAnt().getPoints(), antWorld.getAnt().hasFood())); repaint(); }