add logic to handle ant action and compute rewards
- ant world will handle and compute action received by the agent - first try to convert observations to markov states - improved .equals() methods
This commit is contained in:
parent
ec67ce60c9
commit
db9b62236c
|
@ -1,7 +1,7 @@
|
||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<project version="4">
|
<project version="4">
|
||||||
<component name="ExternalStorageConfigurationManager" enabled="true" />
|
<component name="ExternalStorageConfigurationManager" enabled="true" />
|
||||||
<component name="ProjectRootManager" version="2" languageLevel="JDK_11" project-jdk-name="11" project-jdk-type="JavaSDK">
|
<component name="ProjectRootManager" version="2" languageLevel="JDK_1_8" project-jdk-name="1.8" project-jdk-type="JavaSDK">
|
||||||
<output url="file://$PROJECT_DIR$/out" />
|
<output url="file://$PROJECT_DIR$/out" />
|
||||||
</component>
|
</component>
|
||||||
</project>
|
</project>
|
|
@ -0,0 +1,8 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/refo.iml" filepath="$PROJECT_DIR$/.idea/refo.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
|
@ -0,0 +1,21 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module external.linked.project.id="refo" external.linked.project.path="$MODULE_DIR$" external.root.project.path="$MODULE_DIR$" external.system.id="GRADLE" external.system.module.group="net.lwenstrom.jan" external.system.module.version="1.0-SNAPSHOT" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<output url="file://$MODULE_DIR$/build/classes/java/main" />
|
||||||
|
<output-test url="file://$MODULE_DIR$/build/classes/java/test" />
|
||||||
|
<exclude-output />
|
||||||
|
<content url="file://$MODULE_DIR$">
|
||||||
|
<sourceFolder url="file://$MODULE_DIR$/src/main/java" isTestSource="false" />
|
||||||
|
<excludeFolder url="file://$MODULE_DIR$/.gradle" />
|
||||||
|
<excludeFolder url="file://$MODULE_DIR$/build" />
|
||||||
|
</content>
|
||||||
|
<content url="file://$MODULE_DIR$/build/generated/sources/annotationProcessor/java/main">
|
||||||
|
<sourceFolder url="file://$MODULE_DIR$/build/generated/sources/annotationProcessor/java/main" isTestSource="false" generated="true" />
|
||||||
|
</content>
|
||||||
|
<orderEntry type="inheritedJdk" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
<orderEntry type="library" scope="PROVIDED" name="Gradle: org.projectlombok:lombok:1.18.10" level="project" />
|
||||||
|
<orderEntry type="library" scope="TEST" name="Gradle: junit:junit:4.12" level="project" />
|
||||||
|
<orderEntry type="library" scope="TEST" name="Gradle: org.hamcrest:hamcrest-core:1.3" level="project" />
|
||||||
|
</component>
|
||||||
|
</module>
|
|
@ -28,13 +28,9 @@ public class DiscreteAction<A extends Enum> implements Action{
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public boolean equals(Object obj) {
|
public boolean equals(Object obj) {
|
||||||
if (this == obj)
|
if(obj instanceof DiscreteAction){
|
||||||
return true;
|
return getIndex() == ((DiscreteAction) obj).getIndex();
|
||||||
if (obj == null)
|
}
|
||||||
return false;
|
return super.equals(obj);
|
||||||
if (getClass() != obj.getClass())
|
|
||||||
return false;
|
|
||||||
|
|
||||||
return getIndex() == ((DiscreteAction) obj).getIndex();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,17 +4,24 @@ package evironment.antGame;
|
||||||
import java.awt.*;
|
import java.awt.*;
|
||||||
|
|
||||||
public class AntAgent {
|
public class AntAgent {
|
||||||
// the brain
|
// the learned representation of the environment (gridWorld)
|
||||||
private Cell[][] knownWorld;
|
private Cell[][] knownWorld;
|
||||||
private Point pos;
|
|
||||||
|
|
||||||
public AntAgent(int width, int height){
|
public AntAgent(int width, int height){
|
||||||
knownWorld = new Cell[width][height];
|
knownWorld = new Cell[width][height];
|
||||||
initUnknownWorld();
|
initUnknownWorld();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Learn from observation received after last action
|
||||||
|
* and generate a markov state.
|
||||||
|
*
|
||||||
|
* @param observation received input from the game host (environment)
|
||||||
|
* @return the current state of the agent
|
||||||
|
*/
|
||||||
public AntState feedObservation(AntObservation observation){
|
public AntState feedObservation(AntObservation observation){
|
||||||
|
knownWorld[observation.getPos().x][observation.getPos().y] = observation.getCell();
|
||||||
|
return new AntState(knownWorld, observation.getPos(), observation.hasFood());
|
||||||
}
|
}
|
||||||
|
|
||||||
private void initUnknownWorld(){
|
private void initUnknownWorld(){
|
||||||
|
@ -25,7 +32,7 @@ public class AntAgent {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public Point getPos(){
|
public Cell getCell(Point pos){
|
||||||
return pos;
|
return knownWorld[pos.x][pos.y];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,15 +1,26 @@
|
||||||
package evironment.antGame;
|
package evironment.antGame;
|
||||||
|
|
||||||
import core.Observation;
|
import core.Observation;
|
||||||
|
import lombok.AccessLevel;
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
|
|
||||||
import java.awt.*;
|
import java.awt.*;
|
||||||
|
|
||||||
|
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@Getter
|
@Getter
|
||||||
@Setter
|
@Setter
|
||||||
public class AntObservation implements Observation {
|
public class AntObservation implements Observation {
|
||||||
private Cell cell;
|
private Cell cell;
|
||||||
|
private Point pos;
|
||||||
|
|
||||||
|
|
||||||
|
@Getter(AccessLevel.NONE)
|
||||||
|
private boolean hasFood;
|
||||||
|
|
||||||
|
public boolean hasFood(){
|
||||||
|
return hasFood;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,16 +2,74 @@ package evironment.antGame;
|
||||||
|
|
||||||
import core.State;
|
import core.State;
|
||||||
|
|
||||||
// somewhat the "brain" of the agent, current known setting of the environment
|
import java.awt.*;
|
||||||
|
import java.util.Arrays;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Markov state of the experienced environment.
|
||||||
|
* Essentially a snapshot of the current Ant Agent
|
||||||
|
* and therefor has to be deep copied
|
||||||
|
*/
|
||||||
public class AntState implements State {
|
public class AntState implements State {
|
||||||
private Grid knownGrid;
|
private Cell[][] knownWorld;
|
||||||
|
private Point pos;
|
||||||
|
private boolean hasFood;
|
||||||
|
|
||||||
public AntState(int width, int height){
|
|
||||||
knownGrid = new Grid(width, height);
|
|
||||||
|
|
||||||
|
public AntState(Cell[][] knownWorld, Point antPosition, boolean hasFood){
|
||||||
|
this.knownWorld = deepCopyCellGrid(knownWorld);
|
||||||
|
this.pos = deepCopyAntPosition(antPosition);
|
||||||
|
this.hasFood = hasFood;
|
||||||
}
|
}
|
||||||
|
|
||||||
public AntState(){
|
private Cell[][] deepCopyCellGrid(Cell[][] toCopy){
|
||||||
this(Constants.DEFAULT_GRID_WIDTH, Constants.DEFAULT_GRID_HEIGHT);
|
Cell[][] cells = new Cell[toCopy.length][toCopy[0].length];
|
||||||
|
for (int i = 0; i < cells.length; i++) {
|
||||||
|
for (int j = 0; j < cells[i].length; j++) {
|
||||||
|
// calling copy constructor of Cell class
|
||||||
|
cells[i][j] = new Cell(toCopy[i][j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cells;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Point deepCopyAntPosition(Point toCopy){
|
||||||
|
return new Point(toCopy.x,toCopy.y);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString(){
|
||||||
|
return String.format("Pos: %s, hasFood: %b, knownWorld: %s", pos.toString(), hasFood, Arrays.toString(knownWorld));
|
||||||
|
}
|
||||||
|
|
||||||
|
//TODO: make this a utility function to generate hash Code based upon 2 prime numbers
|
||||||
|
@Override
|
||||||
|
public int hashCode(){
|
||||||
|
int hash = 7;
|
||||||
|
int prime = 31;
|
||||||
|
hash = prime * hash + Arrays.hashCode(knownWorld);
|
||||||
|
hash = prime * hash + (hasFood ? 1:0);
|
||||||
|
hash = prime * hash + pos.hashCode();
|
||||||
|
return hash;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object obj){
|
||||||
|
if(obj instanceof AntState){
|
||||||
|
AntState toCompare = (AntState) obj;
|
||||||
|
if(!this.pos.equals(toCompare.pos) || !this.hasFood == toCompare.hasFood){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < toCompare.knownWorld.length; i++) {
|
||||||
|
for (int j = 0; j < toCompare.knownWorld[i].length; j++) {
|
||||||
|
if(!this.knownWorld[i][j].equals(toCompare.knownWorld[i][j])){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return super.equals(obj);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,7 @@ import java.awt.*;
|
||||||
|
|
||||||
public class AntWorld {
|
public class AntWorld {
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
private Grid grid;
|
private Grid grid;
|
||||||
/**
|
/**
|
||||||
|
@ -29,9 +29,14 @@ public class AntWorld {
|
||||||
*/
|
*/
|
||||||
private AntAgent antAgent;
|
private AntAgent antAgent;
|
||||||
|
|
||||||
|
private int tick;
|
||||||
|
private int maxEpisodeTicks;
|
||||||
|
|
||||||
public AntWorld(int width, int height, double foodDensity){
|
public AntWorld(int width, int height, double foodDensity){
|
||||||
grid = new Grid(width, height, foodDensity);
|
grid = new Grid(width, height, foodDensity);
|
||||||
antAgent = new AntAgent(width, height);
|
antAgent = new AntAgent(width, height);
|
||||||
|
tick = 0;
|
||||||
|
maxEpisodeTicks = 1000;
|
||||||
}
|
}
|
||||||
|
|
||||||
public AntWorld(){
|
public AntWorld(){
|
||||||
|
@ -39,7 +44,7 @@ public class AntWorld {
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class MyAnt{
|
private static class MyAnt{
|
||||||
int x,y;
|
Point pos;
|
||||||
boolean hasFood;
|
boolean hasFood;
|
||||||
boolean spawned;
|
boolean spawned;
|
||||||
}
|
}
|
||||||
|
@ -47,30 +52,125 @@ public class AntWorld {
|
||||||
public StepResult step(DiscreteAction<AntAction> action){
|
public StepResult step(DiscreteAction<AntAction> action){
|
||||||
AntObservation observation;
|
AntObservation observation;
|
||||||
State newState;
|
State newState;
|
||||||
|
double reward = 0;
|
||||||
|
String info = "";
|
||||||
|
boolean done = false;
|
||||||
|
|
||||||
if(!myAnt.spawned){
|
if(!myAnt.spawned){
|
||||||
observation = new AntObservation(grid.getCell(grid.getStartPoint()));
|
myAnt.spawned = true;
|
||||||
|
myAnt.pos = grid.getStartPoint();
|
||||||
|
|
||||||
|
observation = new AntObservation(grid.getCell(myAnt.pos), myAnt.pos, myAnt.hasFood);
|
||||||
newState = antAgent.feedObservation(observation);
|
newState = antAgent.feedObservation(observation);
|
||||||
return new StepResult(newState, 0.0, false, "Just spawned on the map");
|
reward = 0.0;
|
||||||
|
return new StepResult(newState, reward, false, "Just spawned on the map");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Cell currentCell = grid.getCell(myAnt.pos);
|
||||||
|
Point potentialNextPos = new Point(myAnt.pos.x, myAnt.pos.y);
|
||||||
|
boolean stayOnCell = true;
|
||||||
|
// flag to enable a check if all food has been collected only fired if food was dropped
|
||||||
|
// on the starting position
|
||||||
|
boolean checkCompletion = false;
|
||||||
|
|
||||||
switch (action.getValue()) {
|
switch (action.getValue()) {
|
||||||
case MOVE_UP:
|
case MOVE_UP:
|
||||||
|
potentialNextPos.y -= 1;
|
||||||
|
stayOnCell = false;
|
||||||
break;
|
break;
|
||||||
case MOVE_RIGHT:
|
case MOVE_RIGHT:
|
||||||
|
potentialNextPos.x += 1;
|
||||||
|
stayOnCell = false;
|
||||||
break;
|
break;
|
||||||
case MOVE_DOWN:
|
case MOVE_DOWN:
|
||||||
|
potentialNextPos.y += 1;
|
||||||
|
stayOnCell = false;
|
||||||
break;
|
break;
|
||||||
case MOVE_LEFT:
|
case MOVE_LEFT:
|
||||||
|
potentialNextPos.x -= 1;
|
||||||
|
stayOnCell = false;
|
||||||
break;
|
break;
|
||||||
case PICK_UP:
|
case PICK_UP:
|
||||||
|
if(myAnt.hasFood){
|
||||||
|
// Ant tries to pick up food but can only hold one piece
|
||||||
|
reward = Reward.FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY;
|
||||||
|
}else if(currentCell.getFood() == 0){
|
||||||
|
// Ant tries to pick up food on cell that has no food on it
|
||||||
|
reward = Reward.FOOD_PICK_UP_FAIL_NO_FOOD;
|
||||||
|
}else if(currentCell.getFood() > 0){
|
||||||
|
// Ant successfully picks up food
|
||||||
|
currentCell.setFood(currentCell.getFood() - 1);
|
||||||
|
myAnt.hasFood = true;
|
||||||
|
reward = Reward.FOOD_DROP_DOWN_SUCCESS;
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
case DROP_DOWN:
|
case DROP_DOWN:
|
||||||
|
if(!myAnt.hasFood){
|
||||||
|
// Ant had no food to drop
|
||||||
|
reward = Reward.FOOD_DROP_DOWN_FAIL_NO_FOOD;
|
||||||
|
}else{
|
||||||
|
// Drop food onto the ground
|
||||||
|
currentCell.setFood(currentCell.getFood() + 1);
|
||||||
|
myAnt.hasFood = false;
|
||||||
|
|
||||||
|
// negative reward if the agent drops food on any other field
|
||||||
|
// than the starting point
|
||||||
|
if(currentCell.getType() != CellType.START){
|
||||||
|
reward = Reward.FOOD_DROP_DOWN_FAIL_NOT_START;
|
||||||
|
}else{
|
||||||
|
reward = Reward.FOOD_DROP_DOWN_SUCCESS;
|
||||||
|
checkCompletion = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw new RuntimeException(String.format("Action <%s> is not a valid action!", action.toString()));
|
throw new RuntimeException(String.format("Action <%s> is not a valid action!", action.toString()));
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// movement action was selected
|
||||||
|
if(!stayOnCell){
|
||||||
|
if(!isInGrid(potentialNextPos)){
|
||||||
|
stayOnCell = true;
|
||||||
|
reward = Reward.RAN_INTO_WALL;
|
||||||
|
}else if(hitObstacle(potentialNextPos)){
|
||||||
|
stayOnCell = true;
|
||||||
|
reward = Reward.RAN_INTO_OBSTACLE;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// valid movement
|
||||||
|
if(!stayOnCell){
|
||||||
|
myAnt.pos = potentialNextPos;
|
||||||
|
if(antAgent.getCell(myAnt.pos).getType() == CellType.UNKNOWN){
|
||||||
|
// the ant will move to a cell that was previously unknown
|
||||||
|
reward = Reward.UNKNOWN_FIELD_EXPLORED;
|
||||||
|
}else{
|
||||||
|
reward = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// get observation after action was computed
|
||||||
|
observation = new AntObservation(grid.getCell(myAnt.pos), myAnt.pos, myAnt.hasFood);
|
||||||
|
|
||||||
|
// let the ant agent process the observation to create a valid markov state
|
||||||
newState = antAgent.feedObservation(observation);
|
newState = antAgent.feedObservation(observation);
|
||||||
return new StepResult(newState, 0.0, false, "");
|
|
||||||
|
if(checkCompletion){
|
||||||
|
done = grid.isAllFoodCollected();
|
||||||
|
}
|
||||||
|
|
||||||
|
if(++tick == maxEpisodeTicks){
|
||||||
|
done = true;
|
||||||
|
}
|
||||||
|
return new StepResult(newState, reward, done, info);
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean isInGrid(Point pos){
|
||||||
|
return pos.x > 0 && pos.x < grid.getWidth() && pos.y > 0 && pos.y < grid.getHeight();
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean hitObstacle(Point pos){
|
||||||
|
return grid.getCell(pos).getType() == CellType.OBSTACLE;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void reset() {
|
public void reset() {
|
||||||
|
@ -79,6 +179,9 @@ public class AntWorld {
|
||||||
myAnt = new MyAnt();
|
myAnt = new MyAnt();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void setMaxEpisodeLength(int maxTicks){
|
||||||
|
this.maxEpisodeTicks = maxTicks;
|
||||||
|
}
|
||||||
public Point getSpawningPoint(){
|
public Point getSpawningPoint(){
|
||||||
return grid.getStartPoint();
|
return grid.getStartPoint();
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,8 +20,22 @@ public class Cell {
|
||||||
food = foodAmount;
|
food = foodAmount;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Cell(Cell c){
|
||||||
|
this.pos = new Point(c.pos.x, c.pos.y);
|
||||||
|
this.food = c.getFood();
|
||||||
|
this.type = c.getType();
|
||||||
|
}
|
||||||
|
|
||||||
public Cell( Point pos, CellType cellType){
|
public Cell( Point pos, CellType cellType){
|
||||||
this(pos, cellType, 0);
|
this(pos, cellType, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object obj){
|
||||||
|
if(obj instanceof Cell){
|
||||||
|
Cell cell = (Cell) obj;
|
||||||
|
return this.type == cell.getType() && this.food == cell.getFood() && this.pos.x == cell.getPos().x && this.pos.y ==cell.getPos().y;
|
||||||
|
}
|
||||||
|
return super.equals(obj);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,6 +41,18 @@ public class Grid {
|
||||||
return start;
|
return start;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public boolean isAllFoodCollected(){
|
||||||
|
for(int x = 0; x < width; ++x){
|
||||||
|
for(int y = 0; y < height; ++y){
|
||||||
|
if(grid[x][y].getFood() > 0){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
public Cell[][] getGrid(){
|
public Cell[][] getGrid(){
|
||||||
return grid;
|
return grid;
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
package evironment.antGame;
|
||||||
|
|
||||||
|
public class Reward {
|
||||||
|
public static final double FOOD_PICK_UP_SUCCESS = 1;
|
||||||
|
public static final double FOOD_PICK_UP_FAIL_NO_FOOD = -1000;
|
||||||
|
public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = -1000;
|
||||||
|
|
||||||
|
public static final double FOOD_DROP_DOWN_FAIL_NO_FOOD = -1000;
|
||||||
|
public static final double FOOD_DROP_DOWN_FAIL_NOT_START = -1000;
|
||||||
|
public static final double FOOD_DROP_DOWN_SUCCESS = 1000;
|
||||||
|
|
||||||
|
public static final double UNKNOWN_FIELD_EXPLORED = 1;
|
||||||
|
|
||||||
|
public static final double RAN_INTO_WALL = -100;
|
||||||
|
public static final double RAN_INTO_OBSTACLE = -100;
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue