add logic to handle ant action and compute rewards

- ant world will handle and compute action received by the agent
- first try to convert observations to markov states
- improved .equals() methods
This commit is contained in:
Jan Löwenstrom 2019-12-08 16:03:00 +01:00
parent ec67ce60c9
commit db9b62236c
11 changed files with 274 additions and 27 deletions

View File

@ -1,7 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ExternalStorageConfigurationManager" enabled="true" />
<component name="ProjectRootManager" version="2" languageLevel="JDK_11" project-jdk-name="11" project-jdk-type="JavaSDK">
<component name="ProjectRootManager" version="2" languageLevel="JDK_1_8" project-jdk-name="1.8" project-jdk-type="JavaSDK">
<output url="file://$PROJECT_DIR$/out" />
</component>
</project>

8
.idea/modules.xml Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/refo.iml" filepath="$PROJECT_DIR$/.idea/refo.iml" />
</modules>
</component>
</project>

21
.idea/refo.iml Normal file
View File

@ -0,0 +1,21 @@
<?xml version="1.0" encoding="UTF-8"?>
<module external.linked.project.id="refo" external.linked.project.path="$MODULE_DIR$" external.root.project.path="$MODULE_DIR$" external.system.id="GRADLE" external.system.module.group="net.lwenstrom.jan" external.system.module.version="1.0-SNAPSHOT" version="4">
<component name="NewModuleRootManager">
<output url="file://$MODULE_DIR$/build/classes/java/main" />
<output-test url="file://$MODULE_DIR$/build/classes/java/test" />
<exclude-output />
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/src/main/java" isTestSource="false" />
<excludeFolder url="file://$MODULE_DIR$/.gradle" />
<excludeFolder url="file://$MODULE_DIR$/build" />
</content>
<content url="file://$MODULE_DIR$/build/generated/sources/annotationProcessor/java/main">
<sourceFolder url="file://$MODULE_DIR$/build/generated/sources/annotationProcessor/java/main" isTestSource="false" generated="true" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
<orderEntry type="library" scope="PROVIDED" name="Gradle: org.projectlombok:lombok:1.18.10" level="project" />
<orderEntry type="library" scope="TEST" name="Gradle: junit:junit:4.12" level="project" />
<orderEntry type="library" scope="TEST" name="Gradle: org.hamcrest:hamcrest-core:1.3" level="project" />
</component>
</module>

View File

@ -28,13 +28,9 @@ public class DiscreteAction<A extends Enum> implements Action{
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
if(obj instanceof DiscreteAction){
return getIndex() == ((DiscreteAction) obj).getIndex();
}
return super.equals(obj);
}
}

View File

@ -4,17 +4,24 @@ package evironment.antGame;
import java.awt.*;
public class AntAgent {
// the brain
// the learned representation of the environment (gridWorld)
private Cell[][] knownWorld;
private Point pos;
public AntAgent(int width, int height){
knownWorld = new Cell[width][height];
initUnknownWorld();
}
/**
* Learn from observation received after last action
* and generate a markov state.
*
* @param observation received input from the game host (environment)
* @return the current state of the agent
*/
public AntState feedObservation(AntObservation observation){
knownWorld[observation.getPos().x][observation.getPos().y] = observation.getCell();
return new AntState(knownWorld, observation.getPos(), observation.hasFood());
}
private void initUnknownWorld(){
@ -25,7 +32,7 @@ public class AntAgent {
}
}
public Point getPos(){
return pos;
public Cell getCell(Point pos){
return knownWorld[pos.x][pos.y];
}
}

View File

@ -1,15 +1,26 @@
package evironment.antGame;
import core.Observation;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Setter;
import java.awt.*;
@AllArgsConstructor
@Getter
@Setter
public class AntObservation implements Observation {
private Cell cell;
private Point pos;
@Getter(AccessLevel.NONE)
private boolean hasFood;
public boolean hasFood(){
return hasFood;
}
}

View File

@ -2,16 +2,74 @@ package evironment.antGame;
import core.State;
// somewhat the "brain" of the agent, current known setting of the environment
import java.awt.*;
import java.util.Arrays;
/**
* Markov state of the experienced environment.
* Essentially a snapshot of the current Ant Agent
* and therefor has to be deep copied
*/
public class AntState implements State {
private Grid knownGrid;
private Cell[][] knownWorld;
private Point pos;
private boolean hasFood;
public AntState(int width, int height){
knownGrid = new Grid(width, height);
public AntState(Cell[][] knownWorld, Point antPosition, boolean hasFood){
this.knownWorld = deepCopyCellGrid(knownWorld);
this.pos = deepCopyAntPosition(antPosition);
this.hasFood = hasFood;
}
public AntState(){
this(Constants.DEFAULT_GRID_WIDTH, Constants.DEFAULT_GRID_HEIGHT);
private Cell[][] deepCopyCellGrid(Cell[][] toCopy){
Cell[][] cells = new Cell[toCopy.length][toCopy[0].length];
for (int i = 0; i < cells.length; i++) {
for (int j = 0; j < cells[i].length; j++) {
// calling copy constructor of Cell class
cells[i][j] = new Cell(toCopy[i][j]);
}
}
return cells;
}
private Point deepCopyAntPosition(Point toCopy){
return new Point(toCopy.x,toCopy.y);
}
@Override
public String toString(){
return String.format("Pos: %s, hasFood: %b, knownWorld: %s", pos.toString(), hasFood, Arrays.toString(knownWorld));
}
//TODO: make this a utility function to generate hash Code based upon 2 prime numbers
@Override
public int hashCode(){
int hash = 7;
int prime = 31;
hash = prime * hash + Arrays.hashCode(knownWorld);
hash = prime * hash + (hasFood ? 1:0);
hash = prime * hash + pos.hashCode();
return hash;
}
@Override
public boolean equals(Object obj){
if(obj instanceof AntState){
AntState toCompare = (AntState) obj;
if(!this.pos.equals(toCompare.pos) || !this.hasFood == toCompare.hasFood){
return false;
}
for (int i = 0; i < toCompare.knownWorld.length; i++) {
for (int j = 0; j < toCompare.knownWorld[i].length; j++) {
if(!this.knownWorld[i][j].equals(toCompare.knownWorld[i][j])){
return false;
}
}
}
return true;
}
return super.equals(obj);
}
}

View File

@ -29,9 +29,14 @@ public class AntWorld {
*/
private AntAgent antAgent;
private int tick;
private int maxEpisodeTicks;
public AntWorld(int width, int height, double foodDensity){
grid = new Grid(width, height, foodDensity);
antAgent = new AntAgent(width, height);
tick = 0;
maxEpisodeTicks = 1000;
}
public AntWorld(){
@ -39,7 +44,7 @@ public class AntWorld {
}
private static class MyAnt{
int x,y;
Point pos;
boolean hasFood;
boolean spawned;
}
@ -47,30 +52,125 @@ public class AntWorld {
public StepResult step(DiscreteAction<AntAction> action){
AntObservation observation;
State newState;
double reward = 0;
String info = "";
boolean done = false;
if(!myAnt.spawned){
observation = new AntObservation(grid.getCell(grid.getStartPoint()));
myAnt.spawned = true;
myAnt.pos = grid.getStartPoint();
observation = new AntObservation(grid.getCell(myAnt.pos), myAnt.pos, myAnt.hasFood);
newState = antAgent.feedObservation(observation);
return new StepResult(newState, 0.0, false, "Just spawned on the map");
reward = 0.0;
return new StepResult(newState, reward, false, "Just spawned on the map");
}
Cell currentCell = grid.getCell(myAnt.pos);
Point potentialNextPos = new Point(myAnt.pos.x, myAnt.pos.y);
boolean stayOnCell = true;
// flag to enable a check if all food has been collected only fired if food was dropped
// on the starting position
boolean checkCompletion = false;
switch (action.getValue()) {
case MOVE_UP:
potentialNextPos.y -= 1;
stayOnCell = false;
break;
case MOVE_RIGHT:
potentialNextPos.x += 1;
stayOnCell = false;
break;
case MOVE_DOWN:
potentialNextPos.y += 1;
stayOnCell = false;
break;
case MOVE_LEFT:
potentialNextPos.x -= 1;
stayOnCell = false;
break;
case PICK_UP:
if(myAnt.hasFood){
// Ant tries to pick up food but can only hold one piece
reward = Reward.FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY;
}else if(currentCell.getFood() == 0){
// Ant tries to pick up food on cell that has no food on it
reward = Reward.FOOD_PICK_UP_FAIL_NO_FOOD;
}else if(currentCell.getFood() > 0){
// Ant successfully picks up food
currentCell.setFood(currentCell.getFood() - 1);
myAnt.hasFood = true;
reward = Reward.FOOD_DROP_DOWN_SUCCESS;
}
break;
case DROP_DOWN:
if(!myAnt.hasFood){
// Ant had no food to drop
reward = Reward.FOOD_DROP_DOWN_FAIL_NO_FOOD;
}else{
// Drop food onto the ground
currentCell.setFood(currentCell.getFood() + 1);
myAnt.hasFood = false;
// negative reward if the agent drops food on any other field
// than the starting point
if(currentCell.getType() != CellType.START){
reward = Reward.FOOD_DROP_DOWN_FAIL_NOT_START;
}else{
reward = Reward.FOOD_DROP_DOWN_SUCCESS;
checkCompletion = true;
}
}
break;
default:
throw new RuntimeException(String.format("Action <%s> is not a valid action!", action.toString()));
break;
}
// movement action was selected
if(!stayOnCell){
if(!isInGrid(potentialNextPos)){
stayOnCell = true;
reward = Reward.RAN_INTO_WALL;
}else if(hitObstacle(potentialNextPos)){
stayOnCell = true;
reward = Reward.RAN_INTO_OBSTACLE;
}
}
// valid movement
if(!stayOnCell){
myAnt.pos = potentialNextPos;
if(antAgent.getCell(myAnt.pos).getType() == CellType.UNKNOWN){
// the ant will move to a cell that was previously unknown
reward = Reward.UNKNOWN_FIELD_EXPLORED;
}else{
reward = 0;
}
}
// get observation after action was computed
observation = new AntObservation(grid.getCell(myAnt.pos), myAnt.pos, myAnt.hasFood);
// let the ant agent process the observation to create a valid markov state
newState = antAgent.feedObservation(observation);
return new StepResult(newState, 0.0, false, "");
if(checkCompletion){
done = grid.isAllFoodCollected();
}
if(++tick == maxEpisodeTicks){
done = true;
}
return new StepResult(newState, reward, done, info);
}
private boolean isInGrid(Point pos){
return pos.x > 0 && pos.x < grid.getWidth() && pos.y > 0 && pos.y < grid.getHeight();
}
private boolean hitObstacle(Point pos){
return grid.getCell(pos).getType() == CellType.OBSTACLE;
}
public void reset() {
@ -79,6 +179,9 @@ public class AntWorld {
myAnt = new MyAnt();
}
public void setMaxEpisodeLength(int maxTicks){
this.maxEpisodeTicks = maxTicks;
}
public Point getSpawningPoint(){
return grid.getStartPoint();
}

View File

@ -20,8 +20,22 @@ public class Cell {
food = foodAmount;
}
public Cell(Cell c){
this.pos = new Point(c.pos.x, c.pos.y);
this.food = c.getFood();
this.type = c.getType();
}
public Cell( Point pos, CellType cellType){
this(pos, cellType, 0);
}
@Override
public boolean equals(Object obj){
if(obj instanceof Cell){
Cell cell = (Cell) obj;
return this.type == cell.getType() && this.food == cell.getFood() && this.pos.x == cell.getPos().x && this.pos.y ==cell.getPos().y;
}
return super.equals(obj);
}
}

View File

@ -41,6 +41,18 @@ public class Grid {
return start;
}
public boolean isAllFoodCollected(){
for(int x = 0; x < width; ++x){
for(int y = 0; y < height; ++y){
if(grid[x][y].getFood() > 0){
return false;
}
}
}
return true;
}
public Cell[][] getGrid(){
return grid;
}

View File

@ -0,0 +1,17 @@
package evironment.antGame;
public class Reward {
public static final double FOOD_PICK_UP_SUCCESS = 1;
public static final double FOOD_PICK_UP_FAIL_NO_FOOD = -1000;
public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = -1000;
public static final double FOOD_DROP_DOWN_FAIL_NO_FOOD = -1000;
public static final double FOOD_DROP_DOWN_FAIL_NOT_START = -1000;
public static final double FOOD_DROP_DOWN_SUCCESS = 1000;
public static final double UNKNOWN_FIELD_EXPLORED = 1;
public static final double RAN_INTO_WALL = -100;
public static final double RAN_INTO_OBSTACLE = -100;
}