add Random-, Greedy and EGreedy-Policy and first implementation of monte carlo method

- fixed bug regarding wrong generation of hashCode. hashCodes needs to be equal across equal objects. Compute hashCode on final states once and return this value instead of computing it every time .hashCode() gets called.
-
This commit is contained in:
Jan Löwenstrom 2019-12-09 23:21:48 +01:00
parent 0100f2e82a
commit 55d8bbf5dc
18 changed files with 290 additions and 69 deletions

View File

@ -0,0 +1,4 @@
package core;
public interface Action {
}

View File

@ -1,5 +1,6 @@
package core;
public interface Environment<A extends Enum> {
StepResult step(A action);
StepResultEnvironment step(A action);
State reset();
}

View File

@ -55,11 +55,10 @@ public class StateActionHashTable<A extends Enum> implements StateActionTable<A>
@Override
public Map<A, Double> getActionValues(State state) {
Map<A, Double> actionValues = table.get(state);
if(actionValues == null){
actionValues = createDefaultActionValues();
if(table.get(state) == null){
table.put(state, createDefaultActionValues());
}
return actionValues;
return table.get(state);
}
public static void main(String[] args) {

View File

@ -2,14 +2,11 @@ package core;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Setter;
@Getter
@Setter
@AllArgsConstructor
public class StepResult {
@Getter
public class StepResult<A extends Enum> {
private State state;
private A action;
private double reward;
private boolean done;
private String info;
}

View File

@ -0,0 +1,15 @@
package core;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Setter;
@Getter
@Setter
@AllArgsConstructor
public class StepResultEnvironment {
private State state;
private double reward;
private boolean done;
private String info;
}

View File

@ -0,0 +1,27 @@
package core.algo;
import core.DiscreteActionSpace;
import core.Environment;
import core.StateActionTable;
import core.policy.Policy;
public abstract class Learning<A extends Enum> {
protected Policy<A> policy;
protected DiscreteActionSpace<A> actionSpace;
protected StateActionTable<A> stateActionTable;
protected Environment<A> environment;
protected float discountFactor;
protected float epsilon;
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon){
this.environment = environment;
this.actionSpace = actionSpace;
this.discountFactor = discountFactor;
this.epsilon = epsilon;
}
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace){
this(environment, actionSpace, 1.0f, 0.1f);
}
public abstract void learn(int nrOfEpisodes, int delay);
}

View File

@ -0,0 +1,75 @@
package core.algo.MC;
import core.*;
import core.algo.Learning;
import core.policy.EpsilonGreedyPolicy;
import javafx.util.Pair;
import java.util.*;
public class MonteCarloOnPolicyEGreedy<A extends Enum> extends Learning<A> {
public MonteCarloOnPolicyEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace) {
super(environment, actionSpace);
discountFactor = 1f;
this.policy = new EpsilonGreedyPolicy<>(0.1f);
this.stateActionTable = new StateActionHashTable<>(actionSpace);
}
@Override
public void learn(int nrOfEpisodes, int delay) {
Map<Pair<State, A>, Double> returnSum = new HashMap<>();
Map<Pair<State, A>, Integer> returnCount = new HashMap<>();
for(int i = 0; i < nrOfEpisodes; ++i) {
List<StepResult<A>> episode = new ArrayList<>();
State state = environment.reset();
for(int j=0; j < 100; ++j){
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
A chosenAction = policy.chooseAction(actionValues);
StepResultEnvironment envResult = environment.step(chosenAction);
State nextState = envResult.getState();
episode.add(new StepResult<>(state, chosenAction, envResult.getReward()));
if(envResult.isDone()) break;
state = nextState;
try {
Thread.sleep(10);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
Set<Pair<State, A>> stateActionPairs = new HashSet<>();
for(StepResult<A> sr: episode){
stateActionPairs.add(new Pair<>(sr.getState(), sr.getAction()));
}
for(Pair<State, A> stateActionPair: stateActionPairs){
int firstOccurenceIndex = 0;
// find first occurance of state action pair
for(StepResult<A> sr: episode){
if(stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())){
break;
}
firstOccurenceIndex++;
}
double G = 0;
for(int l = firstOccurenceIndex; l < episode.size(); ++l){
G += episode.get(l).getReward() * (Math.pow(discountFactor, l - firstOccurenceIndex));
}
// slick trick to add G to the entry.
// if the key does not exists, it will create a new entry with G as default value
returnSum.merge(stateActionPair, G, Double::sum);
returnCount.merge(stateActionPair, 1, Integer::sum);
stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair));
}
}
}
}

View File

@ -0,0 +1,4 @@
package core.algo.TD;
public class TemporalDifferenceOnPolicy {
}

View File

@ -0,0 +1,35 @@
package core.policy;
import core.RNG;
import java.util.Map;
/**
* To prevent the agent from getting stuck only using the "best" action
* according to the current learning history, this policy
* will take random action with the probability of epsilon.
* (random action space includes the best action as well)
*
* @param <A> Discrete Action Enum
*/
public class EpsilonGreedyPolicy<A extends Enum> implements Policy<A>{
private float epsilon;
private RandomPolicy<A> randomPolicy;
private GreedyPolicy<A> greedyPolicy;
public EpsilonGreedyPolicy(float epsilon){
this.epsilon = epsilon;
randomPolicy = new RandomPolicy<>();
greedyPolicy = new GreedyPolicy<>();
}
@Override
public A chooseAction(Map<A, Double> actionValues) {
if(RNG.getRandom().nextFloat() < epsilon){
// Take random action
return randomPolicy.chooseAction(actionValues);
}else{
// Take the action with the highest value
return greedyPolicy.chooseAction(actionValues);
}
}
}

View File

@ -0,0 +1,32 @@
package core.policy;
import core.RNG;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
public class GreedyPolicy<A extends Enum> implements Policy<A> {
@Override
public A chooseAction(Map<A, Double> actionValues) {
if(actionValues.size() == 0) throw new RuntimeException("Empty actionActionValues set");
Double highestValueAction = null;
List<A> equalHigh = new ArrayList<>();
for(Map.Entry<A, Double> actionValue : actionValues.entrySet()){
System.out.println(actionValue.getKey()+ " " + actionValue.getValue() );
if(highestValueAction == null || highestValueAction < actionValue.getValue()){
highestValueAction = actionValue.getValue();
equalHigh.clear();
equalHigh.add(actionValue.getKey());
}else if(highestValueAction.equals(actionValue.getValue())){
equalHigh.add(actionValue.getKey());
}
}
return equalHigh.get(RNG.getRandom().nextInt(equalHigh.size()));
}
}

View File

@ -0,0 +1,7 @@
package core.policy;
import java.util.Map;
public interface Policy<A extends Enum> {
A chooseAction(Map<A, Double> actionValues);
}

View File

@ -0,0 +1,17 @@
package core.policy;
import core.RNG;
import java.util.Map;
public class RandomPolicy<A extends Enum> implements Policy<A>{
@Override
public A chooseAction(Map<A, Double> actionValues) {
int idx = RNG.getRandom().nextInt(actionValues.size());
int i = 0;
for(A action : actionValues.keySet()){
if(i++ == idx) return action;
}
return null;
}
}

View File

@ -7,7 +7,6 @@ import lombok.Setter;
import java.awt.*;
@AllArgsConstructor
@Getter
@Setter
public class Ant {
@ -15,10 +14,14 @@ public class Ant {
@Setter(AccessLevel.NONE)
private Point pos;
private int points;
private boolean spawned;
@Getter(AccessLevel.NONE)
private boolean hasFood;
public Ant(){
pos = new Point();
points = 0;
hasFood = false;
}
public boolean hasFood(){
return hasFood;
}

View File

@ -9,7 +9,6 @@ public class AntAgent {
public AntAgent(int width, int height){
knownWorld = new Cell[width][height];
initUnknownWorld();
}
/**
@ -24,7 +23,7 @@ public class AntAgent {
return new AntState(knownWorld, observation.getPos(), observation.hasFood());
}
private void initUnknownWorld(){
public void initUnknownWorld(){
for(int x = 0; x < knownWorld.length; ++x){
for(int y = 0; y < knownWorld[x].length; ++y){
knownWorld[x][y] = new Cell(new Point(x,y), CellType.UNKNOWN);

View File

@ -11,17 +11,39 @@ import java.util.Arrays;
* and therefor has to be deep copied
*/
public class AntState implements State {
private Cell[][] knownWorld;
private Point pos;
private boolean hasFood;
private final Cell[][] knownWorld;
private final Point pos;
private final boolean hasFood;
private final int computedHash;
public AntState(Cell[][] knownWorld, Point antPosition, boolean hasFood){
this.knownWorld = deepCopyCellGrid(knownWorld);
this.pos = deepCopyAntPosition(antPosition);
this.hasFood = hasFood;
computedHash = computeHash();
}
private int computeHash(){
int hash = 7;
int prime = 31;
int unknown = 0;
int diff = 0;
for (int i = 0; i < knownWorld.length; i++) {
for (int j = 0; j < knownWorld[i].length; j++) {
if(knownWorld[i][j].getType() == CellType.UNKNOWN){
unknown += 1;
}else{
diff +=1;
}
}
}
hash = prime * hash + unknown;
hash = prime * hash * diff;
hash = prime * hash + (hasFood ? 1:0);
hash = prime * hash + pos.hashCode();
return hash;
}
private Cell[][] deepCopyCellGrid(Cell[][] toCopy){
Cell[][] cells = new Cell[toCopy.length][toCopy[0].length];
for (int i = 0; i < cells.length; i++) {
@ -45,12 +67,7 @@ public class AntState implements State {
//TODO: make this a utility function to generate hash Code based upon 2 prime numbers
@Override
public int hashCode(){
int hash = 7;
int prime = 31;
hash = prime * hash + Arrays.hashCode(knownWorld);
hash = prime * hash + (hasFood ? 1:0);
hash = prime * hash + pos.hashCode();
return hash;
return computedHash;
}
@Override

View File

@ -1,10 +1,11 @@
package evironment.antGame;
import core.*;
import core.algo.Learning;
import core.algo.MC.MonteCarloOnPolicyEGreedy;
import evironment.antGame.gui.MainFrame;
import javax.swing.*;
import java.awt.*;
public class AntWorld implements Environment<AntAction>{
@ -39,9 +40,8 @@ public class AntWorld implements Environment<AntAction>{
public AntWorld(int width, int height, double foodDensity){
grid = new Grid(width, height, foodDensity);
antAgent = new AntAgent(width, height);
myAnt = new Ant(new Point(-1,-1), 0, false, false);
myAnt = new Ant();
gui = new MainFrame(this, antAgent);
tick = 0;
maxEpisodeTicks = 1000;
reset();
}
@ -55,23 +55,13 @@ public class AntWorld implements Environment<AntAction>{
}
@Override
public StepResult step(AntAction action){
public StepResultEnvironment step(AntAction action){
AntObservation observation;
State newState;
double reward = 0;
String info = "";
boolean done = false;
if(!myAnt.isSpawned()){
myAnt.setSpawned(true);
myAnt.getPos().setLocation(grid.getStartPoint());
observation = new AntObservation(grid.getCell(myAnt.getPos()), myAnt.getPos(), myAnt.hasFood());
newState = antAgent.feedObservation(observation);
reward = 0.0;
++tick;
return new StepResult(newState, reward, false, "Just spawned on the map");
}
Cell currentCell = grid.getCell(myAnt.getPos());
Point potentialNextPos = new Point(myAnt.getPos().x, myAnt.getPos().y);
boolean stayOnCell = true;
@ -107,7 +97,7 @@ public class AntWorld implements Environment<AntAction>{
// Ant successfully picks up food
currentCell.setFood(currentCell.getFood() - 1);
myAnt.setHasFood(true);
reward = Reward.FOOD_DROP_DOWN_SUCCESS;
reward = Reward.FOOD_PICK_UP_SUCCESS;
}
break;
case DROP_DOWN:
@ -169,24 +159,30 @@ public class AntWorld implements Environment<AntAction>{
if(++tick == maxEpisodeTicks){
done = true;
}
return new StepResult(newState, reward, done, info);
StepResultEnvironment result = new StepResultEnvironment(newState, reward, done, info);
getGui().update(action, result);
return result;
}
private boolean isInGrid(Point pos){
return pos.x > 0 && pos.x < grid.getWidth() && pos.y > 0 && pos.y < grid.getHeight();
return pos.x >= 0 && pos.x < grid.getWidth() && pos.y >= 0 && pos.y < grid.getHeight();
}
private boolean hitObstacle(Point pos){
return grid.getCell(pos).getType() == CellType.OBSTACLE;
}
public void reset() {
public State reset() {
RNG.reseed();
grid.initRandomWorld();
myAnt.getPos().setLocation(-1,-1);
antAgent.initUnknownWorld();
tick = 0;
myAnt.getPos().setLocation(grid.getStartPoint());
myAnt.setPoints(0);
myAnt.setHasFood(false);
myAnt.setSpawned(false);
AntObservation observation = new AntObservation(grid.getCell(myAnt.getPos()), myAnt.getPos(), myAnt.hasFood());
return antAgent.feedObservation(observation);
}
public void setMaxEpisodeLength(int maxTicks){
@ -207,21 +203,14 @@ public class AntWorld implements Environment<AntAction>{
public Ant getAnt(){
return myAnt;
}
public static void main(String[] args) {
RNG.setSeed(1993);
AntWorld a = new AntWorld(10, 10, 0.1);
ListDiscreteActionSpace<AntAction> actionSpace =
new ListDiscreteActionSpace<>(AntAction.MOVE_LEFT, AntAction.MOVE_RIGHT);
for(int i = 0; i< 1000; ++i){
AntAction selectedAction = actionSpace.getAllActions().get(RNG.getRandom().nextInt(actionSpace.getNumberOfAction()));
StepResult step = a.step(selectedAction);
SwingUtilities.invokeLater(()-> a.getGui().update(selectedAction, step));
try {
Thread.sleep(100);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
Learning<AntAction> monteCarlo = new MonteCarloOnPolicyEGreedy<>(
new AntWorld(3, 3, 0.1),
new ListDiscreteActionSpace<>(AntAction.values())
);
monteCarlo.learn(100,5);
}
}

View File

@ -1,17 +1,17 @@
package evironment.antGame;
public class Reward {
public static final double FOOD_PICK_UP_SUCCESS = 1;
public static final double FOOD_PICK_UP_FAIL_NO_FOOD = -1000;
public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = -1000;
public static final double FOOD_PICK_UP_SUCCESS = 0;
public static final double FOOD_PICK_UP_FAIL_NO_FOOD = 0;
public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = 0;
public static final double FOOD_DROP_DOWN_FAIL_NO_FOOD = -1000;
public static final double FOOD_DROP_DOWN_FAIL_NOT_START = -1000;
public static final double FOOD_DROP_DOWN_FAIL_NO_FOOD = 0;
public static final double FOOD_DROP_DOWN_FAIL_NOT_START = 0;
public static final double FOOD_DROP_DOWN_SUCCESS = 1000;
public static final double UNKNOWN_FIELD_EXPLORED = 1;
public static final double UNKNOWN_FIELD_EXPLORED = 0;
public static final double RAN_INTO_WALL = -100;
public static final double RAN_INTO_OBSTACLE = -100;
public static final double RAN_INTO_WALL = 0;
public static final double RAN_INTO_OBSTACLE = 0;
}

View File

@ -1,6 +1,6 @@
package evironment.antGame.gui;
import core.StepResult;
import core.StepResultEnvironment;
import evironment.antGame.AntAction;
import evironment.antGame.AntAgent;
import evironment.antGame.AntWorld;
@ -33,9 +33,9 @@ public class MainFrame extends JFrame {
setVisible(true);
}
public void update(AntAction lastAction, StepResult stepResult){
public void update(AntAction lastAction, StepResultEnvironment stepResultEnvironment){
historyPanel.addText(String.format("Tick %d: \t Selected action: %s \t Reward: %f \t Info: %s \n totalPoints: %d \t hasFood: %b \t ",
antWorld.getTick(), lastAction.toString(), stepResult.getReward(), stepResult.getInfo(), antWorld.getAnt().getPoints(), antWorld.getAnt().hasFood()));
antWorld.getTick(), lastAction.toString(), stepResultEnvironment.getReward(), stepResultEnvironment.getInfo(), antWorld.getAnt().getPoints(), antWorld.getAnt().hasFood()));
repaint();
}