add Random-, Greedy and EGreedy-Policy and first implementation of monte carlo method
- fixed bug regarding wrong generation of hashCode. hashCodes needs to be equal across equal objects. Compute hashCode on final states once and return this value instead of computing it every time .hashCode() gets called. -
This commit is contained in:
parent
0100f2e82a
commit
55d8bbf5dc
|
@ -0,0 +1,4 @@
|
|||
package core;
|
||||
|
||||
public interface Action {
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
package core;
|
||||
|
||||
public interface Environment<A extends Enum> {
|
||||
StepResult step(A action);
|
||||
StepResultEnvironment step(A action);
|
||||
State reset();
|
||||
}
|
||||
|
|
|
@ -55,11 +55,10 @@ public class StateActionHashTable<A extends Enum> implements StateActionTable<A>
|
|||
|
||||
@Override
|
||||
public Map<A, Double> getActionValues(State state) {
|
||||
Map<A, Double> actionValues = table.get(state);
|
||||
if(actionValues == null){
|
||||
actionValues = createDefaultActionValues();
|
||||
if(table.get(state) == null){
|
||||
table.put(state, createDefaultActionValues());
|
||||
}
|
||||
return actionValues;
|
||||
return table.get(state);
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
|
|
|
@ -2,14 +2,11 @@ package core;
|
|||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
@AllArgsConstructor
|
||||
public class StepResult {
|
||||
@Getter
|
||||
public class StepResult<A extends Enum> {
|
||||
private State state;
|
||||
private A action;
|
||||
private double reward;
|
||||
private boolean done;
|
||||
private String info;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
package core;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
@AllArgsConstructor
|
||||
public class StepResultEnvironment {
|
||||
private State state;
|
||||
private double reward;
|
||||
private boolean done;
|
||||
private String info;
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
package core.algo;
|
||||
|
||||
import core.DiscreteActionSpace;
|
||||
import core.Environment;
|
||||
import core.StateActionTable;
|
||||
import core.policy.Policy;
|
||||
|
||||
public abstract class Learning<A extends Enum> {
|
||||
protected Policy<A> policy;
|
||||
protected DiscreteActionSpace<A> actionSpace;
|
||||
protected StateActionTable<A> stateActionTable;
|
||||
protected Environment<A> environment;
|
||||
protected float discountFactor;
|
||||
protected float epsilon;
|
||||
|
||||
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon){
|
||||
this.environment = environment;
|
||||
this.actionSpace = actionSpace;
|
||||
this.discountFactor = discountFactor;
|
||||
this.epsilon = epsilon;
|
||||
}
|
||||
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace){
|
||||
this(environment, actionSpace, 1.0f, 0.1f);
|
||||
}
|
||||
|
||||
public abstract void learn(int nrOfEpisodes, int delay);
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
package core.algo.MC;
|
||||
|
||||
import core.*;
|
||||
import core.algo.Learning;
|
||||
import core.policy.EpsilonGreedyPolicy;
|
||||
import javafx.util.Pair;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
public class MonteCarloOnPolicyEGreedy<A extends Enum> extends Learning<A> {
|
||||
|
||||
public MonteCarloOnPolicyEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace) {
|
||||
super(environment, actionSpace);
|
||||
discountFactor = 1f;
|
||||
this.policy = new EpsilonGreedyPolicy<>(0.1f);
|
||||
this.stateActionTable = new StateActionHashTable<>(actionSpace);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void learn(int nrOfEpisodes, int delay) {
|
||||
|
||||
Map<Pair<State, A>, Double> returnSum = new HashMap<>();
|
||||
Map<Pair<State, A>, Integer> returnCount = new HashMap<>();
|
||||
|
||||
for(int i = 0; i < nrOfEpisodes; ++i) {
|
||||
|
||||
List<StepResult<A>> episode = new ArrayList<>();
|
||||
State state = environment.reset();
|
||||
for(int j=0; j < 100; ++j){
|
||||
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
|
||||
A chosenAction = policy.chooseAction(actionValues);
|
||||
StepResultEnvironment envResult = environment.step(chosenAction);
|
||||
State nextState = envResult.getState();
|
||||
episode.add(new StepResult<>(state, chosenAction, envResult.getReward()));
|
||||
|
||||
if(envResult.isDone()) break;
|
||||
|
||||
state = nextState;
|
||||
|
||||
try {
|
||||
Thread.sleep(10);
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
Set<Pair<State, A>> stateActionPairs = new HashSet<>();
|
||||
|
||||
for(StepResult<A> sr: episode){
|
||||
stateActionPairs.add(new Pair<>(sr.getState(), sr.getAction()));
|
||||
}
|
||||
|
||||
for(Pair<State, A> stateActionPair: stateActionPairs){
|
||||
int firstOccurenceIndex = 0;
|
||||
// find first occurance of state action pair
|
||||
for(StepResult<A> sr: episode){
|
||||
if(stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())){
|
||||
break;
|
||||
}
|
||||
firstOccurenceIndex++;
|
||||
}
|
||||
|
||||
double G = 0;
|
||||
for(int l = firstOccurenceIndex; l < episode.size(); ++l){
|
||||
G += episode.get(l).getReward() * (Math.pow(discountFactor, l - firstOccurenceIndex));
|
||||
}
|
||||
// slick trick to add G to the entry.
|
||||
// if the key does not exists, it will create a new entry with G as default value
|
||||
returnSum.merge(stateActionPair, G, Double::sum);
|
||||
returnCount.merge(stateActionPair, 1, Integer::sum);
|
||||
stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,4 @@
|
|||
package core.algo.TD;
|
||||
|
||||
public class TemporalDifferenceOnPolicy {
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
package core.policy;
|
||||
|
||||
import core.RNG;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* To prevent the agent from getting stuck only using the "best" action
|
||||
* according to the current learning history, this policy
|
||||
* will take random action with the probability of epsilon.
|
||||
* (random action space includes the best action as well)
|
||||
*
|
||||
* @param <A> Discrete Action Enum
|
||||
*/
|
||||
public class EpsilonGreedyPolicy<A extends Enum> implements Policy<A>{
|
||||
private float epsilon;
|
||||
private RandomPolicy<A> randomPolicy;
|
||||
private GreedyPolicy<A> greedyPolicy;
|
||||
|
||||
public EpsilonGreedyPolicy(float epsilon){
|
||||
this.epsilon = epsilon;
|
||||
randomPolicy = new RandomPolicy<>();
|
||||
greedyPolicy = new GreedyPolicy<>();
|
||||
}
|
||||
@Override
|
||||
public A chooseAction(Map<A, Double> actionValues) {
|
||||
if(RNG.getRandom().nextFloat() < epsilon){
|
||||
// Take random action
|
||||
return randomPolicy.chooseAction(actionValues);
|
||||
}else{
|
||||
// Take the action with the highest value
|
||||
return greedyPolicy.chooseAction(actionValues);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
package core.policy;
|
||||
|
||||
import core.RNG;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class GreedyPolicy<A extends Enum> implements Policy<A> {
|
||||
|
||||
@Override
|
||||
public A chooseAction(Map<A, Double> actionValues) {
|
||||
if(actionValues.size() == 0) throw new RuntimeException("Empty actionActionValues set");
|
||||
|
||||
Double highestValueAction = null;
|
||||
|
||||
List<A> equalHigh = new ArrayList<>();
|
||||
|
||||
for(Map.Entry<A, Double> actionValue : actionValues.entrySet()){
|
||||
System.out.println(actionValue.getKey()+ " " + actionValue.getValue() );
|
||||
if(highestValueAction == null || highestValueAction < actionValue.getValue()){
|
||||
highestValueAction = actionValue.getValue();
|
||||
equalHigh.clear();
|
||||
equalHigh.add(actionValue.getKey());
|
||||
}else if(highestValueAction.equals(actionValue.getValue())){
|
||||
equalHigh.add(actionValue.getKey());
|
||||
}
|
||||
}
|
||||
|
||||
return equalHigh.get(RNG.getRandom().nextInt(equalHigh.size()));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
package core.policy;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
public interface Policy<A extends Enum> {
|
||||
A chooseAction(Map<A, Double> actionValues);
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
package core.policy;
|
||||
|
||||
import core.RNG;
|
||||
import java.util.Map;
|
||||
|
||||
public class RandomPolicy<A extends Enum> implements Policy<A>{
|
||||
@Override
|
||||
public A chooseAction(Map<A, Double> actionValues) {
|
||||
int idx = RNG.getRandom().nextInt(actionValues.size());
|
||||
int i = 0;
|
||||
for(A action : actionValues.keySet()){
|
||||
if(i++ == idx) return action;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
}
|
|
@ -7,7 +7,6 @@ import lombok.Setter;
|
|||
|
||||
import java.awt.*;
|
||||
|
||||
@AllArgsConstructor
|
||||
@Getter
|
||||
@Setter
|
||||
public class Ant {
|
||||
|
@ -15,10 +14,14 @@ public class Ant {
|
|||
@Setter(AccessLevel.NONE)
|
||||
private Point pos;
|
||||
private int points;
|
||||
private boolean spawned;
|
||||
@Getter(AccessLevel.NONE)
|
||||
private boolean hasFood;
|
||||
|
||||
public Ant(){
|
||||
pos = new Point();
|
||||
points = 0;
|
||||
hasFood = false;
|
||||
}
|
||||
public boolean hasFood(){
|
||||
return hasFood;
|
||||
}
|
||||
|
|
|
@ -9,7 +9,6 @@ public class AntAgent {
|
|||
|
||||
public AntAgent(int width, int height){
|
||||
knownWorld = new Cell[width][height];
|
||||
initUnknownWorld();
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -24,7 +23,7 @@ public class AntAgent {
|
|||
return new AntState(knownWorld, observation.getPos(), observation.hasFood());
|
||||
}
|
||||
|
||||
private void initUnknownWorld(){
|
||||
public void initUnknownWorld(){
|
||||
for(int x = 0; x < knownWorld.length; ++x){
|
||||
for(int y = 0; y < knownWorld[x].length; ++y){
|
||||
knownWorld[x][y] = new Cell(new Point(x,y), CellType.UNKNOWN);
|
||||
|
|
|
@ -11,17 +11,39 @@ import java.util.Arrays;
|
|||
* and therefor has to be deep copied
|
||||
*/
|
||||
public class AntState implements State {
|
||||
private Cell[][] knownWorld;
|
||||
private Point pos;
|
||||
private boolean hasFood;
|
||||
|
||||
private final Cell[][] knownWorld;
|
||||
private final Point pos;
|
||||
private final boolean hasFood;
|
||||
private final int computedHash;
|
||||
|
||||
public AntState(Cell[][] knownWorld, Point antPosition, boolean hasFood){
|
||||
this.knownWorld = deepCopyCellGrid(knownWorld);
|
||||
this.pos = deepCopyAntPosition(antPosition);
|
||||
this.hasFood = hasFood;
|
||||
computedHash = computeHash();
|
||||
}
|
||||
|
||||
private int computeHash(){
|
||||
int hash = 7;
|
||||
int prime = 31;
|
||||
|
||||
int unknown = 0;
|
||||
int diff = 0;
|
||||
for (int i = 0; i < knownWorld.length; i++) {
|
||||
for (int j = 0; j < knownWorld[i].length; j++) {
|
||||
if(knownWorld[i][j].getType() == CellType.UNKNOWN){
|
||||
unknown += 1;
|
||||
}else{
|
||||
diff +=1;
|
||||
}
|
||||
}
|
||||
}
|
||||
hash = prime * hash + unknown;
|
||||
hash = prime * hash * diff;
|
||||
hash = prime * hash + (hasFood ? 1:0);
|
||||
hash = prime * hash + pos.hashCode();
|
||||
return hash;
|
||||
}
|
||||
private Cell[][] deepCopyCellGrid(Cell[][] toCopy){
|
||||
Cell[][] cells = new Cell[toCopy.length][toCopy[0].length];
|
||||
for (int i = 0; i < cells.length; i++) {
|
||||
|
@ -45,12 +67,7 @@ public class AntState implements State {
|
|||
//TODO: make this a utility function to generate hash Code based upon 2 prime numbers
|
||||
@Override
|
||||
public int hashCode(){
|
||||
int hash = 7;
|
||||
int prime = 31;
|
||||
hash = prime * hash + Arrays.hashCode(knownWorld);
|
||||
hash = prime * hash + (hasFood ? 1:0);
|
||||
hash = prime * hash + pos.hashCode();
|
||||
return hash;
|
||||
return computedHash;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
package evironment.antGame;
|
||||
|
||||
import core.*;
|
||||
import core.algo.Learning;
|
||||
import core.algo.MC.MonteCarloOnPolicyEGreedy;
|
||||
import evironment.antGame.gui.MainFrame;
|
||||
|
||||
|
||||
import javax.swing.*;
|
||||
import java.awt.*;
|
||||
|
||||
public class AntWorld implements Environment<AntAction>{
|
||||
|
@ -39,9 +40,8 @@ public class AntWorld implements Environment<AntAction>{
|
|||
public AntWorld(int width, int height, double foodDensity){
|
||||
grid = new Grid(width, height, foodDensity);
|
||||
antAgent = new AntAgent(width, height);
|
||||
myAnt = new Ant(new Point(-1,-1), 0, false, false);
|
||||
myAnt = new Ant();
|
||||
gui = new MainFrame(this, antAgent);
|
||||
tick = 0;
|
||||
maxEpisodeTicks = 1000;
|
||||
reset();
|
||||
}
|
||||
|
@ -55,23 +55,13 @@ public class AntWorld implements Environment<AntAction>{
|
|||
}
|
||||
|
||||
@Override
|
||||
public StepResult step(AntAction action){
|
||||
public StepResultEnvironment step(AntAction action){
|
||||
AntObservation observation;
|
||||
State newState;
|
||||
double reward = 0;
|
||||
String info = "";
|
||||
boolean done = false;
|
||||
|
||||
if(!myAnt.isSpawned()){
|
||||
myAnt.setSpawned(true);
|
||||
myAnt.getPos().setLocation(grid.getStartPoint());
|
||||
observation = new AntObservation(grid.getCell(myAnt.getPos()), myAnt.getPos(), myAnt.hasFood());
|
||||
newState = antAgent.feedObservation(observation);
|
||||
reward = 0.0;
|
||||
++tick;
|
||||
return new StepResult(newState, reward, false, "Just spawned on the map");
|
||||
}
|
||||
|
||||
Cell currentCell = grid.getCell(myAnt.getPos());
|
||||
Point potentialNextPos = new Point(myAnt.getPos().x, myAnt.getPos().y);
|
||||
boolean stayOnCell = true;
|
||||
|
@ -107,7 +97,7 @@ public class AntWorld implements Environment<AntAction>{
|
|||
// Ant successfully picks up food
|
||||
currentCell.setFood(currentCell.getFood() - 1);
|
||||
myAnt.setHasFood(true);
|
||||
reward = Reward.FOOD_DROP_DOWN_SUCCESS;
|
||||
reward = Reward.FOOD_PICK_UP_SUCCESS;
|
||||
}
|
||||
break;
|
||||
case DROP_DOWN:
|
||||
|
@ -169,24 +159,30 @@ public class AntWorld implements Environment<AntAction>{
|
|||
if(++tick == maxEpisodeTicks){
|
||||
done = true;
|
||||
}
|
||||
return new StepResult(newState, reward, done, info);
|
||||
|
||||
StepResultEnvironment result = new StepResultEnvironment(newState, reward, done, info);
|
||||
getGui().update(action, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
private boolean isInGrid(Point pos){
|
||||
return pos.x > 0 && pos.x < grid.getWidth() && pos.y > 0 && pos.y < grid.getHeight();
|
||||
return pos.x >= 0 && pos.x < grid.getWidth() && pos.y >= 0 && pos.y < grid.getHeight();
|
||||
}
|
||||
|
||||
private boolean hitObstacle(Point pos){
|
||||
return grid.getCell(pos).getType() == CellType.OBSTACLE;
|
||||
}
|
||||
|
||||
public void reset() {
|
||||
public State reset() {
|
||||
RNG.reseed();
|
||||
grid.initRandomWorld();
|
||||
myAnt.getPos().setLocation(-1,-1);
|
||||
antAgent.initUnknownWorld();
|
||||
tick = 0;
|
||||
myAnt.getPos().setLocation(grid.getStartPoint());
|
||||
myAnt.setPoints(0);
|
||||
myAnt.setHasFood(false);
|
||||
myAnt.setSpawned(false);
|
||||
AntObservation observation = new AntObservation(grid.getCell(myAnt.getPos()), myAnt.getPos(), myAnt.hasFood());
|
||||
return antAgent.feedObservation(observation);
|
||||
}
|
||||
|
||||
public void setMaxEpisodeLength(int maxTicks){
|
||||
|
@ -207,21 +203,14 @@ public class AntWorld implements Environment<AntAction>{
|
|||
public Ant getAnt(){
|
||||
return myAnt;
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
RNG.setSeed(1993);
|
||||
AntWorld a = new AntWorld(10, 10, 0.1);
|
||||
ListDiscreteActionSpace<AntAction> actionSpace =
|
||||
new ListDiscreteActionSpace<>(AntAction.MOVE_LEFT, AntAction.MOVE_RIGHT);
|
||||
|
||||
for(int i = 0; i< 1000; ++i){
|
||||
AntAction selectedAction = actionSpace.getAllActions().get(RNG.getRandom().nextInt(actionSpace.getNumberOfAction()));
|
||||
StepResult step = a.step(selectedAction);
|
||||
SwingUtilities.invokeLater(()-> a.getGui().update(selectedAction, step));
|
||||
try {
|
||||
Thread.sleep(100);
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
Learning<AntAction> monteCarlo = new MonteCarloOnPolicyEGreedy<>(
|
||||
new AntWorld(3, 3, 0.1),
|
||||
new ListDiscreteActionSpace<>(AntAction.values())
|
||||
);
|
||||
monteCarlo.learn(100,5);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,17 +1,17 @@
|
|||
package evironment.antGame;
|
||||
|
||||
public class Reward {
|
||||
public static final double FOOD_PICK_UP_SUCCESS = 1;
|
||||
public static final double FOOD_PICK_UP_FAIL_NO_FOOD = -1000;
|
||||
public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = -1000;
|
||||
public static final double FOOD_PICK_UP_SUCCESS = 0;
|
||||
public static final double FOOD_PICK_UP_FAIL_NO_FOOD = 0;
|
||||
public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = 0;
|
||||
|
||||
public static final double FOOD_DROP_DOWN_FAIL_NO_FOOD = -1000;
|
||||
public static final double FOOD_DROP_DOWN_FAIL_NOT_START = -1000;
|
||||
public static final double FOOD_DROP_DOWN_FAIL_NO_FOOD = 0;
|
||||
public static final double FOOD_DROP_DOWN_FAIL_NOT_START = 0;
|
||||
public static final double FOOD_DROP_DOWN_SUCCESS = 1000;
|
||||
|
||||
public static final double UNKNOWN_FIELD_EXPLORED = 1;
|
||||
public static final double UNKNOWN_FIELD_EXPLORED = 0;
|
||||
|
||||
public static final double RAN_INTO_WALL = -100;
|
||||
public static final double RAN_INTO_OBSTACLE = -100;
|
||||
public static final double RAN_INTO_WALL = 0;
|
||||
public static final double RAN_INTO_OBSTACLE = 0;
|
||||
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
package evironment.antGame.gui;
|
||||
|
||||
import core.StepResult;
|
||||
import core.StepResultEnvironment;
|
||||
import evironment.antGame.AntAction;
|
||||
import evironment.antGame.AntAgent;
|
||||
import evironment.antGame.AntWorld;
|
||||
|
@ -33,9 +33,9 @@ public class MainFrame extends JFrame {
|
|||
setVisible(true);
|
||||
}
|
||||
|
||||
public void update(AntAction lastAction, StepResult stepResult){
|
||||
public void update(AntAction lastAction, StepResultEnvironment stepResultEnvironment){
|
||||
historyPanel.addText(String.format("Tick %d: \t Selected action: %s \t Reward: %f \t Info: %s \n totalPoints: %d \t hasFood: %b \t ",
|
||||
antWorld.getTick(), lastAction.toString(), stepResult.getReward(), stepResult.getInfo(), antWorld.getAnt().getPoints(), antWorld.getAnt().hasFood()));
|
||||
antWorld.getTick(), lastAction.toString(), stepResultEnvironment.getReward(), stepResultEnvironment.getInfo(), antWorld.getAnt().getPoints(), antWorld.getAnt().hasFood()));
|
||||
|
||||
repaint();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue