add Random-, Greedy and EGreedy-Policy and first implementation of monte carlo method
- fixed bug regarding wrong generation of hashCode. hashCodes needs to be equal across equal objects. Compute hashCode on final states once and return this value instead of computing it every time .hashCode() gets called. -
This commit is contained in:
parent
0100f2e82a
commit
55d8bbf5dc
|
@ -0,0 +1,4 @@
|
||||||
|
package core;
|
||||||
|
|
||||||
|
public interface Action {
|
||||||
|
}
|
|
@ -1,5 +1,6 @@
|
||||||
package core;
|
package core;
|
||||||
|
|
||||||
public interface Environment<A extends Enum> {
|
public interface Environment<A extends Enum> {
|
||||||
StepResult step(A action);
|
StepResultEnvironment step(A action);
|
||||||
|
State reset();
|
||||||
}
|
}
|
||||||
|
|
|
@ -55,11 +55,10 @@ public class StateActionHashTable<A extends Enum> implements StateActionTable<A>
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<A, Double> getActionValues(State state) {
|
public Map<A, Double> getActionValues(State state) {
|
||||||
Map<A, Double> actionValues = table.get(state);
|
if(table.get(state) == null){
|
||||||
if(actionValues == null){
|
table.put(state, createDefaultActionValues());
|
||||||
actionValues = createDefaultActionValues();
|
|
||||||
}
|
}
|
||||||
return actionValues;
|
return table.get(state);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
|
|
|
@ -2,14 +2,11 @@ package core;
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
import lombok.AllArgsConstructor;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
|
||||||
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
public class StepResult {
|
@Getter
|
||||||
|
public class StepResult<A extends Enum> {
|
||||||
private State state;
|
private State state;
|
||||||
|
private A action;
|
||||||
private double reward;
|
private double reward;
|
||||||
private boolean done;
|
|
||||||
private String info;
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,15 @@
|
||||||
|
package core;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.Setter;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
@Setter
|
||||||
|
@AllArgsConstructor
|
||||||
|
public class StepResultEnvironment {
|
||||||
|
private State state;
|
||||||
|
private double reward;
|
||||||
|
private boolean done;
|
||||||
|
private String info;
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
package core.algo;
|
||||||
|
|
||||||
|
import core.DiscreteActionSpace;
|
||||||
|
import core.Environment;
|
||||||
|
import core.StateActionTable;
|
||||||
|
import core.policy.Policy;
|
||||||
|
|
||||||
|
public abstract class Learning<A extends Enum> {
|
||||||
|
protected Policy<A> policy;
|
||||||
|
protected DiscreteActionSpace<A> actionSpace;
|
||||||
|
protected StateActionTable<A> stateActionTable;
|
||||||
|
protected Environment<A> environment;
|
||||||
|
protected float discountFactor;
|
||||||
|
protected float epsilon;
|
||||||
|
|
||||||
|
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon){
|
||||||
|
this.environment = environment;
|
||||||
|
this.actionSpace = actionSpace;
|
||||||
|
this.discountFactor = discountFactor;
|
||||||
|
this.epsilon = epsilon;
|
||||||
|
}
|
||||||
|
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace){
|
||||||
|
this(environment, actionSpace, 1.0f, 0.1f);
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract void learn(int nrOfEpisodes, int delay);
|
||||||
|
}
|
|
@ -0,0 +1,75 @@
|
||||||
|
package core.algo.MC;
|
||||||
|
|
||||||
|
import core.*;
|
||||||
|
import core.algo.Learning;
|
||||||
|
import core.policy.EpsilonGreedyPolicy;
|
||||||
|
import javafx.util.Pair;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
public class MonteCarloOnPolicyEGreedy<A extends Enum> extends Learning<A> {
|
||||||
|
|
||||||
|
public MonteCarloOnPolicyEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace) {
|
||||||
|
super(environment, actionSpace);
|
||||||
|
discountFactor = 1f;
|
||||||
|
this.policy = new EpsilonGreedyPolicy<>(0.1f);
|
||||||
|
this.stateActionTable = new StateActionHashTable<>(actionSpace);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void learn(int nrOfEpisodes, int delay) {
|
||||||
|
|
||||||
|
Map<Pair<State, A>, Double> returnSum = new HashMap<>();
|
||||||
|
Map<Pair<State, A>, Integer> returnCount = new HashMap<>();
|
||||||
|
|
||||||
|
for(int i = 0; i < nrOfEpisodes; ++i) {
|
||||||
|
|
||||||
|
List<StepResult<A>> episode = new ArrayList<>();
|
||||||
|
State state = environment.reset();
|
||||||
|
for(int j=0; j < 100; ++j){
|
||||||
|
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
|
||||||
|
A chosenAction = policy.chooseAction(actionValues);
|
||||||
|
StepResultEnvironment envResult = environment.step(chosenAction);
|
||||||
|
State nextState = envResult.getState();
|
||||||
|
episode.add(new StepResult<>(state, chosenAction, envResult.getReward()));
|
||||||
|
|
||||||
|
if(envResult.isDone()) break;
|
||||||
|
|
||||||
|
state = nextState;
|
||||||
|
|
||||||
|
try {
|
||||||
|
Thread.sleep(10);
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Set<Pair<State, A>> stateActionPairs = new HashSet<>();
|
||||||
|
|
||||||
|
for(StepResult<A> sr: episode){
|
||||||
|
stateActionPairs.add(new Pair<>(sr.getState(), sr.getAction()));
|
||||||
|
}
|
||||||
|
|
||||||
|
for(Pair<State, A> stateActionPair: stateActionPairs){
|
||||||
|
int firstOccurenceIndex = 0;
|
||||||
|
// find first occurance of state action pair
|
||||||
|
for(StepResult<A> sr: episode){
|
||||||
|
if(stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())){
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
firstOccurenceIndex++;
|
||||||
|
}
|
||||||
|
|
||||||
|
double G = 0;
|
||||||
|
for(int l = firstOccurenceIndex; l < episode.size(); ++l){
|
||||||
|
G += episode.get(l).getReward() * (Math.pow(discountFactor, l - firstOccurenceIndex));
|
||||||
|
}
|
||||||
|
// slick trick to add G to the entry.
|
||||||
|
// if the key does not exists, it will create a new entry with G as default value
|
||||||
|
returnSum.merge(stateActionPair, G, Double::sum);
|
||||||
|
returnCount.merge(stateActionPair, 1, Integer::sum);
|
||||||
|
stateActionTable.setValue(stateActionPair.getKey(), stateActionPair.getValue(), returnSum.get(stateActionPair) / returnCount.get(stateActionPair));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,4 @@
|
||||||
|
package core.algo.TD;
|
||||||
|
|
||||||
|
public class TemporalDifferenceOnPolicy {
|
||||||
|
}
|
|
@ -0,0 +1,35 @@
|
||||||
|
package core.policy;
|
||||||
|
|
||||||
|
import core.RNG;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* To prevent the agent from getting stuck only using the "best" action
|
||||||
|
* according to the current learning history, this policy
|
||||||
|
* will take random action with the probability of epsilon.
|
||||||
|
* (random action space includes the best action as well)
|
||||||
|
*
|
||||||
|
* @param <A> Discrete Action Enum
|
||||||
|
*/
|
||||||
|
public class EpsilonGreedyPolicy<A extends Enum> implements Policy<A>{
|
||||||
|
private float epsilon;
|
||||||
|
private RandomPolicy<A> randomPolicy;
|
||||||
|
private GreedyPolicy<A> greedyPolicy;
|
||||||
|
|
||||||
|
public EpsilonGreedyPolicy(float epsilon){
|
||||||
|
this.epsilon = epsilon;
|
||||||
|
randomPolicy = new RandomPolicy<>();
|
||||||
|
greedyPolicy = new GreedyPolicy<>();
|
||||||
|
}
|
||||||
|
@Override
|
||||||
|
public A chooseAction(Map<A, Double> actionValues) {
|
||||||
|
if(RNG.getRandom().nextFloat() < epsilon){
|
||||||
|
// Take random action
|
||||||
|
return randomPolicy.chooseAction(actionValues);
|
||||||
|
}else{
|
||||||
|
// Take the action with the highest value
|
||||||
|
return greedyPolicy.chooseAction(actionValues);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,32 @@
|
||||||
|
package core.policy;
|
||||||
|
|
||||||
|
import core.RNG;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class GreedyPolicy<A extends Enum> implements Policy<A> {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public A chooseAction(Map<A, Double> actionValues) {
|
||||||
|
if(actionValues.size() == 0) throw new RuntimeException("Empty actionActionValues set");
|
||||||
|
|
||||||
|
Double highestValueAction = null;
|
||||||
|
|
||||||
|
List<A> equalHigh = new ArrayList<>();
|
||||||
|
|
||||||
|
for(Map.Entry<A, Double> actionValue : actionValues.entrySet()){
|
||||||
|
System.out.println(actionValue.getKey()+ " " + actionValue.getValue() );
|
||||||
|
if(highestValueAction == null || highestValueAction < actionValue.getValue()){
|
||||||
|
highestValueAction = actionValue.getValue();
|
||||||
|
equalHigh.clear();
|
||||||
|
equalHigh.add(actionValue.getKey());
|
||||||
|
}else if(highestValueAction.equals(actionValue.getValue())){
|
||||||
|
equalHigh.add(actionValue.getKey());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return equalHigh.get(RNG.getRandom().nextInt(equalHigh.size()));
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,7 @@
|
||||||
|
package core.policy;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public interface Policy<A extends Enum> {
|
||||||
|
A chooseAction(Map<A, Double> actionValues);
|
||||||
|
}
|
|
@ -0,0 +1,17 @@
|
||||||
|
package core.policy;
|
||||||
|
|
||||||
|
import core.RNG;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class RandomPolicy<A extends Enum> implements Policy<A>{
|
||||||
|
@Override
|
||||||
|
public A chooseAction(Map<A, Double> actionValues) {
|
||||||
|
int idx = RNG.getRandom().nextInt(actionValues.size());
|
||||||
|
int i = 0;
|
||||||
|
for(A action : actionValues.keySet()){
|
||||||
|
if(i++ == idx) return action;
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
|
@ -7,7 +7,6 @@ import lombok.Setter;
|
||||||
|
|
||||||
import java.awt.*;
|
import java.awt.*;
|
||||||
|
|
||||||
@AllArgsConstructor
|
|
||||||
@Getter
|
@Getter
|
||||||
@Setter
|
@Setter
|
||||||
public class Ant {
|
public class Ant {
|
||||||
|
@ -15,10 +14,14 @@ public class Ant {
|
||||||
@Setter(AccessLevel.NONE)
|
@Setter(AccessLevel.NONE)
|
||||||
private Point pos;
|
private Point pos;
|
||||||
private int points;
|
private int points;
|
||||||
private boolean spawned;
|
|
||||||
@Getter(AccessLevel.NONE)
|
@Getter(AccessLevel.NONE)
|
||||||
private boolean hasFood;
|
private boolean hasFood;
|
||||||
|
|
||||||
|
public Ant(){
|
||||||
|
pos = new Point();
|
||||||
|
points = 0;
|
||||||
|
hasFood = false;
|
||||||
|
}
|
||||||
public boolean hasFood(){
|
public boolean hasFood(){
|
||||||
return hasFood;
|
return hasFood;
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,7 +9,6 @@ public class AntAgent {
|
||||||
|
|
||||||
public AntAgent(int width, int height){
|
public AntAgent(int width, int height){
|
||||||
knownWorld = new Cell[width][height];
|
knownWorld = new Cell[width][height];
|
||||||
initUnknownWorld();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -24,7 +23,7 @@ public class AntAgent {
|
||||||
return new AntState(knownWorld, observation.getPos(), observation.hasFood());
|
return new AntState(knownWorld, observation.getPos(), observation.hasFood());
|
||||||
}
|
}
|
||||||
|
|
||||||
private void initUnknownWorld(){
|
public void initUnknownWorld(){
|
||||||
for(int x = 0; x < knownWorld.length; ++x){
|
for(int x = 0; x < knownWorld.length; ++x){
|
||||||
for(int y = 0; y < knownWorld[x].length; ++y){
|
for(int y = 0; y < knownWorld[x].length; ++y){
|
||||||
knownWorld[x][y] = new Cell(new Point(x,y), CellType.UNKNOWN);
|
knownWorld[x][y] = new Cell(new Point(x,y), CellType.UNKNOWN);
|
||||||
|
|
|
@ -11,17 +11,39 @@ import java.util.Arrays;
|
||||||
* and therefor has to be deep copied
|
* and therefor has to be deep copied
|
||||||
*/
|
*/
|
||||||
public class AntState implements State {
|
public class AntState implements State {
|
||||||
private Cell[][] knownWorld;
|
private final Cell[][] knownWorld;
|
||||||
private Point pos;
|
private final Point pos;
|
||||||
private boolean hasFood;
|
private final boolean hasFood;
|
||||||
|
private final int computedHash;
|
||||||
|
|
||||||
public AntState(Cell[][] knownWorld, Point antPosition, boolean hasFood){
|
public AntState(Cell[][] knownWorld, Point antPosition, boolean hasFood){
|
||||||
this.knownWorld = deepCopyCellGrid(knownWorld);
|
this.knownWorld = deepCopyCellGrid(knownWorld);
|
||||||
this.pos = deepCopyAntPosition(antPosition);
|
this.pos = deepCopyAntPosition(antPosition);
|
||||||
this.hasFood = hasFood;
|
this.hasFood = hasFood;
|
||||||
|
computedHash = computeHash();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private int computeHash(){
|
||||||
|
int hash = 7;
|
||||||
|
int prime = 31;
|
||||||
|
|
||||||
|
int unknown = 0;
|
||||||
|
int diff = 0;
|
||||||
|
for (int i = 0; i < knownWorld.length; i++) {
|
||||||
|
for (int j = 0; j < knownWorld[i].length; j++) {
|
||||||
|
if(knownWorld[i][j].getType() == CellType.UNKNOWN){
|
||||||
|
unknown += 1;
|
||||||
|
}else{
|
||||||
|
diff +=1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
hash = prime * hash + unknown;
|
||||||
|
hash = prime * hash * diff;
|
||||||
|
hash = prime * hash + (hasFood ? 1:0);
|
||||||
|
hash = prime * hash + pos.hashCode();
|
||||||
|
return hash;
|
||||||
|
}
|
||||||
private Cell[][] deepCopyCellGrid(Cell[][] toCopy){
|
private Cell[][] deepCopyCellGrid(Cell[][] toCopy){
|
||||||
Cell[][] cells = new Cell[toCopy.length][toCopy[0].length];
|
Cell[][] cells = new Cell[toCopy.length][toCopy[0].length];
|
||||||
for (int i = 0; i < cells.length; i++) {
|
for (int i = 0; i < cells.length; i++) {
|
||||||
|
@ -45,12 +67,7 @@ public class AntState implements State {
|
||||||
//TODO: make this a utility function to generate hash Code based upon 2 prime numbers
|
//TODO: make this a utility function to generate hash Code based upon 2 prime numbers
|
||||||
@Override
|
@Override
|
||||||
public int hashCode(){
|
public int hashCode(){
|
||||||
int hash = 7;
|
return computedHash;
|
||||||
int prime = 31;
|
|
||||||
hash = prime * hash + Arrays.hashCode(knownWorld);
|
|
||||||
hash = prime * hash + (hasFood ? 1:0);
|
|
||||||
hash = prime * hash + pos.hashCode();
|
|
||||||
return hash;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
package evironment.antGame;
|
package evironment.antGame;
|
||||||
|
|
||||||
import core.*;
|
import core.*;
|
||||||
|
import core.algo.Learning;
|
||||||
|
import core.algo.MC.MonteCarloOnPolicyEGreedy;
|
||||||
import evironment.antGame.gui.MainFrame;
|
import evironment.antGame.gui.MainFrame;
|
||||||
|
|
||||||
|
|
||||||
import javax.swing.*;
|
|
||||||
import java.awt.*;
|
import java.awt.*;
|
||||||
|
|
||||||
public class AntWorld implements Environment<AntAction>{
|
public class AntWorld implements Environment<AntAction>{
|
||||||
|
@ -39,9 +40,8 @@ public class AntWorld implements Environment<AntAction>{
|
||||||
public AntWorld(int width, int height, double foodDensity){
|
public AntWorld(int width, int height, double foodDensity){
|
||||||
grid = new Grid(width, height, foodDensity);
|
grid = new Grid(width, height, foodDensity);
|
||||||
antAgent = new AntAgent(width, height);
|
antAgent = new AntAgent(width, height);
|
||||||
myAnt = new Ant(new Point(-1,-1), 0, false, false);
|
myAnt = new Ant();
|
||||||
gui = new MainFrame(this, antAgent);
|
gui = new MainFrame(this, antAgent);
|
||||||
tick = 0;
|
|
||||||
maxEpisodeTicks = 1000;
|
maxEpisodeTicks = 1000;
|
||||||
reset();
|
reset();
|
||||||
}
|
}
|
||||||
|
@ -55,23 +55,13 @@ public class AntWorld implements Environment<AntAction>{
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public StepResult step(AntAction action){
|
public StepResultEnvironment step(AntAction action){
|
||||||
AntObservation observation;
|
AntObservation observation;
|
||||||
State newState;
|
State newState;
|
||||||
double reward = 0;
|
double reward = 0;
|
||||||
String info = "";
|
String info = "";
|
||||||
boolean done = false;
|
boolean done = false;
|
||||||
|
|
||||||
if(!myAnt.isSpawned()){
|
|
||||||
myAnt.setSpawned(true);
|
|
||||||
myAnt.getPos().setLocation(grid.getStartPoint());
|
|
||||||
observation = new AntObservation(grid.getCell(myAnt.getPos()), myAnt.getPos(), myAnt.hasFood());
|
|
||||||
newState = antAgent.feedObservation(observation);
|
|
||||||
reward = 0.0;
|
|
||||||
++tick;
|
|
||||||
return new StepResult(newState, reward, false, "Just spawned on the map");
|
|
||||||
}
|
|
||||||
|
|
||||||
Cell currentCell = grid.getCell(myAnt.getPos());
|
Cell currentCell = grid.getCell(myAnt.getPos());
|
||||||
Point potentialNextPos = new Point(myAnt.getPos().x, myAnt.getPos().y);
|
Point potentialNextPos = new Point(myAnt.getPos().x, myAnt.getPos().y);
|
||||||
boolean stayOnCell = true;
|
boolean stayOnCell = true;
|
||||||
|
@ -107,7 +97,7 @@ public class AntWorld implements Environment<AntAction>{
|
||||||
// Ant successfully picks up food
|
// Ant successfully picks up food
|
||||||
currentCell.setFood(currentCell.getFood() - 1);
|
currentCell.setFood(currentCell.getFood() - 1);
|
||||||
myAnt.setHasFood(true);
|
myAnt.setHasFood(true);
|
||||||
reward = Reward.FOOD_DROP_DOWN_SUCCESS;
|
reward = Reward.FOOD_PICK_UP_SUCCESS;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case DROP_DOWN:
|
case DROP_DOWN:
|
||||||
|
@ -169,24 +159,30 @@ public class AntWorld implements Environment<AntAction>{
|
||||||
if(++tick == maxEpisodeTicks){
|
if(++tick == maxEpisodeTicks){
|
||||||
done = true;
|
done = true;
|
||||||
}
|
}
|
||||||
return new StepResult(newState, reward, done, info);
|
|
||||||
|
StepResultEnvironment result = new StepResultEnvironment(newState, reward, done, info);
|
||||||
|
getGui().update(action, result);
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean isInGrid(Point pos){
|
private boolean isInGrid(Point pos){
|
||||||
return pos.x > 0 && pos.x < grid.getWidth() && pos.y > 0 && pos.y < grid.getHeight();
|
return pos.x >= 0 && pos.x < grid.getWidth() && pos.y >= 0 && pos.y < grid.getHeight();
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean hitObstacle(Point pos){
|
private boolean hitObstacle(Point pos){
|
||||||
return grid.getCell(pos).getType() == CellType.OBSTACLE;
|
return grid.getCell(pos).getType() == CellType.OBSTACLE;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void reset() {
|
public State reset() {
|
||||||
RNG.reseed();
|
RNG.reseed();
|
||||||
grid.initRandomWorld();
|
grid.initRandomWorld();
|
||||||
myAnt.getPos().setLocation(-1,-1);
|
antAgent.initUnknownWorld();
|
||||||
|
tick = 0;
|
||||||
|
myAnt.getPos().setLocation(grid.getStartPoint());
|
||||||
myAnt.setPoints(0);
|
myAnt.setPoints(0);
|
||||||
myAnt.setHasFood(false);
|
myAnt.setHasFood(false);
|
||||||
myAnt.setSpawned(false);
|
AntObservation observation = new AntObservation(grid.getCell(myAnt.getPos()), myAnt.getPos(), myAnt.hasFood());
|
||||||
|
return antAgent.feedObservation(observation);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setMaxEpisodeLength(int maxTicks){
|
public void setMaxEpisodeLength(int maxTicks){
|
||||||
|
@ -207,21 +203,14 @@ public class AntWorld implements Environment<AntAction>{
|
||||||
public Ant getAnt(){
|
public Ant getAnt(){
|
||||||
return myAnt;
|
return myAnt;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
RNG.setSeed(1993);
|
RNG.setSeed(1993);
|
||||||
AntWorld a = new AntWorld(10, 10, 0.1);
|
|
||||||
ListDiscreteActionSpace<AntAction> actionSpace =
|
|
||||||
new ListDiscreteActionSpace<>(AntAction.MOVE_LEFT, AntAction.MOVE_RIGHT);
|
|
||||||
|
|
||||||
for(int i = 0; i< 1000; ++i){
|
Learning<AntAction> monteCarlo = new MonteCarloOnPolicyEGreedy<>(
|
||||||
AntAction selectedAction = actionSpace.getAllActions().get(RNG.getRandom().nextInt(actionSpace.getNumberOfAction()));
|
new AntWorld(3, 3, 0.1),
|
||||||
StepResult step = a.step(selectedAction);
|
new ListDiscreteActionSpace<>(AntAction.values())
|
||||||
SwingUtilities.invokeLater(()-> a.getGui().update(selectedAction, step));
|
);
|
||||||
try {
|
monteCarlo.learn(100,5);
|
||||||
Thread.sleep(100);
|
|
||||||
} catch (InterruptedException e) {
|
|
||||||
e.printStackTrace();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,17 +1,17 @@
|
||||||
package evironment.antGame;
|
package evironment.antGame;
|
||||||
|
|
||||||
public class Reward {
|
public class Reward {
|
||||||
public static final double FOOD_PICK_UP_SUCCESS = 1;
|
public static final double FOOD_PICK_UP_SUCCESS = 0;
|
||||||
public static final double FOOD_PICK_UP_FAIL_NO_FOOD = -1000;
|
public static final double FOOD_PICK_UP_FAIL_NO_FOOD = 0;
|
||||||
public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = -1000;
|
public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = 0;
|
||||||
|
|
||||||
public static final double FOOD_DROP_DOWN_FAIL_NO_FOOD = -1000;
|
public static final double FOOD_DROP_DOWN_FAIL_NO_FOOD = 0;
|
||||||
public static final double FOOD_DROP_DOWN_FAIL_NOT_START = -1000;
|
public static final double FOOD_DROP_DOWN_FAIL_NOT_START = 0;
|
||||||
public static final double FOOD_DROP_DOWN_SUCCESS = 1000;
|
public static final double FOOD_DROP_DOWN_SUCCESS = 1000;
|
||||||
|
|
||||||
public static final double UNKNOWN_FIELD_EXPLORED = 1;
|
public static final double UNKNOWN_FIELD_EXPLORED = 0;
|
||||||
|
|
||||||
public static final double RAN_INTO_WALL = -100;
|
public static final double RAN_INTO_WALL = 0;
|
||||||
public static final double RAN_INTO_OBSTACLE = -100;
|
public static final double RAN_INTO_OBSTACLE = 0;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
package evironment.antGame.gui;
|
package evironment.antGame.gui;
|
||||||
|
|
||||||
import core.StepResult;
|
import core.StepResultEnvironment;
|
||||||
import evironment.antGame.AntAction;
|
import evironment.antGame.AntAction;
|
||||||
import evironment.antGame.AntAgent;
|
import evironment.antGame.AntAgent;
|
||||||
import evironment.antGame.AntWorld;
|
import evironment.antGame.AntWorld;
|
||||||
|
@ -33,9 +33,9 @@ public class MainFrame extends JFrame {
|
||||||
setVisible(true);
|
setVisible(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void update(AntAction lastAction, StepResult stepResult){
|
public void update(AntAction lastAction, StepResultEnvironment stepResultEnvironment){
|
||||||
historyPanel.addText(String.format("Tick %d: \t Selected action: %s \t Reward: %f \t Info: %s \n totalPoints: %d \t hasFood: %b \t ",
|
historyPanel.addText(String.format("Tick %d: \t Selected action: %s \t Reward: %f \t Info: %s \n totalPoints: %d \t hasFood: %b \t ",
|
||||||
antWorld.getTick(), lastAction.toString(), stepResult.getReward(), stepResult.getInfo(), antWorld.getAnt().getPoints(), antWorld.getAnt().hasFood()));
|
antWorld.getTick(), lastAction.toString(), stepResultEnvironment.getReward(), stepResultEnvironment.getInfo(), antWorld.getAnt().getPoints(), antWorld.getAnt().hasFood()));
|
||||||
|
|
||||||
repaint();
|
repaint();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue