diff --git a/src/main/java/core/Action.java b/src/main/java/core/Action.java new file mode 100644 index 0000000..f1630e1 --- /dev/null +++ b/src/main/java/core/Action.java @@ -0,0 +1,4 @@ +package core; + +public interface Action { +} diff --git a/src/main/java/core/Environment.java b/src/main/java/core/Environment.java index a93d1bb..e3fc91f 100644 --- a/src/main/java/core/Environment.java +++ b/src/main/java/core/Environment.java @@ -1,5 +1,6 @@ package core; public interface Environment { - StepResult step(A action); + StepResultEnvironment step(A action); + State reset(); } diff --git a/src/main/java/core/StateActionHashTable.java b/src/main/java/core/StateActionHashTable.java index 66b6ae0..dd981d0 100644 --- a/src/main/java/core/StateActionHashTable.java +++ b/src/main/java/core/StateActionHashTable.java @@ -55,11 +55,10 @@ public class StateActionHashTable implements StateActionTable @Override public Map getActionValues(State state) { - Map actionValues = table.get(state); - if(actionValues == null){ - actionValues = createDefaultActionValues(); + if(table.get(state) == null){ + table.put(state, createDefaultActionValues()); } - return actionValues; + return table.get(state); } public static void main(String[] args) { diff --git a/src/main/java/core/StepResult.java b/src/main/java/core/StepResult.java index 4ed1586..7de2756 100644 --- a/src/main/java/core/StepResult.java +++ b/src/main/java/core/StepResult.java @@ -2,14 +2,11 @@ package core; import lombok.AllArgsConstructor; import lombok.Getter; -import lombok.Setter; -@Getter -@Setter @AllArgsConstructor -public class StepResult { +@Getter +public class StepResult { private State state; + private A action; private double reward; - private boolean done; - private String info; } diff --git a/src/main/java/core/StepResultEnvironment.java b/src/main/java/core/StepResultEnvironment.java new file mode 100644 index 0000000..b1d1c06 --- /dev/null +++ b/src/main/java/core/StepResultEnvironment.java @@ -0,0 +1,15 @@ +package core; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +@AllArgsConstructor +public class StepResultEnvironment { + private State state; + private double reward; + private boolean done; + private String info; +} diff --git a/src/main/java/core/algo/Learning.java b/src/main/java/core/algo/Learning.java new file mode 100644 index 0000000..8d71a78 --- /dev/null +++ b/src/main/java/core/algo/Learning.java @@ -0,0 +1,27 @@ +package core.algo; + +import core.DiscreteActionSpace; +import core.Environment; +import core.StateActionTable; +import core.policy.Policy; + +public abstract class Learning { + protected Policy policy; + protected DiscreteActionSpace actionSpace; + protected StateActionTable stateActionTable; + protected Environment environment; + protected float discountFactor; + protected float epsilon; + + public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon){ + this.environment = environment; + this.actionSpace = actionSpace; + this.discountFactor = discountFactor; + this.epsilon = epsilon; + } + public Learning(Environment environment, DiscreteActionSpace actionSpace){ + this(environment, actionSpace, 1.0f, 0.1f); + } + + public abstract void learn(int nrOfEpisodes, int delay); +} diff --git a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java new file mode 100644 index 0000000..55c70ed --- /dev/null +++ b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java @@ -0,0 +1,75 @@ +package core.algo.MC; + +import core.*; +import core.algo.Learning; +import core.policy.EpsilonGreedyPolicy; +import javafx.util.Pair; + +import java.util.*; + +public class MonteCarloOnPolicyEGreedy extends Learning { + + public MonteCarloOnPolicyEGreedy(Environment environment, DiscreteActionSpace actionSpace) { + super(environment, actionSpace); + discountFactor = 1f; + this.policy = new EpsilonGreedyPolicy<>(0.1f); + this.stateActionTable = new StateActionHashTable<>(actionSpace); + } + + @Override + public void learn(int nrOfEpisodes, int delay) { + + Map, Double> returnSum = new HashMap<>(); + Map, Integer> returnCount = new HashMap<>(); + + for(int i = 0; i < nrOfEpisodes; ++i) { + + List> episode = new ArrayList<>(); + State state = environment.reset(); + for(int j=0; j < 100; ++j){ + Map actionValues = stateActionTable.getActionValues(state); + A chosenAction = policy.chooseAction(actionValues); + StepResultEnvironment envResult = environment.step(chosenAction); + State nextState = envResult.getState(); + episode.add(new StepResult<>(state, chosenAction, envResult.getReward())); + + if(envResult.isDone()) break; + + state = nextState; + + try { + Thread.sleep(10); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + + Set> stateActionPairs = new HashSet<>(); + + for(StepResult sr: episode){ + stateActionPairs.add(new Pair<>(sr.getState(), sr.getAction())); + } + + for(Pair stateActionPair: stateActionPairs){ + int firstOccurenceIndex = 0; + // find first occurance of state action pair + for(StepResult sr: episode){ + if(stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())){ + break; + } + firstOccurenceIndex++; + } + + double G = 0; + for(int l = firstOccurenceIndex; l < episode.size(); ++l){ + G += episode.get(l).getReward() * (Math.pow(discountFactor, l - firstOccurenceIndex)); + } + // slick trick to add G to the entry. + // if the key does not exists, it will create a new entry with G as default value + returnSum.merge(stateActionPair, G, Double::sum); + returnCount.merge(stateActionPair, 1, Integer::sum); + stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair)); + } + } + } +} diff --git a/src/main/java/core/algo/TD/TemporalDifferenceOnPolicy.java b/src/main/java/core/algo/TD/TemporalDifferenceOnPolicy.java new file mode 100644 index 0000000..33716f8 --- /dev/null +++ b/src/main/java/core/algo/TD/TemporalDifferenceOnPolicy.java @@ -0,0 +1,4 @@ +package core.algo.TD; + +public class TemporalDifferenceOnPolicy { +} diff --git a/src/main/java/core/policy/EpsilonGreedyPolicy.java b/src/main/java/core/policy/EpsilonGreedyPolicy.java new file mode 100644 index 0000000..0e8d448 --- /dev/null +++ b/src/main/java/core/policy/EpsilonGreedyPolicy.java @@ -0,0 +1,35 @@ +package core.policy; + +import core.RNG; + +import java.util.Map; + +/** + * To prevent the agent from getting stuck only using the "best" action + * according to the current learning history, this policy + * will take random action with the probability of epsilon. + * (random action space includes the best action as well) + * + * @param Discrete Action Enum + */ +public class EpsilonGreedyPolicy implements Policy{ + private float epsilon; + private RandomPolicy randomPolicy; + private GreedyPolicy greedyPolicy; + + public EpsilonGreedyPolicy(float epsilon){ + this.epsilon = epsilon; + randomPolicy = new RandomPolicy<>(); + greedyPolicy = new GreedyPolicy<>(); + } + @Override + public A chooseAction(Map actionValues) { + if(RNG.getRandom().nextFloat() < epsilon){ + // Take random action + return randomPolicy.chooseAction(actionValues); + }else{ + // Take the action with the highest value + return greedyPolicy.chooseAction(actionValues); + } + } +} diff --git a/src/main/java/core/policy/GreedyPolicy.java b/src/main/java/core/policy/GreedyPolicy.java new file mode 100644 index 0000000..15b06f7 --- /dev/null +++ b/src/main/java/core/policy/GreedyPolicy.java @@ -0,0 +1,32 @@ +package core.policy; + +import core.RNG; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +public class GreedyPolicy implements Policy { + + @Override + public A chooseAction(Map actionValues) { + if(actionValues.size() == 0) throw new RuntimeException("Empty actionActionValues set"); + + Double highestValueAction = null; + + List equalHigh = new ArrayList<>(); + + for(Map.Entry actionValue : actionValues.entrySet()){ + System.out.println(actionValue.getKey()+ " " + actionValue.getValue() ); + if(highestValueAction == null || highestValueAction < actionValue.getValue()){ + highestValueAction = actionValue.getValue(); + equalHigh.clear(); + equalHigh.add(actionValue.getKey()); + }else if(highestValueAction.equals(actionValue.getValue())){ + equalHigh.add(actionValue.getKey()); + } + } + + return equalHigh.get(RNG.getRandom().nextInt(equalHigh.size())); + } +} diff --git a/src/main/java/core/policy/Policy.java b/src/main/java/core/policy/Policy.java new file mode 100644 index 0000000..fcb9d04 --- /dev/null +++ b/src/main/java/core/policy/Policy.java @@ -0,0 +1,7 @@ +package core.policy; + +import java.util.Map; + +public interface Policy { + A chooseAction(Map actionValues); +} diff --git a/src/main/java/core/policy/RandomPolicy.java b/src/main/java/core/policy/RandomPolicy.java new file mode 100644 index 0000000..1f8f086 --- /dev/null +++ b/src/main/java/core/policy/RandomPolicy.java @@ -0,0 +1,17 @@ +package core.policy; + +import core.RNG; +import java.util.Map; + +public class RandomPolicy implements Policy{ + @Override + public A chooseAction(Map actionValues) { + int idx = RNG.getRandom().nextInt(actionValues.size()); + int i = 0; + for(A action : actionValues.keySet()){ + if(i++ == idx) return action; + } + + return null; + } +} diff --git a/src/main/java/evironment/antGame/Ant.java b/src/main/java/evironment/antGame/Ant.java index 40beb2c..4bd4599 100644 --- a/src/main/java/evironment/antGame/Ant.java +++ b/src/main/java/evironment/antGame/Ant.java @@ -7,7 +7,6 @@ import lombok.Setter; import java.awt.*; -@AllArgsConstructor @Getter @Setter public class Ant { @@ -15,10 +14,14 @@ public class Ant { @Setter(AccessLevel.NONE) private Point pos; private int points; - private boolean spawned; @Getter(AccessLevel.NONE) private boolean hasFood; + public Ant(){ + pos = new Point(); + points = 0; + hasFood = false; + } public boolean hasFood(){ return hasFood; } diff --git a/src/main/java/evironment/antGame/AntAgent.java b/src/main/java/evironment/antGame/AntAgent.java index 4e4d28b..055ff87 100644 --- a/src/main/java/evironment/antGame/AntAgent.java +++ b/src/main/java/evironment/antGame/AntAgent.java @@ -9,7 +9,6 @@ public class AntAgent { public AntAgent(int width, int height){ knownWorld = new Cell[width][height]; - initUnknownWorld(); } /** @@ -24,7 +23,7 @@ public class AntAgent { return new AntState(knownWorld, observation.getPos(), observation.hasFood()); } - private void initUnknownWorld(){ + public void initUnknownWorld(){ for(int x = 0; x < knownWorld.length; ++x){ for(int y = 0; y < knownWorld[x].length; ++y){ knownWorld[x][y] = new Cell(new Point(x,y), CellType.UNKNOWN); diff --git a/src/main/java/evironment/antGame/AntState.java b/src/main/java/evironment/antGame/AntState.java index 6d4a77b..d1a31b7 100644 --- a/src/main/java/evironment/antGame/AntState.java +++ b/src/main/java/evironment/antGame/AntState.java @@ -11,17 +11,39 @@ import java.util.Arrays; * and therefor has to be deep copied */ public class AntState implements State { - private Cell[][] knownWorld; - private Point pos; - private boolean hasFood; - + private final Cell[][] knownWorld; + private final Point pos; + private final boolean hasFood; + private final int computedHash; public AntState(Cell[][] knownWorld, Point antPosition, boolean hasFood){ this.knownWorld = deepCopyCellGrid(knownWorld); this.pos = deepCopyAntPosition(antPosition); this.hasFood = hasFood; + computedHash = computeHash(); } + private int computeHash(){ + int hash = 7; + int prime = 31; + + int unknown = 0; + int diff = 0; + for (int i = 0; i < knownWorld.length; i++) { + for (int j = 0; j < knownWorld[i].length; j++) { + if(knownWorld[i][j].getType() == CellType.UNKNOWN){ + unknown += 1; + }else{ + diff +=1; + } + } + } + hash = prime * hash + unknown; + hash = prime * hash * diff; + hash = prime * hash + (hasFood ? 1:0); + hash = prime * hash + pos.hashCode(); + return hash; + } private Cell[][] deepCopyCellGrid(Cell[][] toCopy){ Cell[][] cells = new Cell[toCopy.length][toCopy[0].length]; for (int i = 0; i < cells.length; i++) { @@ -45,12 +67,7 @@ public class AntState implements State { //TODO: make this a utility function to generate hash Code based upon 2 prime numbers @Override public int hashCode(){ - int hash = 7; - int prime = 31; - hash = prime * hash + Arrays.hashCode(knownWorld); - hash = prime * hash + (hasFood ? 1:0); - hash = prime * hash + pos.hashCode(); - return hash; + return computedHash; } @Override diff --git a/src/main/java/evironment/antGame/AntWorld.java b/src/main/java/evironment/antGame/AntWorld.java index ef305e0..ff48eb4 100644 --- a/src/main/java/evironment/antGame/AntWorld.java +++ b/src/main/java/evironment/antGame/AntWorld.java @@ -1,10 +1,11 @@ package evironment.antGame; import core.*; +import core.algo.Learning; +import core.algo.MC.MonteCarloOnPolicyEGreedy; import evironment.antGame.gui.MainFrame; -import javax.swing.*; import java.awt.*; public class AntWorld implements Environment{ @@ -39,9 +40,8 @@ public class AntWorld implements Environment{ public AntWorld(int width, int height, double foodDensity){ grid = new Grid(width, height, foodDensity); antAgent = new AntAgent(width, height); - myAnt = new Ant(new Point(-1,-1), 0, false, false); + myAnt = new Ant(); gui = new MainFrame(this, antAgent); - tick = 0; maxEpisodeTicks = 1000; reset(); } @@ -55,23 +55,13 @@ public class AntWorld implements Environment{ } @Override - public StepResult step(AntAction action){ + public StepResultEnvironment step(AntAction action){ AntObservation observation; State newState; double reward = 0; String info = ""; boolean done = false; - if(!myAnt.isSpawned()){ - myAnt.setSpawned(true); - myAnt.getPos().setLocation(grid.getStartPoint()); - observation = new AntObservation(grid.getCell(myAnt.getPos()), myAnt.getPos(), myAnt.hasFood()); - newState = antAgent.feedObservation(observation); - reward = 0.0; - ++tick; - return new StepResult(newState, reward, false, "Just spawned on the map"); - } - Cell currentCell = grid.getCell(myAnt.getPos()); Point potentialNextPos = new Point(myAnt.getPos().x, myAnt.getPos().y); boolean stayOnCell = true; @@ -107,7 +97,7 @@ public class AntWorld implements Environment{ // Ant successfully picks up food currentCell.setFood(currentCell.getFood() - 1); myAnt.setHasFood(true); - reward = Reward.FOOD_DROP_DOWN_SUCCESS; + reward = Reward.FOOD_PICK_UP_SUCCESS; } break; case DROP_DOWN: @@ -169,24 +159,30 @@ public class AntWorld implements Environment{ if(++tick == maxEpisodeTicks){ done = true; } - return new StepResult(newState, reward, done, info); + + StepResultEnvironment result = new StepResultEnvironment(newState, reward, done, info); + getGui().update(action, result); + return result; } private boolean isInGrid(Point pos){ - return pos.x > 0 && pos.x < grid.getWidth() && pos.y > 0 && pos.y < grid.getHeight(); + return pos.x >= 0 && pos.x < grid.getWidth() && pos.y >= 0 && pos.y < grid.getHeight(); } private boolean hitObstacle(Point pos){ return grid.getCell(pos).getType() == CellType.OBSTACLE; } - public void reset() { + public State reset() { RNG.reseed(); grid.initRandomWorld(); - myAnt.getPos().setLocation(-1,-1); + antAgent.initUnknownWorld(); + tick = 0; + myAnt.getPos().setLocation(grid.getStartPoint()); myAnt.setPoints(0); myAnt.setHasFood(false); - myAnt.setSpawned(false); + AntObservation observation = new AntObservation(grid.getCell(myAnt.getPos()), myAnt.getPos(), myAnt.hasFood()); + return antAgent.feedObservation(observation); } public void setMaxEpisodeLength(int maxTicks){ @@ -207,21 +203,14 @@ public class AntWorld implements Environment{ public Ant getAnt(){ return myAnt; } + public static void main(String[] args) { RNG.setSeed(1993); - AntWorld a = new AntWorld(10, 10, 0.1); - ListDiscreteActionSpace actionSpace = - new ListDiscreteActionSpace<>(AntAction.MOVE_LEFT, AntAction.MOVE_RIGHT); - for(int i = 0; i< 1000; ++i){ - AntAction selectedAction = actionSpace.getAllActions().get(RNG.getRandom().nextInt(actionSpace.getNumberOfAction())); - StepResult step = a.step(selectedAction); - SwingUtilities.invokeLater(()-> a.getGui().update(selectedAction, step)); - try { - Thread.sleep(100); - } catch (InterruptedException e) { - e.printStackTrace(); - } - } + Learning monteCarlo = new MonteCarloOnPolicyEGreedy<>( + new AntWorld(3, 3, 0.1), + new ListDiscreteActionSpace<>(AntAction.values()) + ); + monteCarlo.learn(100,5); } } diff --git a/src/main/java/evironment/antGame/Reward.java b/src/main/java/evironment/antGame/Reward.java index b1fee95..855c6ff 100644 --- a/src/main/java/evironment/antGame/Reward.java +++ b/src/main/java/evironment/antGame/Reward.java @@ -1,17 +1,17 @@ package evironment.antGame; public class Reward { - public static final double FOOD_PICK_UP_SUCCESS = 1; - public static final double FOOD_PICK_UP_FAIL_NO_FOOD = -1000; - public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = -1000; + public static final double FOOD_PICK_UP_SUCCESS = 0; + public static final double FOOD_PICK_UP_FAIL_NO_FOOD = 0; + public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = 0; - public static final double FOOD_DROP_DOWN_FAIL_NO_FOOD = -1000; - public static final double FOOD_DROP_DOWN_FAIL_NOT_START = -1000; + public static final double FOOD_DROP_DOWN_FAIL_NO_FOOD = 0; + public static final double FOOD_DROP_DOWN_FAIL_NOT_START = 0; public static final double FOOD_DROP_DOWN_SUCCESS = 1000; - public static final double UNKNOWN_FIELD_EXPLORED = 1; + public static final double UNKNOWN_FIELD_EXPLORED = 0; - public static final double RAN_INTO_WALL = -100; - public static final double RAN_INTO_OBSTACLE = -100; + public static final double RAN_INTO_WALL = 0; + public static final double RAN_INTO_OBSTACLE = 0; } diff --git a/src/main/java/evironment/antGame/gui/MainFrame.java b/src/main/java/evironment/antGame/gui/MainFrame.java index 079bfa8..c299d78 100644 --- a/src/main/java/evironment/antGame/gui/MainFrame.java +++ b/src/main/java/evironment/antGame/gui/MainFrame.java @@ -1,6 +1,6 @@ package evironment.antGame.gui; -import core.StepResult; +import core.StepResultEnvironment; import evironment.antGame.AntAction; import evironment.antGame.AntAgent; import evironment.antGame.AntWorld; @@ -33,9 +33,9 @@ public class MainFrame extends JFrame { setVisible(true); } - public void update(AntAction lastAction, StepResult stepResult){ + public void update(AntAction lastAction, StepResultEnvironment stepResultEnvironment){ historyPanel.addText(String.format("Tick %d: \t Selected action: %s \t Reward: %f \t Info: %s \n totalPoints: %d \t hasFood: %b \t ", - antWorld.getTick(), lastAction.toString(), stepResult.getReward(), stepResult.getInfo(), antWorld.getAnt().getPoints(), antWorld.getAnt().hasFood())); + antWorld.getTick(), lastAction.toString(), stepResultEnvironment.getReward(), stepResultEnvironment.getInfo(), antWorld.getAnt().getPoints(), antWorld.getAnt().hasFood())); repaint(); }