add fix RNG, add extended interface EpsilonPolicy and move rewardHistory to model instead of view
- only setting the seed of RNG once at the beginning and not reseeding it afterwards. Deep copying the initial AntWorld to use as blueprint for resetting the world instead of reseeding and creating pesudo random again. Reseeding the RNG has influence action selecting to always choose the same trajectory. - instance of is used to determine if policy has epsilon or not and the view will adopt to this, only showing epsilon slider if policy has epsilon
This commit is contained in:
parent
e0160ca1df
commit
7db5a2af3b
|
@ -10,8 +10,11 @@ import lombok.Getter;
|
|||
import lombok.Setter;
|
||||
|
||||
import javax.swing.*;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
|
||||
@Getter
|
||||
public abstract class Learning<A extends Enum> {
|
||||
|
@ -20,38 +23,43 @@ public abstract class Learning<A extends Enum> {
|
|||
protected StateActionTable<A> stateActionTable;
|
||||
protected Environment<A> environment;
|
||||
protected float discountFactor;
|
||||
@Setter
|
||||
protected float epsilon;
|
||||
protected Set<LearningListener> learningListeners;
|
||||
@Setter
|
||||
protected int delay;
|
||||
private List<Double> rewardHistory;
|
||||
|
||||
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay){
|
||||
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay){
|
||||
this.environment = environment;
|
||||
this.actionSpace = actionSpace;
|
||||
this.discountFactor = discountFactor;
|
||||
this.epsilon = epsilon;
|
||||
this.delay = delay;
|
||||
learningListeners = new HashSet<>();
|
||||
rewardHistory = new CopyOnWriteArrayList<>();
|
||||
}
|
||||
|
||||
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon){
|
||||
this(environment, actionSpace, discountFactor, epsilon, LearningConfig.DEFAULT_DELAY);
|
||||
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor){
|
||||
this(environment, actionSpace, discountFactor, LearningConfig.DEFAULT_DELAY);
|
||||
}
|
||||
|
||||
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay){
|
||||
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, delay);
|
||||
}
|
||||
|
||||
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace){
|
||||
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, LearningConfig.DEFAULT_DELAY);
|
||||
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_DELAY);
|
||||
}
|
||||
|
||||
|
||||
public abstract void learn(int nrOfEpisodes);
|
||||
|
||||
public void addListener(LearningListener learningListener){
|
||||
learningListeners.add(learningListener);
|
||||
}
|
||||
|
||||
protected void dispatchEpisodeEnd(double sum){
|
||||
protected void dispatchEpisodeEnd(double recentSumOfRewards){
|
||||
rewardHistory.add(recentSumOfRewards);
|
||||
for(LearningListener l: learningListeners) {
|
||||
l.onEpisodeEnd(sum);
|
||||
l.onEpisodeEnd(rewardHistory);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -4,6 +4,8 @@ import core.*;
|
|||
import core.algo.Learning;
|
||||
import core.policy.EpsilonGreedyPolicy;
|
||||
import javafx.util.Pair;
|
||||
import lombok.Setter;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
|
@ -26,13 +28,18 @@ import java.util.*;
|
|||
*/
|
||||
public class MonteCarloOnPolicyEGreedy<A extends Enum> extends Learning<A> {
|
||||
|
||||
public MonteCarloOnPolicyEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace) {
|
||||
super(environment, actionSpace);
|
||||
discountFactor = 1f;
|
||||
this.policy = new EpsilonGreedyPolicy<>(0.1f);
|
||||
this.stateActionTable = new StateActionHashTable<>(actionSpace);
|
||||
public MonteCarloOnPolicyEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
|
||||
super(environment, actionSpace, discountFactor, delay);
|
||||
|
||||
this.policy = new EpsilonGreedyPolicy<>(epsilon);
|
||||
this.stateActionTable = new StateActionHashTable<>(this.actionSpace);
|
||||
}
|
||||
|
||||
public MonteCarloOnPolicyEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
|
||||
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void learn(int nrOfEpisodes) {
|
||||
|
||||
|
|
|
@ -7,9 +7,9 @@ import core.algo.Learning;
|
|||
import core.algo.Method;
|
||||
import core.algo.mc.MonteCarloOnPolicyEGreedy;
|
||||
import core.gui.View;
|
||||
import core.policy.EpsilonPolicy;
|
||||
|
||||
import javax.swing.*;
|
||||
import java.util.Optional;
|
||||
|
||||
public class RLController<A extends Enum> implements ViewListener{
|
||||
protected Environment<A> environment;
|
||||
|
@ -30,28 +30,38 @@ public class RLController<A extends Enum> implements ViewListener{
|
|||
|
||||
switch (method){
|
||||
case MC_ONPOLICY_EGREEDY:
|
||||
learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace);
|
||||
learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, delay);
|
||||
break;
|
||||
case TD_ONPOLICY:
|
||||
break;
|
||||
default:
|
||||
throw new RuntimeException("Undefined method");
|
||||
}
|
||||
SwingUtilities.invokeLater(() ->{
|
||||
view = new View<>(learning, this);
|
||||
learning.addListener(view);
|
||||
});
|
||||
/*
|
||||
not using SwingUtilities here on purpose to ensure the view is fully
|
||||
initialized and can be passed as LearningListener.
|
||||
*/
|
||||
view = new View<>(learning, this);
|
||||
learning.addListener(view);
|
||||
learning.learn(nrOfEpisodes);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onEpsilonChange(float epsilon) {
|
||||
learning.setEpsilon(epsilon);
|
||||
SwingUtilities.invokeLater(() -> view.updateLearningInfoPanel());
|
||||
if(learning.getPolicy() instanceof EpsilonPolicy){
|
||||
((EpsilonPolicy<A>) learning.getPolicy()).setEpsilon(epsilon);
|
||||
SwingUtilities.invokeLater(() -> view.updateLearningInfoPanel());
|
||||
}else{
|
||||
System.out.println("Trying to call inEpsilonChange on non-epsilon policy");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onDelayChange(int delay) {
|
||||
learning.setDelay(delay);
|
||||
SwingUtilities.invokeLater(() -> {
|
||||
view.updateLearningInfoPanel();
|
||||
});
|
||||
}
|
||||
|
||||
public RLController<A> setMethod(Method method){
|
||||
|
|
|
@ -2,6 +2,7 @@ package core.gui;
|
|||
|
||||
import core.algo.Learning;
|
||||
import core.controller.ViewListener;
|
||||
import core.policy.EpsilonPolicy;
|
||||
|
||||
import javax.swing.*;
|
||||
|
||||
|
@ -11,6 +12,7 @@ public class LearningInfoPanel extends JPanel {
|
|||
private JLabel discountLabel;
|
||||
private JLabel epsilonLabel;
|
||||
private JSlider epsilonSlider;
|
||||
private JLabel delayLabel;
|
||||
private JSlider delaySlider;
|
||||
|
||||
public LearningInfoPanel(Learning learning, ViewListener viewListener){
|
||||
|
@ -19,12 +21,19 @@ public class LearningInfoPanel extends JPanel {
|
|||
policyLabel = new JLabel();
|
||||
discountLabel = new JLabel();
|
||||
epsilonLabel = new JLabel();
|
||||
epsilonSlider = new JSlider(0, 100, (int)(learning.getEpsilon() * 100));
|
||||
epsilonSlider.addChangeListener(e -> viewListener.onEpsilonChange(epsilonSlider.getValue() / 100f));
|
||||
delayLabel = new JLabel();
|
||||
delaySlider = new JSlider(0,1000, learning.getDelay());
|
||||
delaySlider.addChangeListener(e -> viewListener.onDelayChange(delaySlider.getValue()));
|
||||
add(policyLabel);
|
||||
add(discountLabel);
|
||||
add(epsilonLabel);
|
||||
add(epsilonSlider);
|
||||
if(learning.getPolicy() instanceof EpsilonPolicy){
|
||||
epsilonSlider = new JSlider(0, 100, (int)((EpsilonPolicy)learning.getPolicy()).getEpsilon() * 100);
|
||||
epsilonSlider.addChangeListener(e -> viewListener.onEpsilonChange(epsilonSlider.getValue() / 100f));
|
||||
add(epsilonLabel);
|
||||
add(epsilonSlider);
|
||||
}
|
||||
add(delayLabel);
|
||||
add(delaySlider);
|
||||
refreshLabels();
|
||||
setVisible(true);
|
||||
}
|
||||
|
@ -32,10 +41,9 @@ public class LearningInfoPanel extends JPanel {
|
|||
public void refreshLabels(){
|
||||
policyLabel.setText("Policy: " + learning.getPolicy().getClass());
|
||||
discountLabel.setText("Discount factor: " + learning.getDiscountFactor());
|
||||
epsilonLabel.setText("Exploration (Epsilon): " + learning.getEpsilon());
|
||||
}
|
||||
|
||||
protected JSlider getEpsilonSlider(){
|
||||
return epsilonSlider;
|
||||
if(learning.getPolicy() instanceof EpsilonPolicy){
|
||||
epsilonLabel.setText("Exploration (Epsilon): " + ((EpsilonPolicy)learning.getPolicy()).getEpsilon());
|
||||
}
|
||||
delayLabel.setText("Delay (ms): " + learning.getDelay());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,12 +23,10 @@ public class View<A extends Enum> implements LearningListener {
|
|||
private JFrame mainFrame;
|
||||
private XChartPanel<XYChart> rewardChartPanel;
|
||||
private ViewListener viewListener;
|
||||
private List<Double> rewardHistory;
|
||||
|
||||
public View(Learning<A> learning, ViewListener viewListener){
|
||||
this.learning = learning;
|
||||
this.viewListener = viewListener;
|
||||
rewardHistory = new ArrayList<>();
|
||||
this.initMainFrame();
|
||||
}
|
||||
|
||||
|
@ -78,8 +76,7 @@ public class View<A extends Enum> implements LearningListener {
|
|||
};
|
||||
}
|
||||
|
||||
public void updateRewardGraph(double recentReward){
|
||||
rewardHistory.add(recentReward);
|
||||
public void updateRewardGraph(List<Double> rewardHistory){
|
||||
chart.updateXYSeries("randomWalk", null, rewardHistory, null);
|
||||
rewardChartPanel.revalidate();
|
||||
rewardChartPanel.repaint();
|
||||
|
@ -89,10 +86,11 @@ public class View<A extends Enum> implements LearningListener {
|
|||
this.learningInfoPanel.refreshLabels();
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void onEpisodeEnd(double sumOfRewards) {
|
||||
SwingUtilities.invokeLater(()->updateRewardGraph(sumOfRewards));
|
||||
public void onEpisodeEnd(List<Double> rewardHistory) {
|
||||
SwingUtilities.invokeLater(()->{
|
||||
updateRewardGraph(rewardHistory);
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package core.listener;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface LearningListener{
|
||||
void onEpisodeEnd(double sumOfRewards);
|
||||
void onEpisodeEnd(List<Double> rewardHistory);
|
||||
void onEpisodeStart();
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package core.policy;
|
||||
|
||||
import core.RNG;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -12,7 +14,9 @@ import java.util.Map;
|
|||
*
|
||||
* @param <A> Discrete Action Enum
|
||||
*/
|
||||
public class EpsilonGreedyPolicy<A extends Enum> implements Policy<A>{
|
||||
public class EpsilonGreedyPolicy<A extends Enum> implements EpsilonPolicy<A>{
|
||||
@Setter
|
||||
@Getter
|
||||
private float epsilon;
|
||||
private RandomPolicy<A> randomPolicy;
|
||||
private GreedyPolicy<A> greedyPolicy;
|
||||
|
@ -22,8 +26,10 @@ public class EpsilonGreedyPolicy<A extends Enum> implements Policy<A>{
|
|||
randomPolicy = new RandomPolicy<>();
|
||||
greedyPolicy = new GreedyPolicy<>();
|
||||
}
|
||||
|
||||
@Override
|
||||
public A chooseAction(Map<A, Double> actionValues) {
|
||||
System.out.println("current epsilon " + epsilon);
|
||||
if(RNG.getRandom().nextFloat() < epsilon){
|
||||
// Take random action
|
||||
return randomPolicy.chooseAction(actionValues);
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
package core.policy;
|
||||
|
||||
public interface EpsilonPolicy<A extends Enum> extends Policy<A> {
|
||||
float getEpsilon();
|
||||
void setEpsilon(float epsilon);
|
||||
}
|
|
@ -1,7 +1,5 @@
|
|||
package core.policy;
|
||||
|
||||
import core.RNG;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
@ -13,7 +11,7 @@ public class GreedyPolicy<A extends Enum> implements Policy<A> {
|
|||
public A chooseAction(Map<A, Double> actionValues) {
|
||||
if(actionValues.size() == 0) throw new RuntimeException("Empty actionActionValues set");
|
||||
|
||||
Double highestValueAction = null;
|
||||
Double highestValueAction = null;
|
||||
|
||||
List<A> equalHigh = new ArrayList<>();
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ public class RandomPolicy<A extends Enum> implements Policy<A>{
|
|||
@Override
|
||||
public A chooseAction(Map<A, Double> actionValues) {
|
||||
int idx = RNG.getRandom().nextInt(actionValues.size());
|
||||
System.out.println("selected action " + idx);
|
||||
int i = 0;
|
||||
for(A action : actionValues.keySet()){
|
||||
if(i++ == idx) return action;
|
||||
|
|
|
@ -20,13 +20,13 @@ public class AntState implements State, Visualizable {
|
|||
private final int computedHash;
|
||||
|
||||
public AntState(Cell[][] knownWorld, Point antPosition, boolean hasFood){
|
||||
this.knownWorld = deepCopyCellGrid(knownWorld);
|
||||
this.knownWorld = Util.deepCopyCellGrid(knownWorld);
|
||||
this.pos = deepCopyAntPosition(antPosition);
|
||||
this.hasFood = hasFood;
|
||||
computedHash = computeHash();
|
||||
}
|
||||
|
||||
private int computeHash(){
|
||||
private int computeHash() {
|
||||
int hash = 7;
|
||||
int prime = 31;
|
||||
|
||||
|
@ -43,20 +43,10 @@ public class AntState implements State, Visualizable {
|
|||
}
|
||||
hash = prime * hash + unknown;
|
||||
hash = prime * hash * diff;
|
||||
hash = prime * hash + (hasFood ? 1:0);
|
||||
hash = prime * hash + (hasFood ? 1 : 0);
|
||||
hash = prime * hash + pos.hashCode();
|
||||
return hash;
|
||||
}
|
||||
private Cell[][] deepCopyCellGrid(Cell[][] toCopy){
|
||||
Cell[][] cells = new Cell[toCopy.length][toCopy[0].length];
|
||||
for (int i = 0; i < cells.length; i++) {
|
||||
for (int j = 0; j < cells[i].length; j++) {
|
||||
// calling copy constructor of Cell class
|
||||
cells[i][j] = new Cell(toCopy[i][j]);
|
||||
}
|
||||
}
|
||||
return cells;
|
||||
}
|
||||
|
||||
private Point deepCopyAntPosition(Point toCopy){
|
||||
return new Point(toCopy.x,toCopy.y);
|
||||
|
|
|
@ -151,9 +151,12 @@ public class AntWorld implements Environment<AntAction>{
|
|||
done = grid.isAllFoodCollected();
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
if(!done){
|
||||
reward = -1;
|
||||
}
|
||||
*/
|
||||
if(++tick == maxEpisodeTicks){
|
||||
done = true;
|
||||
}
|
||||
|
@ -172,8 +175,7 @@ public class AntWorld implements Environment<AntAction>{
|
|||
}
|
||||
|
||||
public State reset() {
|
||||
RNG.reseed();
|
||||
grid.initRandomWorld();
|
||||
grid.resetWorld();
|
||||
antAgent.initUnknownWorld();
|
||||
tick = 0;
|
||||
myAnt.getPos().setLocation(grid.getStartPoint());
|
||||
|
@ -207,7 +209,8 @@ public class AntWorld implements Environment<AntAction>{
|
|||
|
||||
Learning<AntAction> monteCarlo = new MonteCarloOnPolicyEGreedy<>(
|
||||
new AntWorld(3, 3, 0.1),
|
||||
new ListDiscreteActionSpace<>(AntAction.values())
|
||||
new ListDiscreteActionSpace<>(AntAction.values()),
|
||||
5
|
||||
);
|
||||
monteCarlo.learn(20000);
|
||||
}
|
||||
|
|
|
@ -10,31 +10,37 @@ public class Grid {
|
|||
private double foodDensity;
|
||||
private Point start;
|
||||
private Cell[][] grid;
|
||||
private Cell[][] initialGrid;
|
||||
|
||||
public Grid(int width, int height, double foodDensity){
|
||||
this.width = width;
|
||||
this.height = height;
|
||||
this.foodDensity = foodDensity;
|
||||
|
||||
grid = new Cell[width][height];
|
||||
initialGrid = new Cell[width][height];
|
||||
initRandomWorld();
|
||||
}
|
||||
|
||||
public Grid(int width, int height){
|
||||
this(width, height, 0);
|
||||
}
|
||||
|
||||
public void resetWorld(){
|
||||
grid = Util.deepCopyCellGrid(initialGrid);
|
||||
}
|
||||
|
||||
public void initRandomWorld(){
|
||||
for(int x = 0; x < width; ++x){
|
||||
for(int y = 0; y < height; ++y){
|
||||
if( RNG.getRandom().nextDouble() < foodDensity){
|
||||
grid[x][y] = new Cell(new Point(x,y), CellType.FREE, 1);
|
||||
initialGrid[x][y] = new Cell(new Point(x,y), CellType.FREE, 1);
|
||||
}else{
|
||||
grid[x][y] = new Cell(new Point(x,y), CellType.FREE);
|
||||
initialGrid[x][y] = new Cell(new Point(x,y), CellType.FREE);
|
||||
}
|
||||
}
|
||||
}
|
||||
start = new Point(RNG.getRandom().nextInt(width), RNG.getRandom().nextInt(height));
|
||||
grid[start.x][start.y] = new Cell(new Point(start.x, start.y), CellType.START);
|
||||
initialGrid[start.x][start.y] = new Cell(new Point(start.x, start.y), CellType.START);
|
||||
}
|
||||
|
||||
public Point getStartPoint(){
|
||||
|
|
|
@ -1,17 +1,16 @@
|
|||
package evironment.antGame;
|
||||
|
||||
public class Reward {
|
||||
public static final double FOOD_PICK_UP_SUCCESS = 0;
|
||||
public static final double FOOD_PICK_UP_FAIL_NO_FOOD = 0;
|
||||
public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = 0;
|
||||
public static final double FOOD_PICK_UP_SUCCESS = 1;
|
||||
public static final double FOOD_PICK_UP_FAIL_NO_FOOD = -1;
|
||||
public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = -1;
|
||||
|
||||
public static final double FOOD_DROP_DOWN_FAIL_NO_FOOD = 0;
|
||||
public static final double FOOD_DROP_DOWN_FAIL_NOT_START = 0;
|
||||
public static final double FOOD_DROP_DOWN_FAIL_NO_FOOD = -1;
|
||||
public static final double FOOD_DROP_DOWN_FAIL_NOT_START = -1;
|
||||
public static final double FOOD_DROP_DOWN_SUCCESS = 1;
|
||||
|
||||
public static final double UNKNOWN_FIELD_EXPLORED = 0;
|
||||
|
||||
public static final double RAN_INTO_WALL = 0;
|
||||
public static final double RAN_INTO_OBSTACLE = 0;
|
||||
public static final double UNKNOWN_FIELD_EXPLORED = 1;
|
||||
|
||||
public static final double RAN_INTO_WALL = -1;
|
||||
public static final double RAN_INTO_OBSTACLE = -1;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,14 @@
|
|||
package evironment.antGame;
|
||||
|
||||
public class Util {
|
||||
public static Cell[][] deepCopyCellGrid(Cell[][] toCopy){
|
||||
Cell[][] cells = new Cell[toCopy.length][toCopy[0].length];
|
||||
for (int i = 0; i < cells.length; i++) {
|
||||
for (int j = 0; j < cells[i].length; j++) {
|
||||
// calling copy constructor of Cell class
|
||||
cells[i][j] = new Cell(toCopy[i][j]);
|
||||
}
|
||||
}
|
||||
return cells;
|
||||
}
|
||||
}
|
|
@ -15,7 +15,7 @@ public class RunningAnt {
|
|||
.setAllowedActions(AntAction.values())
|
||||
.setMethod(Method.MC_ONPOLICY_EGREEDY)
|
||||
.setDelay(10)
|
||||
.setEpisodes(1000);
|
||||
.setEpisodes(10000);
|
||||
rl.start();
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue