From db9b62236c49b0aa7d0e8e4b54ee5ba739c8397f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20L=C3=B6wenstrom?= Date: Sun, 8 Dec 2019 16:03:00 +0100 Subject: [PATCH] add logic to handle ant action and compute rewards - ant world will handle and compute action received by the agent - first try to convert observations to markov states - improved .equals() methods --- .idea/misc.xml | 2 +- .idea/modules.xml | 8 ++ .idea/refo.iml | 21 ++++ src/main/java/core/DiscreteAction.java | 12 +- .../java/evironment/antGame/AntAgent.java | 17 ++- .../evironment/antGame/AntObservation.java | 13 +- .../java/evironment/antGame/AntState.java | 70 ++++++++++- .../java/evironment/antGame/AntWorld.java | 115 +++++++++++++++++- src/main/java/evironment/antGame/Cell.java | 14 +++ src/main/java/evironment/antGame/Grid.java | 12 ++ src/main/java/evironment/antGame/Reward.java | 17 +++ 11 files changed, 274 insertions(+), 27 deletions(-) create mode 100644 .idea/modules.xml create mode 100644 .idea/refo.iml create mode 100644 src/main/java/evironment/antGame/Reward.java diff --git a/.idea/misc.xml b/.idea/misc.xml index 29af3ee..bc8d0a3 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -1,7 +1,7 @@ - + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..4acc2aa --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/refo.iml b/.idea/refo.iml new file mode 100644 index 0000000..2530107 --- /dev/null +++ b/.idea/refo.iml @@ -0,0 +1,21 @@ + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/src/main/java/core/DiscreteAction.java b/src/main/java/core/DiscreteAction.java index bb90c9d..52a8c0a 100644 --- a/src/main/java/core/DiscreteAction.java +++ b/src/main/java/core/DiscreteAction.java @@ -28,13 +28,9 @@ public class DiscreteAction implements Action{ @Override public boolean equals(Object obj) { - if (this == obj) - return true; - if (obj == null) - return false; - if (getClass() != obj.getClass()) - return false; - - return getIndex() == ((DiscreteAction) obj).getIndex(); + if(obj instanceof DiscreteAction){ + return getIndex() == ((DiscreteAction) obj).getIndex(); + } + return super.equals(obj); } } diff --git a/src/main/java/evironment/antGame/AntAgent.java b/src/main/java/evironment/antGame/AntAgent.java index 8cf920f..0bafbc3 100644 --- a/src/main/java/evironment/antGame/AntAgent.java +++ b/src/main/java/evironment/antGame/AntAgent.java @@ -4,17 +4,24 @@ package evironment.antGame; import java.awt.*; public class AntAgent { - // the brain + // the learned representation of the environment (gridWorld) private Cell[][] knownWorld; - private Point pos; public AntAgent(int width, int height){ knownWorld = new Cell[width][height]; initUnknownWorld(); } + /** + * Learn from observation received after last action + * and generate a markov state. + * + * @param observation received input from the game host (environment) + * @return the current state of the agent + */ public AntState feedObservation(AntObservation observation){ - + knownWorld[observation.getPos().x][observation.getPos().y] = observation.getCell(); + return new AntState(knownWorld, observation.getPos(), observation.hasFood()); } private void initUnknownWorld(){ @@ -25,7 +32,7 @@ public class AntAgent { } } - public Point getPos(){ - return pos; + public Cell getCell(Point pos){ + return knownWorld[pos.x][pos.y]; } } diff --git a/src/main/java/evironment/antGame/AntObservation.java b/src/main/java/evironment/antGame/AntObservation.java index 57ddea0..d497214 100644 --- a/src/main/java/evironment/antGame/AntObservation.java +++ b/src/main/java/evironment/antGame/AntObservation.java @@ -1,15 +1,26 @@ package evironment.antGame; import core.Observation; +import lombok.AccessLevel; import lombok.AllArgsConstructor; import lombok.Getter; import lombok.Setter; import java.awt.*; + @AllArgsConstructor @Getter @Setter public class AntObservation implements Observation { - private Cell cell; + private Cell cell; + private Point pos; + + + @Getter(AccessLevel.NONE) + private boolean hasFood; + + public boolean hasFood(){ + return hasFood; + } } diff --git a/src/main/java/evironment/antGame/AntState.java b/src/main/java/evironment/antGame/AntState.java index a499c84..6d4a77b 100644 --- a/src/main/java/evironment/antGame/AntState.java +++ b/src/main/java/evironment/antGame/AntState.java @@ -2,16 +2,74 @@ package evironment.antGame; import core.State; -// somewhat the "brain" of the agent, current known setting of the environment +import java.awt.*; +import java.util.Arrays; + +/** + * Markov state of the experienced environment. + * Essentially a snapshot of the current Ant Agent + * and therefor has to be deep copied + */ public class AntState implements State { - private Grid knownGrid; + private Cell[][] knownWorld; + private Point pos; + private boolean hasFood; - public AntState(int width, int height){ - knownGrid = new Grid(width, height); + public AntState(Cell[][] knownWorld, Point antPosition, boolean hasFood){ + this.knownWorld = deepCopyCellGrid(knownWorld); + this.pos = deepCopyAntPosition(antPosition); + this.hasFood = hasFood; } - public AntState(){ - this(Constants.DEFAULT_GRID_WIDTH, Constants.DEFAULT_GRID_HEIGHT); + private Cell[][] deepCopyCellGrid(Cell[][] toCopy){ + Cell[][] cells = new Cell[toCopy.length][toCopy[0].length]; + for (int i = 0; i < cells.length; i++) { + for (int j = 0; j < cells[i].length; j++) { + // calling copy constructor of Cell class + cells[i][j] = new Cell(toCopy[i][j]); + } + } + return cells; + } + + private Point deepCopyAntPosition(Point toCopy){ + return new Point(toCopy.x,toCopy.y); + } + + @Override + public String toString(){ + return String.format("Pos: %s, hasFood: %b, knownWorld: %s", pos.toString(), hasFood, Arrays.toString(knownWorld)); +} + + //TODO: make this a utility function to generate hash Code based upon 2 prime numbers + @Override + public int hashCode(){ + int hash = 7; + int prime = 31; + hash = prime * hash + Arrays.hashCode(knownWorld); + hash = prime * hash + (hasFood ? 1:0); + hash = prime * hash + pos.hashCode(); + return hash; + } + + @Override + public boolean equals(Object obj){ + if(obj instanceof AntState){ + AntState toCompare = (AntState) obj; + if(!this.pos.equals(toCompare.pos) || !this.hasFood == toCompare.hasFood){ + return false; + } + for (int i = 0; i < toCompare.knownWorld.length; i++) { + for (int j = 0; j < toCompare.knownWorld[i].length; j++) { + if(!this.knownWorld[i][j].equals(toCompare.knownWorld[i][j])){ + return false; + } + + } + } + return true; + } + return super.equals(obj); } } diff --git a/src/main/java/evironment/antGame/AntWorld.java b/src/main/java/evironment/antGame/AntWorld.java index d156751..3675104 100644 --- a/src/main/java/evironment/antGame/AntWorld.java +++ b/src/main/java/evironment/antGame/AntWorld.java @@ -6,7 +6,7 @@ import java.awt.*; public class AntWorld { /** - * + * */ private Grid grid; /** @@ -29,9 +29,14 @@ public class AntWorld { */ private AntAgent antAgent; + private int tick; + private int maxEpisodeTicks; + public AntWorld(int width, int height, double foodDensity){ grid = new Grid(width, height, foodDensity); antAgent = new AntAgent(width, height); + tick = 0; + maxEpisodeTicks = 1000; } public AntWorld(){ @@ -39,7 +44,7 @@ public class AntWorld { } private static class MyAnt{ - int x,y; + Point pos; boolean hasFood; boolean spawned; } @@ -47,30 +52,125 @@ public class AntWorld { public StepResult step(DiscreteAction action){ AntObservation observation; State newState; + double reward = 0; + String info = ""; + boolean done = false; + if(!myAnt.spawned){ - observation = new AntObservation(grid.getCell(grid.getStartPoint())); + myAnt.spawned = true; + myAnt.pos = grid.getStartPoint(); + + observation = new AntObservation(grid.getCell(myAnt.pos), myAnt.pos, myAnt.hasFood); newState = antAgent.feedObservation(observation); - return new StepResult(newState, 0.0, false, "Just spawned on the map"); + reward = 0.0; + return new StepResult(newState, reward, false, "Just spawned on the map"); } + + Cell currentCell = grid.getCell(myAnt.pos); + Point potentialNextPos = new Point(myAnt.pos.x, myAnt.pos.y); + boolean stayOnCell = true; + // flag to enable a check if all food has been collected only fired if food was dropped + // on the starting position + boolean checkCompletion = false; + switch (action.getValue()) { case MOVE_UP: + potentialNextPos.y -= 1; + stayOnCell = false; break; case MOVE_RIGHT: + potentialNextPos.x += 1; + stayOnCell = false; break; case MOVE_DOWN: + potentialNextPos.y += 1; + stayOnCell = false; break; case MOVE_LEFT: + potentialNextPos.x -= 1; + stayOnCell = false; break; case PICK_UP: + if(myAnt.hasFood){ + // Ant tries to pick up food but can only hold one piece + reward = Reward.FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY; + }else if(currentCell.getFood() == 0){ + // Ant tries to pick up food on cell that has no food on it + reward = Reward.FOOD_PICK_UP_FAIL_NO_FOOD; + }else if(currentCell.getFood() > 0){ + // Ant successfully picks up food + currentCell.setFood(currentCell.getFood() - 1); + myAnt.hasFood = true; + reward = Reward.FOOD_DROP_DOWN_SUCCESS; + } break; case DROP_DOWN: + if(!myAnt.hasFood){ + // Ant had no food to drop + reward = Reward.FOOD_DROP_DOWN_FAIL_NO_FOOD; + }else{ + // Drop food onto the ground + currentCell.setFood(currentCell.getFood() + 1); + myAnt.hasFood = false; + + // negative reward if the agent drops food on any other field + // than the starting point + if(currentCell.getType() != CellType.START){ + reward = Reward.FOOD_DROP_DOWN_FAIL_NOT_START; + }else{ + reward = Reward.FOOD_DROP_DOWN_SUCCESS; + checkCompletion = true; + } + } break; default: throw new RuntimeException(String.format("Action <%s> is not a valid action!", action.toString())); - break; } + + // movement action was selected + if(!stayOnCell){ + if(!isInGrid(potentialNextPos)){ + stayOnCell = true; + reward = Reward.RAN_INTO_WALL; + }else if(hitObstacle(potentialNextPos)){ + stayOnCell = true; + reward = Reward.RAN_INTO_OBSTACLE; + } + } + + // valid movement + if(!stayOnCell){ + myAnt.pos = potentialNextPos; + if(antAgent.getCell(myAnt.pos).getType() == CellType.UNKNOWN){ + // the ant will move to a cell that was previously unknown + reward = Reward.UNKNOWN_FIELD_EXPLORED; + }else{ + reward = 0; + } + } + + // get observation after action was computed + observation = new AntObservation(grid.getCell(myAnt.pos), myAnt.pos, myAnt.hasFood); + + // let the ant agent process the observation to create a valid markov state newState = antAgent.feedObservation(observation); - return new StepResult(newState, 0.0, false, ""); + + if(checkCompletion){ + done = grid.isAllFoodCollected(); + } + + if(++tick == maxEpisodeTicks){ + done = true; + } + return new StepResult(newState, reward, done, info); + } + + private boolean isInGrid(Point pos){ + return pos.x > 0 && pos.x < grid.getWidth() && pos.y > 0 && pos.y < grid.getHeight(); + } + + private boolean hitObstacle(Point pos){ + return grid.getCell(pos).getType() == CellType.OBSTACLE; } public void reset() { @@ -79,6 +179,9 @@ public class AntWorld { myAnt = new MyAnt(); } + public void setMaxEpisodeLength(int maxTicks){ + this.maxEpisodeTicks = maxTicks; + } public Point getSpawningPoint(){ return grid.getStartPoint(); } diff --git a/src/main/java/evironment/antGame/Cell.java b/src/main/java/evironment/antGame/Cell.java index ab726fc..5fccc5d 100644 --- a/src/main/java/evironment/antGame/Cell.java +++ b/src/main/java/evironment/antGame/Cell.java @@ -20,8 +20,22 @@ public class Cell { food = foodAmount; } + public Cell(Cell c){ + this.pos = new Point(c.pos.x, c.pos.y); + this.food = c.getFood(); + this.type = c.getType(); + } + public Cell( Point pos, CellType cellType){ this(pos, cellType, 0); } + @Override + public boolean equals(Object obj){ + if(obj instanceof Cell){ + Cell cell = (Cell) obj; + return this.type == cell.getType() && this.food == cell.getFood() && this.pos.x == cell.getPos().x && this.pos.y ==cell.getPos().y; + } + return super.equals(obj); + } } diff --git a/src/main/java/evironment/antGame/Grid.java b/src/main/java/evironment/antGame/Grid.java index ad5f45d..2f9f884 100644 --- a/src/main/java/evironment/antGame/Grid.java +++ b/src/main/java/evironment/antGame/Grid.java @@ -41,6 +41,18 @@ public class Grid { return start; } + public boolean isAllFoodCollected(){ + for(int x = 0; x < width; ++x){ + for(int y = 0; y < height; ++y){ + if(grid[x][y].getFood() > 0){ + return false; + } + } + } + + return true; + } + public Cell[][] getGrid(){ return grid; } diff --git a/src/main/java/evironment/antGame/Reward.java b/src/main/java/evironment/antGame/Reward.java new file mode 100644 index 0000000..b1fee95 --- /dev/null +++ b/src/main/java/evironment/antGame/Reward.java @@ -0,0 +1,17 @@ +package evironment.antGame; + +public class Reward { + public static final double FOOD_PICK_UP_SUCCESS = 1; + public static final double FOOD_PICK_UP_FAIL_NO_FOOD = -1000; + public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = -1000; + + public static final double FOOD_DROP_DOWN_FAIL_NO_FOOD = -1000; + public static final double FOOD_DROP_DOWN_FAIL_NOT_START = -1000; + public static final double FOOD_DROP_DOWN_SUCCESS = 1000; + + public static final double UNKNOWN_FIELD_EXPLORED = 1; + + public static final double RAN_INTO_WALL = -100; + public static final double RAN_INTO_OBSTACLE = -100; + +}