add fix RNG, add extended interface EpsilonPolicy and move rewardHistory to model instead of view

- only setting the seed of RNG once at the beginning and not reseeding it afterwards. Deep copying
the initial AntWorld to use as blueprint for resetting the world instead of reseeding and creating pesudo random again. Reseeding the RNG has influence action selecting to always
choose the same trajectory.
- instance of is used to determine if policy has epsilon or not and the view will adopt to this, only showing epsilon slider if policy has epsilon
This commit is contained in:
Jan Löwenstrom 2019-12-20 16:51:09 +01:00
parent e0160ca1df
commit 7db5a2af3b
16 changed files with 130 additions and 74 deletions

View File

@ -10,8 +10,11 @@ import lombok.Getter;
import lombok.Setter;
import javax.swing.*;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArrayList;
@Getter
public abstract class Learning<A extends Enum> {
@ -20,38 +23,43 @@ public abstract class Learning<A extends Enum> {
protected StateActionTable<A> stateActionTable;
protected Environment<A> environment;
protected float discountFactor;
@Setter
protected float epsilon;
protected Set<LearningListener> learningListeners;
@Setter
protected int delay;
private List<Double> rewardHistory;
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay){
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay){
this.environment = environment;
this.actionSpace = actionSpace;
this.discountFactor = discountFactor;
this.epsilon = epsilon;
this.delay = delay;
learningListeners = new HashSet<>();
rewardHistory = new CopyOnWriteArrayList<>();
}
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon){
this(environment, actionSpace, discountFactor, epsilon, LearningConfig.DEFAULT_DELAY);
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor){
this(environment, actionSpace, discountFactor, LearningConfig.DEFAULT_DELAY);
}
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay){
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, delay);
}
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace){
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, LearningConfig.DEFAULT_DELAY);
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_DELAY);
}
public abstract void learn(int nrOfEpisodes);
public void addListener(LearningListener learningListener){
learningListeners.add(learningListener);
}
protected void dispatchEpisodeEnd(double sum){
protected void dispatchEpisodeEnd(double recentSumOfRewards){
rewardHistory.add(recentSumOfRewards);
for(LearningListener l: learningListeners) {
l.onEpisodeEnd(sum);
l.onEpisodeEnd(rewardHistory);
}
}

View File

@ -4,6 +4,8 @@ import core.*;
import core.algo.Learning;
import core.policy.EpsilonGreedyPolicy;
import javafx.util.Pair;
import lombok.Setter;
import java.util.*;
/**
@ -26,13 +28,18 @@ import java.util.*;
*/
public class MonteCarloOnPolicyEGreedy<A extends Enum> extends Learning<A> {
public MonteCarloOnPolicyEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace) {
super(environment, actionSpace);
discountFactor = 1f;
this.policy = new EpsilonGreedyPolicy<>(0.1f);
this.stateActionTable = new StateActionHashTable<>(actionSpace);
public MonteCarloOnPolicyEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
super(environment, actionSpace, discountFactor, delay);
this.policy = new EpsilonGreedyPolicy<>(epsilon);
this.stateActionTable = new StateActionHashTable<>(this.actionSpace);
}
public MonteCarloOnPolicyEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay);
}
@Override
public void learn(int nrOfEpisodes) {

View File

@ -7,9 +7,9 @@ import core.algo.Learning;
import core.algo.Method;
import core.algo.mc.MonteCarloOnPolicyEGreedy;
import core.gui.View;
import core.policy.EpsilonPolicy;
import javax.swing.*;
import java.util.Optional;
public class RLController<A extends Enum> implements ViewListener{
protected Environment<A> environment;
@ -30,28 +30,38 @@ public class RLController<A extends Enum> implements ViewListener{
switch (method){
case MC_ONPOLICY_EGREEDY:
learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace);
learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, delay);
break;
case TD_ONPOLICY:
break;
default:
throw new RuntimeException("Undefined method");
}
SwingUtilities.invokeLater(() ->{
view = new View<>(learning, this);
learning.addListener(view);
});
/*
not using SwingUtilities here on purpose to ensure the view is fully
initialized and can be passed as LearningListener.
*/
view = new View<>(learning, this);
learning.addListener(view);
learning.learn(nrOfEpisodes);
}
@Override
public void onEpsilonChange(float epsilon) {
learning.setEpsilon(epsilon);
SwingUtilities.invokeLater(() -> view.updateLearningInfoPanel());
if(learning.getPolicy() instanceof EpsilonPolicy){
((EpsilonPolicy<A>) learning.getPolicy()).setEpsilon(epsilon);
SwingUtilities.invokeLater(() -> view.updateLearningInfoPanel());
}else{
System.out.println("Trying to call inEpsilonChange on non-epsilon policy");
}
}
@Override
public void onDelayChange(int delay) {
learning.setDelay(delay);
SwingUtilities.invokeLater(() -> {
view.updateLearningInfoPanel();
});
}
public RLController<A> setMethod(Method method){

View File

@ -2,6 +2,7 @@ package core.gui;
import core.algo.Learning;
import core.controller.ViewListener;
import core.policy.EpsilonPolicy;
import javax.swing.*;
@ -11,6 +12,7 @@ public class LearningInfoPanel extends JPanel {
private JLabel discountLabel;
private JLabel epsilonLabel;
private JSlider epsilonSlider;
private JLabel delayLabel;
private JSlider delaySlider;
public LearningInfoPanel(Learning learning, ViewListener viewListener){
@ -19,12 +21,19 @@ public class LearningInfoPanel extends JPanel {
policyLabel = new JLabel();
discountLabel = new JLabel();
epsilonLabel = new JLabel();
epsilonSlider = new JSlider(0, 100, (int)(learning.getEpsilon() * 100));
epsilonSlider.addChangeListener(e -> viewListener.onEpsilonChange(epsilonSlider.getValue() / 100f));
delayLabel = new JLabel();
delaySlider = new JSlider(0,1000, learning.getDelay());
delaySlider.addChangeListener(e -> viewListener.onDelayChange(delaySlider.getValue()));
add(policyLabel);
add(discountLabel);
add(epsilonLabel);
add(epsilonSlider);
if(learning.getPolicy() instanceof EpsilonPolicy){
epsilonSlider = new JSlider(0, 100, (int)((EpsilonPolicy)learning.getPolicy()).getEpsilon() * 100);
epsilonSlider.addChangeListener(e -> viewListener.onEpsilonChange(epsilonSlider.getValue() / 100f));
add(epsilonLabel);
add(epsilonSlider);
}
add(delayLabel);
add(delaySlider);
refreshLabels();
setVisible(true);
}
@ -32,10 +41,9 @@ public class LearningInfoPanel extends JPanel {
public void refreshLabels(){
policyLabel.setText("Policy: " + learning.getPolicy().getClass());
discountLabel.setText("Discount factor: " + learning.getDiscountFactor());
epsilonLabel.setText("Exploration (Epsilon): " + learning.getEpsilon());
}
protected JSlider getEpsilonSlider(){
return epsilonSlider;
if(learning.getPolicy() instanceof EpsilonPolicy){
epsilonLabel.setText("Exploration (Epsilon): " + ((EpsilonPolicy)learning.getPolicy()).getEpsilon());
}
delayLabel.setText("Delay (ms): " + learning.getDelay());
}
}

View File

@ -23,12 +23,10 @@ public class View<A extends Enum> implements LearningListener {
private JFrame mainFrame;
private XChartPanel<XYChart> rewardChartPanel;
private ViewListener viewListener;
private List<Double> rewardHistory;
public View(Learning<A> learning, ViewListener viewListener){
this.learning = learning;
this.viewListener = viewListener;
rewardHistory = new ArrayList<>();
this.initMainFrame();
}
@ -78,8 +76,7 @@ public class View<A extends Enum> implements LearningListener {
};
}
public void updateRewardGraph(double recentReward){
rewardHistory.add(recentReward);
public void updateRewardGraph(List<Double> rewardHistory){
chart.updateXYSeries("randomWalk", null, rewardHistory, null);
rewardChartPanel.revalidate();
rewardChartPanel.repaint();
@ -89,10 +86,11 @@ public class View<A extends Enum> implements LearningListener {
this.learningInfoPanel.refreshLabels();
}
@Override
public void onEpisodeEnd(double sumOfRewards) {
SwingUtilities.invokeLater(()->updateRewardGraph(sumOfRewards));
public void onEpisodeEnd(List<Double> rewardHistory) {
SwingUtilities.invokeLater(()->{
updateRewardGraph(rewardHistory);
});
}
@Override

View File

@ -1,6 +1,8 @@
package core.listener;
import java.util.List;
public interface LearningListener{
void onEpisodeEnd(double sumOfRewards);
void onEpisodeEnd(List<Double> rewardHistory);
void onEpisodeStart();
}

View File

@ -1,6 +1,8 @@
package core.policy;
import core.RNG;
import lombok.Getter;
import lombok.Setter;
import java.util.Map;
@ -12,7 +14,9 @@ import java.util.Map;
*
* @param <A> Discrete Action Enum
*/
public class EpsilonGreedyPolicy<A extends Enum> implements Policy<A>{
public class EpsilonGreedyPolicy<A extends Enum> implements EpsilonPolicy<A>{
@Setter
@Getter
private float epsilon;
private RandomPolicy<A> randomPolicy;
private GreedyPolicy<A> greedyPolicy;
@ -22,8 +26,10 @@ public class EpsilonGreedyPolicy<A extends Enum> implements Policy<A>{
randomPolicy = new RandomPolicy<>();
greedyPolicy = new GreedyPolicy<>();
}
@Override
public A chooseAction(Map<A, Double> actionValues) {
System.out.println("current epsilon " + epsilon);
if(RNG.getRandom().nextFloat() < epsilon){
// Take random action
return randomPolicy.chooseAction(actionValues);

View File

@ -0,0 +1,6 @@
package core.policy;
public interface EpsilonPolicy<A extends Enum> extends Policy<A> {
float getEpsilon();
void setEpsilon(float epsilon);
}

View File

@ -1,7 +1,5 @@
package core.policy;
import core.RNG;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@ -13,7 +11,7 @@ public class GreedyPolicy<A extends Enum> implements Policy<A> {
public A chooseAction(Map<A, Double> actionValues) {
if(actionValues.size() == 0) throw new RuntimeException("Empty actionActionValues set");
Double highestValueAction = null;
Double highestValueAction = null;
List<A> equalHigh = new ArrayList<>();

View File

@ -7,6 +7,7 @@ public class RandomPolicy<A extends Enum> implements Policy<A>{
@Override
public A chooseAction(Map<A, Double> actionValues) {
int idx = RNG.getRandom().nextInt(actionValues.size());
System.out.println("selected action " + idx);
int i = 0;
for(A action : actionValues.keySet()){
if(i++ == idx) return action;

View File

@ -20,13 +20,13 @@ public class AntState implements State, Visualizable {
private final int computedHash;
public AntState(Cell[][] knownWorld, Point antPosition, boolean hasFood){
this.knownWorld = deepCopyCellGrid(knownWorld);
this.knownWorld = Util.deepCopyCellGrid(knownWorld);
this.pos = deepCopyAntPosition(antPosition);
this.hasFood = hasFood;
computedHash = computeHash();
}
private int computeHash(){
private int computeHash() {
int hash = 7;
int prime = 31;
@ -43,20 +43,10 @@ public class AntState implements State, Visualizable {
}
hash = prime * hash + unknown;
hash = prime * hash * diff;
hash = prime * hash + (hasFood ? 1:0);
hash = prime * hash + (hasFood ? 1 : 0);
hash = prime * hash + pos.hashCode();
return hash;
}
private Cell[][] deepCopyCellGrid(Cell[][] toCopy){
Cell[][] cells = new Cell[toCopy.length][toCopy[0].length];
for (int i = 0; i < cells.length; i++) {
for (int j = 0; j < cells[i].length; j++) {
// calling copy constructor of Cell class
cells[i][j] = new Cell(toCopy[i][j]);
}
}
return cells;
}
private Point deepCopyAntPosition(Point toCopy){
return new Point(toCopy.x,toCopy.y);

View File

@ -151,9 +151,12 @@ public class AntWorld implements Environment<AntAction>{
done = grid.isAllFoodCollected();
}
/*
if(!done){
reward = -1;
}
*/
if(++tick == maxEpisodeTicks){
done = true;
}
@ -172,8 +175,7 @@ public class AntWorld implements Environment<AntAction>{
}
public State reset() {
RNG.reseed();
grid.initRandomWorld();
grid.resetWorld();
antAgent.initUnknownWorld();
tick = 0;
myAnt.getPos().setLocation(grid.getStartPoint());
@ -207,7 +209,8 @@ public class AntWorld implements Environment<AntAction>{
Learning<AntAction> monteCarlo = new MonteCarloOnPolicyEGreedy<>(
new AntWorld(3, 3, 0.1),
new ListDiscreteActionSpace<>(AntAction.values())
new ListDiscreteActionSpace<>(AntAction.values()),
5
);
monteCarlo.learn(20000);
}

View File

@ -10,31 +10,37 @@ public class Grid {
private double foodDensity;
private Point start;
private Cell[][] grid;
private Cell[][] initialGrid;
public Grid(int width, int height, double foodDensity){
this.width = width;
this.height = height;
this.foodDensity = foodDensity;
grid = new Cell[width][height];
initialGrid = new Cell[width][height];
initRandomWorld();
}
public Grid(int width, int height){
this(width, height, 0);
}
public void resetWorld(){
grid = Util.deepCopyCellGrid(initialGrid);
}
public void initRandomWorld(){
for(int x = 0; x < width; ++x){
for(int y = 0; y < height; ++y){
if( RNG.getRandom().nextDouble() < foodDensity){
grid[x][y] = new Cell(new Point(x,y), CellType.FREE, 1);
initialGrid[x][y] = new Cell(new Point(x,y), CellType.FREE, 1);
}else{
grid[x][y] = new Cell(new Point(x,y), CellType.FREE);
initialGrid[x][y] = new Cell(new Point(x,y), CellType.FREE);
}
}
}
start = new Point(RNG.getRandom().nextInt(width), RNG.getRandom().nextInt(height));
grid[start.x][start.y] = new Cell(new Point(start.x, start.y), CellType.START);
initialGrid[start.x][start.y] = new Cell(new Point(start.x, start.y), CellType.START);
}
public Point getStartPoint(){

View File

@ -1,17 +1,16 @@
package evironment.antGame;
public class Reward {
public static final double FOOD_PICK_UP_SUCCESS = 0;
public static final double FOOD_PICK_UP_FAIL_NO_FOOD = 0;
public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = 0;
public static final double FOOD_PICK_UP_SUCCESS = 1;
public static final double FOOD_PICK_UP_FAIL_NO_FOOD = -1;
public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = -1;
public static final double FOOD_DROP_DOWN_FAIL_NO_FOOD = 0;
public static final double FOOD_DROP_DOWN_FAIL_NOT_START = 0;
public static final double FOOD_DROP_DOWN_FAIL_NO_FOOD = -1;
public static final double FOOD_DROP_DOWN_FAIL_NOT_START = -1;
public static final double FOOD_DROP_DOWN_SUCCESS = 1;
public static final double UNKNOWN_FIELD_EXPLORED = 0;
public static final double RAN_INTO_WALL = 0;
public static final double RAN_INTO_OBSTACLE = 0;
public static final double UNKNOWN_FIELD_EXPLORED = 1;
public static final double RAN_INTO_WALL = -1;
public static final double RAN_INTO_OBSTACLE = -1;
}

View File

@ -0,0 +1,14 @@
package evironment.antGame;
public class Util {
public static Cell[][] deepCopyCellGrid(Cell[][] toCopy){
Cell[][] cells = new Cell[toCopy.length][toCopy[0].length];
for (int i = 0; i < cells.length; i++) {
for (int j = 0; j < cells[i].length; j++) {
// calling copy constructor of Cell class
cells[i][j] = new Cell(toCopy[i][j]);
}
}
return cells;
}
}

View File

@ -15,7 +15,7 @@ public class RunningAnt {
.setAllowedActions(AntAction.values())
.setMethod(Method.MC_ONPOLICY_EGREEDY)
.setDelay(10)
.setEpisodes(1000);
.setEpisodes(10000);
rl.start();
}
}