add Random-, Greedy and EGreedy-Policy and first implementation of monte carlo method

- fixed bug regarding wrong generation of hashCode. hashCodes needs to be equal across equal objects. Compute hashCode on final states once and return this value instead of computing it every time .hashCode() gets called.
-
This commit is contained in:
Jan Löwenstrom 2019-12-09 23:21:48 +01:00
parent 0100f2e82a
commit 55d8bbf5dc
18 changed files with 290 additions and 69 deletions

View File

@ -0,0 +1,4 @@
package core;
public interface Action {
}

View File

@ -1,5 +1,6 @@
package core; package core;
public interface Environment<A extends Enum> { public interface Environment<A extends Enum> {
StepResult step(A action); StepResultEnvironment step(A action);
State reset();
} }

View File

@ -55,11 +55,10 @@ public class StateActionHashTable<A extends Enum> implements StateActionTable<A>
@Override @Override
public Map<A, Double> getActionValues(State state) { public Map<A, Double> getActionValues(State state) {
Map<A, Double> actionValues = table.get(state); if(table.get(state) == null){
if(actionValues == null){ table.put(state, createDefaultActionValues());
actionValues = createDefaultActionValues();
} }
return actionValues; return table.get(state);
} }
public static void main(String[] args) { public static void main(String[] args) {

View File

@ -2,14 +2,11 @@ package core;
import lombok.AllArgsConstructor; import lombok.AllArgsConstructor;
import lombok.Getter; import lombok.Getter;
import lombok.Setter;
@Getter
@Setter
@AllArgsConstructor @AllArgsConstructor
public class StepResult { @Getter
public class StepResult<A extends Enum> {
private State state; private State state;
private A action;
private double reward; private double reward;
private boolean done;
private String info;
} }

View File

@ -0,0 +1,15 @@
package core;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Setter;
@Getter
@Setter
@AllArgsConstructor
public class StepResultEnvironment {
private State state;
private double reward;
private boolean done;
private String info;
}

View File

@ -0,0 +1,27 @@
package core.algo;
import core.DiscreteActionSpace;
import core.Environment;
import core.StateActionTable;
import core.policy.Policy;
public abstract class Learning<A extends Enum> {
protected Policy<A> policy;
protected DiscreteActionSpace<A> actionSpace;
protected StateActionTable<A> stateActionTable;
protected Environment<A> environment;
protected float discountFactor;
protected float epsilon;
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon){
this.environment = environment;
this.actionSpace = actionSpace;
this.discountFactor = discountFactor;
this.epsilon = epsilon;
}
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace){
this(environment, actionSpace, 1.0f, 0.1f);
}
public abstract void learn(int nrOfEpisodes, int delay);
}

View File

@ -0,0 +1,75 @@
package core.algo.MC;
import core.*;
import core.algo.Learning;
import core.policy.EpsilonGreedyPolicy;
import javafx.util.Pair;
import java.util.*;
public class MonteCarloOnPolicyEGreedy<A extends Enum> extends Learning<A> {
public MonteCarloOnPolicyEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace) {
super(environment, actionSpace);
discountFactor = 1f;
this.policy = new EpsilonGreedyPolicy<>(0.1f);
this.stateActionTable = new StateActionHashTable<>(actionSpace);
}
@Override
public void learn(int nrOfEpisodes, int delay) {
Map<Pair<State, A>, Double> returnSum = new HashMap<>();
Map<Pair<State, A>, Integer> returnCount = new HashMap<>();
for(int i = 0; i < nrOfEpisodes; ++i) {
List<StepResult<A>> episode = new ArrayList<>();
State state = environment.reset();
for(int j=0; j < 100; ++j){
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
A chosenAction = policy.chooseAction(actionValues);
StepResultEnvironment envResult = environment.step(chosenAction);
State nextState = envResult.getState();
episode.add(new StepResult<>(state, chosenAction, envResult.getReward()));
if(envResult.isDone()) break;
state = nextState;
try {
Thread.sleep(10);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
Set<Pair<State, A>> stateActionPairs = new HashSet<>();
for(StepResult<A> sr: episode){
stateActionPairs.add(new Pair<>(sr.getState(), sr.getAction()));
}
for(Pair<State, A> stateActionPair: stateActionPairs){
int firstOccurenceIndex = 0;
// find first occurance of state action pair
for(StepResult<A> sr: episode){
if(stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())){
break;
}
firstOccurenceIndex++;
}
double G = 0;
for(int l = firstOccurenceIndex; l < episode.size(); ++l){
G += episode.get(l).getReward() * (Math.pow(discountFactor, l - firstOccurenceIndex));
}
// slick trick to add G to the entry.
// if the key does not exists, it will create a new entry with G as default value
returnSum.merge(stateActionPair, G, Double::sum);
returnCount.merge(stateActionPair, 1, Integer::sum);
stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair));
}
}
}
}

View File

@ -0,0 +1,4 @@
package core.algo.TD;
public class TemporalDifferenceOnPolicy {
}

View File

@ -0,0 +1,35 @@
package core.policy;
import core.RNG;
import java.util.Map;
/**
* To prevent the agent from getting stuck only using the "best" action
* according to the current learning history, this policy
* will take random action with the probability of epsilon.
* (random action space includes the best action as well)
*
* @param <A> Discrete Action Enum
*/
public class EpsilonGreedyPolicy<A extends Enum> implements Policy<A>{
private float epsilon;
private RandomPolicy<A> randomPolicy;
private GreedyPolicy<A> greedyPolicy;
public EpsilonGreedyPolicy(float epsilon){
this.epsilon = epsilon;
randomPolicy = new RandomPolicy<>();
greedyPolicy = new GreedyPolicy<>();
}
@Override
public A chooseAction(Map<A, Double> actionValues) {
if(RNG.getRandom().nextFloat() < epsilon){
// Take random action
return randomPolicy.chooseAction(actionValues);
}else{
// Take the action with the highest value
return greedyPolicy.chooseAction(actionValues);
}
}
}

View File

@ -0,0 +1,32 @@
package core.policy;
import core.RNG;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
public class GreedyPolicy<A extends Enum> implements Policy<A> {
@Override
public A chooseAction(Map<A, Double> actionValues) {
if(actionValues.size() == 0) throw new RuntimeException("Empty actionActionValues set");
Double highestValueAction = null;
List<A> equalHigh = new ArrayList<>();
for(Map.Entry<A, Double> actionValue : actionValues.entrySet()){
System.out.println(actionValue.getKey()+ " " + actionValue.getValue() );
if(highestValueAction == null || highestValueAction < actionValue.getValue()){
highestValueAction = actionValue.getValue();
equalHigh.clear();
equalHigh.add(actionValue.getKey());
}else if(highestValueAction.equals(actionValue.getValue())){
equalHigh.add(actionValue.getKey());
}
}
return equalHigh.get(RNG.getRandom().nextInt(equalHigh.size()));
}
}

View File

@ -0,0 +1,7 @@
package core.policy;
import java.util.Map;
public interface Policy<A extends Enum> {
A chooseAction(Map<A, Double> actionValues);
}

View File

@ -0,0 +1,17 @@
package core.policy;
import core.RNG;
import java.util.Map;
public class RandomPolicy<A extends Enum> implements Policy<A>{
@Override
public A chooseAction(Map<A, Double> actionValues) {
int idx = RNG.getRandom().nextInt(actionValues.size());
int i = 0;
for(A action : actionValues.keySet()){
if(i++ == idx) return action;
}
return null;
}
}

View File

@ -7,7 +7,6 @@ import lombok.Setter;
import java.awt.*; import java.awt.*;
@AllArgsConstructor
@Getter @Getter
@Setter @Setter
public class Ant { public class Ant {
@ -15,10 +14,14 @@ public class Ant {
@Setter(AccessLevel.NONE) @Setter(AccessLevel.NONE)
private Point pos; private Point pos;
private int points; private int points;
private boolean spawned;
@Getter(AccessLevel.NONE) @Getter(AccessLevel.NONE)
private boolean hasFood; private boolean hasFood;
public Ant(){
pos = new Point();
points = 0;
hasFood = false;
}
public boolean hasFood(){ public boolean hasFood(){
return hasFood; return hasFood;
} }

View File

@ -9,7 +9,6 @@ public class AntAgent {
public AntAgent(int width, int height){ public AntAgent(int width, int height){
knownWorld = new Cell[width][height]; knownWorld = new Cell[width][height];
initUnknownWorld();
} }
/** /**
@ -24,7 +23,7 @@ public class AntAgent {
return new AntState(knownWorld, observation.getPos(), observation.hasFood()); return new AntState(knownWorld, observation.getPos(), observation.hasFood());
} }
private void initUnknownWorld(){ public void initUnknownWorld(){
for(int x = 0; x < knownWorld.length; ++x){ for(int x = 0; x < knownWorld.length; ++x){
for(int y = 0; y < knownWorld[x].length; ++y){ for(int y = 0; y < knownWorld[x].length; ++y){
knownWorld[x][y] = new Cell(new Point(x,y), CellType.UNKNOWN); knownWorld[x][y] = new Cell(new Point(x,y), CellType.UNKNOWN);

View File

@ -11,17 +11,39 @@ import java.util.Arrays;
* and therefor has to be deep copied * and therefor has to be deep copied
*/ */
public class AntState implements State { public class AntState implements State {
private Cell[][] knownWorld; private final Cell[][] knownWorld;
private Point pos; private final Point pos;
private boolean hasFood; private final boolean hasFood;
private final int computedHash;
public AntState(Cell[][] knownWorld, Point antPosition, boolean hasFood){ public AntState(Cell[][] knownWorld, Point antPosition, boolean hasFood){
this.knownWorld = deepCopyCellGrid(knownWorld); this.knownWorld = deepCopyCellGrid(knownWorld);
this.pos = deepCopyAntPosition(antPosition); this.pos = deepCopyAntPosition(antPosition);
this.hasFood = hasFood; this.hasFood = hasFood;
computedHash = computeHash();
} }
private int computeHash(){
int hash = 7;
int prime = 31;
int unknown = 0;
int diff = 0;
for (int i = 0; i < knownWorld.length; i++) {
for (int j = 0; j < knownWorld[i].length; j++) {
if(knownWorld[i][j].getType() == CellType.UNKNOWN){
unknown += 1;
}else{
diff +=1;
}
}
}
hash = prime * hash + unknown;
hash = prime * hash * diff;
hash = prime * hash + (hasFood ? 1:0);
hash = prime * hash + pos.hashCode();
return hash;
}
private Cell[][] deepCopyCellGrid(Cell[][] toCopy){ private Cell[][] deepCopyCellGrid(Cell[][] toCopy){
Cell[][] cells = new Cell[toCopy.length][toCopy[0].length]; Cell[][] cells = new Cell[toCopy.length][toCopy[0].length];
for (int i = 0; i < cells.length; i++) { for (int i = 0; i < cells.length; i++) {
@ -45,12 +67,7 @@ public class AntState implements State {
//TODO: make this a utility function to generate hash Code based upon 2 prime numbers //TODO: make this a utility function to generate hash Code based upon 2 prime numbers
@Override @Override
public int hashCode(){ public int hashCode(){
int hash = 7; return computedHash;
int prime = 31;
hash = prime * hash + Arrays.hashCode(knownWorld);
hash = prime * hash + (hasFood ? 1:0);
hash = prime * hash + pos.hashCode();
return hash;
} }
@Override @Override

View File

@ -1,10 +1,11 @@
package evironment.antGame; package evironment.antGame;
import core.*; import core.*;
import core.algo.Learning;
import core.algo.MC.MonteCarloOnPolicyEGreedy;
import evironment.antGame.gui.MainFrame; import evironment.antGame.gui.MainFrame;
import javax.swing.*;
import java.awt.*; import java.awt.*;
public class AntWorld implements Environment<AntAction>{ public class AntWorld implements Environment<AntAction>{
@ -39,9 +40,8 @@ public class AntWorld implements Environment<AntAction>{
public AntWorld(int width, int height, double foodDensity){ public AntWorld(int width, int height, double foodDensity){
grid = new Grid(width, height, foodDensity); grid = new Grid(width, height, foodDensity);
antAgent = new AntAgent(width, height); antAgent = new AntAgent(width, height);
myAnt = new Ant(new Point(-1,-1), 0, false, false); myAnt = new Ant();
gui = new MainFrame(this, antAgent); gui = new MainFrame(this, antAgent);
tick = 0;
maxEpisodeTicks = 1000; maxEpisodeTicks = 1000;
reset(); reset();
} }
@ -55,23 +55,13 @@ public class AntWorld implements Environment<AntAction>{
} }
@Override @Override
public StepResult step(AntAction action){ public StepResultEnvironment step(AntAction action){
AntObservation observation; AntObservation observation;
State newState; State newState;
double reward = 0; double reward = 0;
String info = ""; String info = "";
boolean done = false; boolean done = false;
if(!myAnt.isSpawned()){
myAnt.setSpawned(true);
myAnt.getPos().setLocation(grid.getStartPoint());
observation = new AntObservation(grid.getCell(myAnt.getPos()), myAnt.getPos(), myAnt.hasFood());
newState = antAgent.feedObservation(observation);
reward = 0.0;
++tick;
return new StepResult(newState, reward, false, "Just spawned on the map");
}
Cell currentCell = grid.getCell(myAnt.getPos()); Cell currentCell = grid.getCell(myAnt.getPos());
Point potentialNextPos = new Point(myAnt.getPos().x, myAnt.getPos().y); Point potentialNextPos = new Point(myAnt.getPos().x, myAnt.getPos().y);
boolean stayOnCell = true; boolean stayOnCell = true;
@ -107,7 +97,7 @@ public class AntWorld implements Environment<AntAction>{
// Ant successfully picks up food // Ant successfully picks up food
currentCell.setFood(currentCell.getFood() - 1); currentCell.setFood(currentCell.getFood() - 1);
myAnt.setHasFood(true); myAnt.setHasFood(true);
reward = Reward.FOOD_DROP_DOWN_SUCCESS; reward = Reward.FOOD_PICK_UP_SUCCESS;
} }
break; break;
case DROP_DOWN: case DROP_DOWN:
@ -169,24 +159,30 @@ public class AntWorld implements Environment<AntAction>{
if(++tick == maxEpisodeTicks){ if(++tick == maxEpisodeTicks){
done = true; done = true;
} }
return new StepResult(newState, reward, done, info);
StepResultEnvironment result = new StepResultEnvironment(newState, reward, done, info);
getGui().update(action, result);
return result;
} }
private boolean isInGrid(Point pos){ private boolean isInGrid(Point pos){
return pos.x > 0 && pos.x < grid.getWidth() && pos.y > 0 && pos.y < grid.getHeight(); return pos.x >= 0 && pos.x < grid.getWidth() && pos.y >= 0 && pos.y < grid.getHeight();
} }
private boolean hitObstacle(Point pos){ private boolean hitObstacle(Point pos){
return grid.getCell(pos).getType() == CellType.OBSTACLE; return grid.getCell(pos).getType() == CellType.OBSTACLE;
} }
public void reset() { public State reset() {
RNG.reseed(); RNG.reseed();
grid.initRandomWorld(); grid.initRandomWorld();
myAnt.getPos().setLocation(-1,-1); antAgent.initUnknownWorld();
tick = 0;
myAnt.getPos().setLocation(grid.getStartPoint());
myAnt.setPoints(0); myAnt.setPoints(0);
myAnt.setHasFood(false); myAnt.setHasFood(false);
myAnt.setSpawned(false); AntObservation observation = new AntObservation(grid.getCell(myAnt.getPos()), myAnt.getPos(), myAnt.hasFood());
return antAgent.feedObservation(observation);
} }
public void setMaxEpisodeLength(int maxTicks){ public void setMaxEpisodeLength(int maxTicks){
@ -207,21 +203,14 @@ public class AntWorld implements Environment<AntAction>{
public Ant getAnt(){ public Ant getAnt(){
return myAnt; return myAnt;
} }
public static void main(String[] args) { public static void main(String[] args) {
RNG.setSeed(1993); RNG.setSeed(1993);
AntWorld a = new AntWorld(10, 10, 0.1);
ListDiscreteActionSpace<AntAction> actionSpace =
new ListDiscreteActionSpace<>(AntAction.MOVE_LEFT, AntAction.MOVE_RIGHT);
for(int i = 0; i< 1000; ++i){ Learning<AntAction> monteCarlo = new MonteCarloOnPolicyEGreedy<>(
AntAction selectedAction = actionSpace.getAllActions().get(RNG.getRandom().nextInt(actionSpace.getNumberOfAction())); new AntWorld(3, 3, 0.1),
StepResult step = a.step(selectedAction); new ListDiscreteActionSpace<>(AntAction.values())
SwingUtilities.invokeLater(()-> a.getGui().update(selectedAction, step)); );
try { monteCarlo.learn(100,5);
Thread.sleep(100);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
} }
} }

View File

@ -1,17 +1,17 @@
package evironment.antGame; package evironment.antGame;
public class Reward { public class Reward {
public static final double FOOD_PICK_UP_SUCCESS = 1; public static final double FOOD_PICK_UP_SUCCESS = 0;
public static final double FOOD_PICK_UP_FAIL_NO_FOOD = -1000; public static final double FOOD_PICK_UP_FAIL_NO_FOOD = 0;
public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = -1000; public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = 0;
public static final double FOOD_DROP_DOWN_FAIL_NO_FOOD = -1000; public static final double FOOD_DROP_DOWN_FAIL_NO_FOOD = 0;
public static final double FOOD_DROP_DOWN_FAIL_NOT_START = -1000; public static final double FOOD_DROP_DOWN_FAIL_NOT_START = 0;
public static final double FOOD_DROP_DOWN_SUCCESS = 1000; public static final double FOOD_DROP_DOWN_SUCCESS = 1000;
public static final double UNKNOWN_FIELD_EXPLORED = 1; public static final double UNKNOWN_FIELD_EXPLORED = 0;
public static final double RAN_INTO_WALL = -100; public static final double RAN_INTO_WALL = 0;
public static final double RAN_INTO_OBSTACLE = -100; public static final double RAN_INTO_OBSTACLE = 0;
} }

View File

@ -1,6 +1,6 @@
package evironment.antGame.gui; package evironment.antGame.gui;
import core.StepResult; import core.StepResultEnvironment;
import evironment.antGame.AntAction; import evironment.antGame.AntAction;
import evironment.antGame.AntAgent; import evironment.antGame.AntAgent;
import evironment.antGame.AntWorld; import evironment.antGame.AntWorld;
@ -33,9 +33,9 @@ public class MainFrame extends JFrame {
setVisible(true); setVisible(true);
} }
public void update(AntAction lastAction, StepResult stepResult){ public void update(AntAction lastAction, StepResultEnvironment stepResultEnvironment){
historyPanel.addText(String.format("Tick %d: \t Selected action: %s \t Reward: %f \t Info: %s \n totalPoints: %d \t hasFood: %b \t ", historyPanel.addText(String.format("Tick %d: \t Selected action: %s \t Reward: %f \t Info: %s \n totalPoints: %d \t hasFood: %b \t ",
antWorld.getTick(), lastAction.toString(), stepResult.getReward(), stepResult.getInfo(), antWorld.getAnt().getPoints(), antWorld.getAnt().hasFood())); antWorld.getTick(), lastAction.toString(), stepResultEnvironment.getReward(), stepResultEnvironment.getInfo(), antWorld.getAnt().getPoints(), antWorld.getAnt().hasFood()));
repaint(); repaint();
} }