add dino jumping environment, deterministic/reproducable behaviour and save-and-load feature
- add feature to save and load learning progress (Q-Table) and current episode count - episode end is now purely decided by environment instead of monte carlo algo capping it on 10 actions - using linkedHashMap on all locations to ensure deterministic behaviour - fixed major RNG issue to reproduce algorithmic behaviour - clearing rewardHistory, to only save the last 10k rewards - added google dino jump environment
This commit is contained in:
parent
b1246f62cc
commit
5a4e380faf
|
@ -1,20 +1,19 @@
|
||||||
package core;
|
package core;
|
||||||
|
|
||||||
import evironment.antGame.AntAction;
|
import java.io.Serializable;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Premise: All states have the complete action space
|
* Premise: All states have the complete action space
|
||||||
*/
|
*/
|
||||||
public class StateActionHashTable<A extends Enum> implements StateActionTable<A> {
|
public class DeterministicStateActionTable<A extends Enum> implements StateActionTable<A>, Serializable {
|
||||||
|
|
||||||
private final Map<State, Map<A, Double>> table;
|
private final Map<State, Map<A, Double>> table;
|
||||||
private DiscreteActionSpace<A> discreteActionSpace;
|
private DiscreteActionSpace<A> discreteActionSpace;
|
||||||
|
|
||||||
public StateActionHashTable(DiscreteActionSpace<A> discreteActionSpace){
|
public DeterministicStateActionTable(DiscreteActionSpace<A> discreteActionSpace){
|
||||||
table = new HashMap<>();
|
table = new LinkedHashMap<>();
|
||||||
this.discreteActionSpace = discreteActionSpace;
|
this.discreteActionSpace = discreteActionSpace;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -61,19 +60,15 @@ public class StateActionHashTable<A extends Enum> implements StateActionTable<A>
|
||||||
return table.get(state);
|
return table.get(state);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void main(String[] args) {
|
|
||||||
DiscreteActionSpace<AntAction> da = new ListDiscreteActionSpace<>(AntAction.MOVE_RIGHT, AntAction.PICK_UP);
|
|
||||||
StateActionTable sat = new StateActionHashTable<>(da);
|
|
||||||
State t = new State() {
|
|
||||||
};
|
|
||||||
|
|
||||||
System.out.println(sat.getActionValues(t));
|
|
||||||
}
|
|
||||||
private Map<A, Double> createDefaultActionValues(){
|
private Map<A, Double> createDefaultActionValues(){
|
||||||
final Map<A, Double> defaultActionValues = new HashMap<>();
|
final Map<A, Double> defaultActionValues = new LinkedHashMap<>();
|
||||||
for(A action: discreteActionSpace.getAllActions()){
|
for(A action: discreteActionSpace.getAllActions()){
|
||||||
defaultActionValues.put(action, DEFAULT_VALUE);
|
defaultActionValues.put(action, DEFAULT_VALUE);
|
||||||
}
|
}
|
||||||
return defaultActionValues;
|
return defaultActionValues;
|
||||||
}
|
}
|
||||||
|
@Override
|
||||||
|
public int getStateCount(){
|
||||||
|
return table.size();
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -1,7 +1,7 @@
|
||||||
package core;
|
package core;
|
||||||
|
|
||||||
public class LearningConfig {
|
public class LearningConfig {
|
||||||
public static final int DEFAULT_DELAY = 1;
|
public static final int DEFAULT_DELAY = 30;
|
||||||
public static final float DEFAULT_EPSILON = 0.1f;
|
public static final float DEFAULT_EPSILON = 0.1f;
|
||||||
public static final float DEFAULT_DISCOUNT_FACTOR = 1.0f;
|
public static final float DEFAULT_DISCOUNT_FACTOR = 1.0f;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
package core;
|
package core;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSpace<A> {
|
public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSpace<A>, Serializable {
|
||||||
private List<A> actions;
|
private List<A> actions;
|
||||||
|
|
||||||
public ListDiscreteActionSpace(){
|
public ListDiscreteActionSpace(){
|
||||||
|
|
|
@ -1,12 +1,14 @@
|
||||||
package core;
|
package core;
|
||||||
|
|
||||||
|
import java.security.SecureRandom;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
public class RNG {
|
public class RNG {
|
||||||
private static Random rng;
|
private static SecureRandom rng;
|
||||||
private static int seed = 123;
|
private static int seed = 123;
|
||||||
static {
|
static {
|
||||||
rng = new Random(seed);
|
rng = new SecureRandom();
|
||||||
|
rng.setSeed(seed);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static Random getRandom() {
|
public static Random getRandom() {
|
||||||
|
|
|
@ -0,0 +1,13 @@
|
||||||
|
package core;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Getter;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
|
||||||
|
@AllArgsConstructor
|
||||||
|
@Getter
|
||||||
|
public class SaveState<A extends Enum> implements Serializable {
|
||||||
|
private StateActionTable<A> stateActionTable;
|
||||||
|
private int currentEpisode;
|
||||||
|
}
|
|
@ -7,6 +7,6 @@ public interface StateActionTable<A extends Enum> {
|
||||||
|
|
||||||
double getValue(State state, A action);
|
double getValue(State state, A action);
|
||||||
void setValue(State state, A action, double value);
|
void setValue(State state, A action, double value);
|
||||||
|
int getStateCount();
|
||||||
Map<A, Double> getActionValues(State state);
|
Map<A, Double> getActionValues(State state);
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,13 +3,17 @@ package core.algo;
|
||||||
import core.DiscreteActionSpace;
|
import core.DiscreteActionSpace;
|
||||||
import core.Environment;
|
import core.Environment;
|
||||||
import core.listener.LearningListener;
|
import core.listener.LearningListener;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.Setter;
|
||||||
|
|
||||||
public abstract class EpisodicLearning<A extends Enum> extends Learning<A> implements Episodic {
|
public abstract class EpisodicLearning<A extends Enum> extends Learning<A> implements Episodic {
|
||||||
|
@Setter
|
||||||
|
@Getter
|
||||||
protected int currentEpisode;
|
protected int currentEpisode;
|
||||||
protected int episodesToLearn;
|
protected int episodesToLearn;
|
||||||
protected volatile int episodePerSecond;
|
protected volatile int episodePerSecond;
|
||||||
protected int episodeSumCurrentSecond;
|
protected int episodeSumCurrentSecond;
|
||||||
private volatile boolean meseaureEpisodeBenchMark;
|
private volatile boolean measureEpisodeBenchMark;
|
||||||
|
|
||||||
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay) {
|
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay) {
|
||||||
super(environment, actionSpace, discountFactor, delay);
|
super(environment, actionSpace, discountFactor, delay);
|
||||||
|
@ -29,6 +33,9 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
|
||||||
|
|
||||||
protected void dispatchEpisodeEnd(double recentSumOfRewards){
|
protected void dispatchEpisodeEnd(double recentSumOfRewards){
|
||||||
++episodeSumCurrentSecond;
|
++episodeSumCurrentSecond;
|
||||||
|
if(rewardHistory.size() > 10000){
|
||||||
|
rewardHistory.clear();
|
||||||
|
}
|
||||||
rewardHistory.add(recentSumOfRewards);
|
rewardHistory.add(recentSumOfRewards);
|
||||||
for(LearningListener l: learningListeners) {
|
for(LearningListener l: learningListeners) {
|
||||||
l.onEpisodeEnd(rewardHistory);
|
l.onEpisodeEnd(rewardHistory);
|
||||||
|
@ -47,9 +54,9 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
|
||||||
}
|
}
|
||||||
|
|
||||||
public void learn(int nrOfEpisodes){
|
public void learn(int nrOfEpisodes){
|
||||||
meseaureEpisodeBenchMark = true;
|
measureEpisodeBenchMark = true;
|
||||||
new Thread(()->{
|
new Thread(()->{
|
||||||
while(meseaureEpisodeBenchMark){
|
while(measureEpisodeBenchMark){
|
||||||
episodePerSecond = episodeSumCurrentSecond;
|
episodePerSecond = episodeSumCurrentSecond;
|
||||||
episodeSumCurrentSecond = 0;
|
episodeSumCurrentSecond = 0;
|
||||||
try {
|
try {
|
||||||
|
@ -65,7 +72,7 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
|
||||||
nextEpisode();
|
nextEpisode();
|
||||||
}
|
}
|
||||||
dispatchLearningEnd();
|
dispatchLearningEnd();
|
||||||
meseaureEpisodeBenchMark = false;
|
measureEpisodeBenchMark = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract void nextEpisode();
|
protected abstract void nextEpisode();
|
||||||
|
|
|
@ -9,15 +9,17 @@ import core.policy.Policy;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.concurrent.CopyOnWriteArrayList;
|
import java.util.concurrent.CopyOnWriteArrayList;
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
public abstract class Learning<A extends Enum> {
|
public abstract class Learning<A extends Enum> implements Serializable {
|
||||||
protected Policy<A> policy;
|
protected Policy<A> policy;
|
||||||
protected DiscreteActionSpace<A> actionSpace;
|
protected DiscreteActionSpace<A> actionSpace;
|
||||||
|
@Setter
|
||||||
protected StateActionTable<A> stateActionTable;
|
protected StateActionTable<A> stateActionTable;
|
||||||
protected Environment<A> environment;
|
protected Environment<A> environment;
|
||||||
protected float discountFactor;
|
protected float discountFactor;
|
||||||
|
@ -47,6 +49,7 @@ public abstract class Learning<A extends Enum> {
|
||||||
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_DELAY);
|
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_DELAY);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public abstract void learn();
|
public abstract void learn();
|
||||||
|
|
||||||
public void addListener(LearningListener learningListener) {
|
public void addListener(LearningListener learningListener) {
|
||||||
|
@ -70,4 +73,5 @@ public abstract class Learning<A extends Enum> {
|
||||||
l.onLearningEnd();
|
l.onLearningEnd();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,7 +35,7 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
|
||||||
super(environment, actionSpace, discountFactor, delay);
|
super(environment, actionSpace, discountFactor, delay);
|
||||||
currentEpisode = 0;
|
currentEpisode = 0;
|
||||||
this.policy = new EpsilonGreedyPolicy<>(epsilon);
|
this.policy = new EpsilonGreedyPolicy<>(epsilon);
|
||||||
this.stateActionTable = new StateActionHashTable<>(this.actionSpace);
|
this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace);
|
||||||
returnSum = new HashMap<>();
|
returnSum = new HashMap<>();
|
||||||
returnCount = new HashMap<>();
|
returnCount = new HashMap<>();
|
||||||
}
|
}
|
||||||
|
@ -57,16 +57,15 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
}
|
}
|
||||||
double sumOfRewards = 0;
|
double sumOfRewards = 0;
|
||||||
for (int j = 0; j < 10; ++j) {
|
StepResultEnvironment envResult = null;
|
||||||
|
while(envResult == null || !envResult.isDone()){
|
||||||
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
|
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
|
||||||
A chosenAction = policy.chooseAction(actionValues);
|
A chosenAction = policy.chooseAction(actionValues);
|
||||||
StepResultEnvironment envResult = environment.step(chosenAction);
|
envResult = environment.step(chosenAction);
|
||||||
State nextState = envResult.getState();
|
State nextState = envResult.getState();
|
||||||
sumOfRewards += envResult.getReward();
|
sumOfRewards += envResult.getReward();
|
||||||
episode.add(new StepResult<>(state, chosenAction, envResult.getReward()));
|
episode.add(new StepResult<>(state, chosenAction, envResult.getReward()));
|
||||||
|
|
||||||
if (envResult.isDone()) break;
|
|
||||||
|
|
||||||
state = nextState;
|
state = nextState;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
@ -78,13 +77,13 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
|
||||||
}
|
}
|
||||||
|
|
||||||
dispatchEpisodeEnd(sumOfRewards);
|
dispatchEpisodeEnd(sumOfRewards);
|
||||||
System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
|
// System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
|
||||||
Set<Pair<State, A>> stateActionPairs = new HashSet<>();
|
Set<Pair<State, A>> stateActionPairs = new LinkedHashSet<>();
|
||||||
|
|
||||||
for (StepResult<A> sr : episode) {
|
for (StepResult<A> sr : episode) {
|
||||||
stateActionPairs.add(new Pair<>(sr.getState(), sr.getAction()));
|
stateActionPairs.add(new Pair<>(sr.getState(), sr.getAction()));
|
||||||
}
|
}
|
||||||
System.out.println("stateActionPairs " + stateActionPairs.size());
|
//System.out.println("stateActionPairs " + stateActionPairs.size());
|
||||||
for (Pair<State, A> stateActionPair : stateActionPairs) {
|
for (Pair<State, A> stateActionPair : stateActionPairs) {
|
||||||
int firstOccurenceIndex = 0;
|
int firstOccurenceIndex = 0;
|
||||||
// find first occurance of state action pair
|
// find first occurance of state action pair
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
package core.controller;
|
package core.controller;
|
||||||
|
|
||||||
import core.DiscreteActionSpace;
|
import core.*;
|
||||||
import core.Environment;
|
|
||||||
import core.ListDiscreteActionSpace;
|
|
||||||
import core.algo.EpisodicLearning;
|
import core.algo.EpisodicLearning;
|
||||||
import core.algo.Learning;
|
import core.algo.Learning;
|
||||||
import core.algo.Method;
|
import core.algo.Method;
|
||||||
|
@ -14,6 +12,7 @@ import core.listener.ViewListener;
|
||||||
import core.policy.EpsilonPolicy;
|
import core.policy.EpsilonPolicy;
|
||||||
|
|
||||||
import javax.swing.*;
|
import javax.swing.*;
|
||||||
|
import java.io.*;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.ExecutorService;
|
import java.util.concurrent.ExecutorService;
|
||||||
import java.util.concurrent.Executors;
|
import java.util.concurrent.Executors;
|
||||||
|
@ -23,10 +22,12 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
||||||
protected Learning<A> learning;
|
protected Learning<A> learning;
|
||||||
protected DiscreteActionSpace<A> discreteActionSpace;
|
protected DiscreteActionSpace<A> discreteActionSpace;
|
||||||
protected LearningView learningView;
|
protected LearningView learningView;
|
||||||
private int delay;
|
|
||||||
private int nrOfEpisodes;
|
private int nrOfEpisodes;
|
||||||
private Method method;
|
private Method method;
|
||||||
private int prevDelay;
|
private int prevDelay;
|
||||||
|
private int delay = LearningConfig.DEFAULT_DELAY;
|
||||||
|
private float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR;
|
||||||
|
private float epsilon = LearningConfig.DEFAULT_EPSILON;
|
||||||
private boolean fastLearning;
|
private boolean fastLearning;
|
||||||
private boolean currentlyLearning;
|
private boolean currentlyLearning;
|
||||||
private ExecutorService learningExecutor;
|
private ExecutorService learningExecutor;
|
||||||
|
@ -36,6 +37,7 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
||||||
learningExecutor = Executors.newSingleThreadExecutor();
|
learningExecutor = Executors.newSingleThreadExecutor();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public void start(){
|
public void start(){
|
||||||
if(environment == null || discreteActionSpace == null || method == null){
|
if(environment == null || discreteActionSpace == null || method == null){
|
||||||
throw new RuntimeException("Set environment, discreteActionSpace and method before calling .start()");
|
throw new RuntimeException("Set environment, discreteActionSpace and method before calling .start()");
|
||||||
|
@ -43,7 +45,7 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
||||||
|
|
||||||
switch (method){
|
switch (method){
|
||||||
case MC_ONPOLICY_EGREEDY:
|
case MC_ONPOLICY_EGREEDY:
|
||||||
learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, delay);
|
learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
|
||||||
break;
|
break;
|
||||||
case TD_ONPOLICY:
|
case TD_ONPOLICY:
|
||||||
break;
|
break;
|
||||||
|
@ -76,6 +78,44 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onLoadState(String fileName) {
|
||||||
|
FileInputStream fis;
|
||||||
|
ObjectInput in;
|
||||||
|
try {
|
||||||
|
fis = new FileInputStream(fileName);
|
||||||
|
in = new ObjectInputStream(fis);
|
||||||
|
SaveState<A> saveState = (SaveState<A>) in.readObject();
|
||||||
|
learning.setStateActionTable(saveState.getStateActionTable());
|
||||||
|
if(learning instanceof EpisodicLearning){
|
||||||
|
((EpisodicLearning) learning).setCurrentEpisode(saveState.getCurrentEpisode());
|
||||||
|
}
|
||||||
|
in.close();
|
||||||
|
} catch (IOException | ClassNotFoundException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onSaveState(String fileName) {
|
||||||
|
FileOutputStream fos;
|
||||||
|
ObjectOutputStream out;
|
||||||
|
try{
|
||||||
|
fos = new FileOutputStream(fileName);
|
||||||
|
out = new ObjectOutputStream(fos);
|
||||||
|
int currentEpisode;
|
||||||
|
if(learning instanceof EpisodicLearning){
|
||||||
|
currentEpisode = ((EpisodicLearning) learning).getCurrentEpisode();
|
||||||
|
}else{
|
||||||
|
currentEpisode = 0;
|
||||||
|
}
|
||||||
|
out.writeObject(new SaveState<>(learning.getStateActionTable(), currentEpisode));
|
||||||
|
out.close();
|
||||||
|
}catch (IOException e){
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void onEpsilonChange(float epsilon) {
|
public void onEpsilonChange(float epsilon) {
|
||||||
if(learning.getPolicy() instanceof EpsilonPolicy){
|
if(learning.getPolicy() instanceof EpsilonPolicy){
|
||||||
|
@ -169,4 +209,13 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
||||||
this.nrOfEpisodes = nrOfEpisodes;
|
this.nrOfEpisodes = nrOfEpisodes;
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public RLController<A> setDiscountFactor(float discountFactor){
|
||||||
|
this.discountFactor = discountFactor;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
public RLController<A> setEpsilon(float epsilon){
|
||||||
|
this.epsilon = epsilon;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,6 +11,8 @@ import org.knowm.xchart.XYChart;
|
||||||
|
|
||||||
import javax.swing.*;
|
import javax.swing.*;
|
||||||
import java.awt.*;
|
import java.awt.*;
|
||||||
|
import java.awt.event.ActionEvent;
|
||||||
|
import java.io.File;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.CopyOnWriteArrayList;
|
import java.util.concurrent.CopyOnWriteArrayList;
|
||||||
|
|
||||||
|
@ -26,6 +28,8 @@ public class View<A extends Enum> implements LearningView{
|
||||||
private JFrame environmentFrame;
|
private JFrame environmentFrame;
|
||||||
private XChartPanel<XYChart> rewardChartPanel;
|
private XChartPanel<XYChart> rewardChartPanel;
|
||||||
private ViewListener viewListener;
|
private ViewListener viewListener;
|
||||||
|
private JMenuBar menuBar;
|
||||||
|
private JMenu fileMenu;
|
||||||
|
|
||||||
public View(Learning<A> learning, Environment<A> environment, ViewListener viewListener) {
|
public View(Learning<A> learning, Environment<A> environment, ViewListener viewListener) {
|
||||||
this.learning = learning;
|
this.learning = learning;
|
||||||
|
@ -38,7 +42,32 @@ public class View<A extends Enum> implements LearningView{
|
||||||
mainFrame = new JFrame();
|
mainFrame = new JFrame();
|
||||||
mainFrame.setPreferredSize(new Dimension(1280, 720));
|
mainFrame.setPreferredSize(new Dimension(1280, 720));
|
||||||
mainFrame.setLayout(new BorderLayout());
|
mainFrame.setLayout(new BorderLayout());
|
||||||
|
menuBar = new JMenuBar();
|
||||||
|
fileMenu = new JMenu("File");
|
||||||
|
menuBar.add(fileMenu);
|
||||||
|
fileMenu.add(new JMenuItem(new AbstractAction("Load") {
|
||||||
|
@Override
|
||||||
|
public void actionPerformed(ActionEvent e) {
|
||||||
|
final JFileChooser fc = new JFileChooser();
|
||||||
|
fc.setCurrentDirectory(new File(System.getProperty("user.dir")));
|
||||||
|
int returnVal = fc.showOpenDialog(mainFrame);
|
||||||
|
|
||||||
|
if (returnVal == JFileChooser.APPROVE_OPTION) {
|
||||||
|
viewListener.onLoadState(fc.getSelectedFile().toString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
|
||||||
|
fileMenu.add(new JMenuItem(new AbstractAction("Save") {
|
||||||
|
@Override
|
||||||
|
public void actionPerformed(ActionEvent e) {
|
||||||
|
String fileName = JOptionPane.showInputDialog("Enter file name", "save");
|
||||||
|
if(fileName != null){
|
||||||
|
viewListener.onSaveState(fileName);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}));
|
||||||
|
mainFrame.setJMenuBar(menuBar);
|
||||||
initLearningInfoPanel();
|
initLearningInfoPanel();
|
||||||
initRewardChart();
|
initRewardChart();
|
||||||
|
|
||||||
|
|
|
@ -5,4 +5,6 @@ public interface ViewListener {
|
||||||
void onDelayChange(int delay);
|
void onDelayChange(int delay);
|
||||||
void onFastLearnChange(boolean isFastLearn);
|
void onFastLearnChange(boolean isFastLearn);
|
||||||
void onLearnMoreEpisodes(int nrOfEpisodes);
|
void onLearnMoreEpisodes(int nrOfEpisodes);
|
||||||
|
void onLoadState(String fileName);
|
||||||
|
void onSaveState(String fileName);
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,7 +29,8 @@ public class EpsilonGreedyPolicy<A extends Enum> implements EpsilonPolicy<A>{
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public A chooseAction(Map<A, Double> actionValues) {
|
public A chooseAction(Map<A, Double> actionValues) {
|
||||||
if(RNG.getRandom().nextFloat() < epsilon){
|
float f = RNG.getRandom().nextFloat();
|
||||||
|
if(f < epsilon){
|
||||||
// Take random action
|
// Take random action
|
||||||
return randomPolicy.chooseAction(actionValues);
|
return randomPolicy.chooseAction(actionValues);
|
||||||
}else{
|
}else{
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
package core.policy;
|
package core.policy;
|
||||||
|
|
||||||
|
import core.RNG;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Random;
|
|
||||||
|
|
||||||
public class GreedyPolicy<A extends Enum> implements Policy<A> {
|
public class GreedyPolicy<A extends Enum> implements Policy<A> {
|
||||||
|
|
||||||
|
@ -26,6 +27,6 @@ public class GreedyPolicy<A extends Enum> implements Policy<A> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return equalHigh.get(new Random().nextInt(equalHigh.size()));
|
return equalHigh.get(RNG.getRandom().nextInt(equalHigh.size()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,18 +1,17 @@
|
||||||
package core.policy;
|
package core.policy;
|
||||||
|
|
||||||
import core.RNG;
|
import core.RNG;
|
||||||
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
public class RandomPolicy<A extends Enum> implements Policy<A>{
|
public class RandomPolicy<A extends Enum> implements Policy<A>{
|
||||||
@Override
|
@Override
|
||||||
public A chooseAction(Map<A, Double> actionValues) {
|
public A chooseAction(Map<A, Double> actionValues) {
|
||||||
int idx = RNG.getRandom().nextInt(actionValues.size());
|
int idx = RNG.getRandom().nextInt(actionValues.size());
|
||||||
System.out.println("selected action " + idx);
|
|
||||||
int i = 0;
|
int i = 0;
|
||||||
for(A action : actionValues.keySet()){
|
for(A action : actionValues.keySet()){
|
||||||
if(i++ == idx) return action;
|
if(i++ == idx) return action;
|
||||||
}
|
}
|
||||||
|
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,13 @@
|
||||||
|
package evironment.jumpingDino;
|
||||||
|
|
||||||
|
public class Config {
|
||||||
|
public static final int FRAME_WIDTH = 1280;
|
||||||
|
public static final int FRAME_HEIGHT = 720;
|
||||||
|
public static final int GROUND_Y = 50;
|
||||||
|
public static final int DINO_STARTING_X = 50;
|
||||||
|
public static final int DINO_SIZE = 50;
|
||||||
|
public static final int OBSTACLE_SIZE = 60;
|
||||||
|
public static final int OBSTACLE_SPEED = 30;
|
||||||
|
public static final int DINO_JUMP_SPEED = 20;
|
||||||
|
public static final int MAX_JUMP_HEIGHT = 200;
|
||||||
|
}
|
|
@ -0,0 +1,40 @@
|
||||||
|
package evironment.jumpingDino;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
|
|
||||||
|
import java.awt.*;
|
||||||
|
|
||||||
|
public class Dino extends RenderObject {
|
||||||
|
@Getter
|
||||||
|
private boolean inJump;
|
||||||
|
|
||||||
|
public Dino(int size, int x, int y, int dx, int dy, Color color) {
|
||||||
|
super(size, x, y, dx, dy, color);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void jump(){
|
||||||
|
if(!inJump){
|
||||||
|
dy = -Config.DINO_JUMP_SPEED;
|
||||||
|
inJump = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void fall(){
|
||||||
|
if(inJump){
|
||||||
|
dy = Config.DINO_JUMP_SPEED;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void tick(){
|
||||||
|
// reached max jump height
|
||||||
|
if(y + dy < Config.FRAME_HEIGHT - Config.GROUND_Y -Config.OBSTACLE_SIZE - Config.MAX_JUMP_HEIGHT){
|
||||||
|
fall();
|
||||||
|
}else if(y + dy >= Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE){
|
||||||
|
inJump = false;
|
||||||
|
dy = 0;
|
||||||
|
y = Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE;
|
||||||
|
}
|
||||||
|
super.tick();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,6 @@
|
||||||
|
package evironment.jumpingDino;
|
||||||
|
|
||||||
|
public enum DinoAction {
|
||||||
|
JUMP,
|
||||||
|
NOTHING,
|
||||||
|
}
|
|
@ -0,0 +1,32 @@
|
||||||
|
package evironment.jumpingDino;
|
||||||
|
|
||||||
|
import core.State;
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Getter;
|
||||||
|
|
||||||
|
import java.io.Serializable;
|
||||||
|
|
||||||
|
@AllArgsConstructor
|
||||||
|
@Getter
|
||||||
|
public class DinoState implements State, Serializable {
|
||||||
|
private int xDistanceToObstacle;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return Integer.toString(xDistanceToObstacle);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int hashCode() {
|
||||||
|
return this.xDistanceToObstacle;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object obj) {
|
||||||
|
if(obj instanceof DinoState){
|
||||||
|
DinoState toCompare = (DinoState) obj;
|
||||||
|
return toCompare.getXDistanceToObstacle() == this.xDistanceToObstacle;
|
||||||
|
}
|
||||||
|
return super.equals(obj);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,78 @@
|
||||||
|
package evironment.jumpingDino;
|
||||||
|
|
||||||
|
import core.Environment;
|
||||||
|
import core.State;
|
||||||
|
import core.StepResultEnvironment;
|
||||||
|
import core.gui.Visualizable;
|
||||||
|
import evironment.jumpingDino.gui.DinoWorldComponent;
|
||||||
|
import lombok.Getter;
|
||||||
|
|
||||||
|
import javax.swing.*;
|
||||||
|
import java.awt.*;
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
public class DinoWorld implements Environment<DinoAction>, Visualizable {
|
||||||
|
private Dino dino;
|
||||||
|
private Obstacle currentObstacle;
|
||||||
|
|
||||||
|
public DinoWorld(){
|
||||||
|
dino = new Dino(Config.DINO_SIZE, Config.DINO_STARTING_X, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE, 0, 0, Color.GREEN);
|
||||||
|
spawnNewObstacle();
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean ranIntoObstacle(){
|
||||||
|
Obstacle o = currentObstacle;
|
||||||
|
Dino p = dino;
|
||||||
|
|
||||||
|
boolean xAxis = (o.getX() <= p.getX() && p.getX() < o.getX() + Config.OBSTACLE_SIZE)
|
||||||
|
|| (o.getX() <= p.getX() + Config.DINO_SIZE && p.getX() + Config.DINO_SIZE < o.getX() + Config.OBSTACLE_SIZE);
|
||||||
|
|
||||||
|
boolean yAxis = (o.getY() <= p.getY() && p.getY() < o.getY() + Config.OBSTACLE_SIZE)
|
||||||
|
|| (o.getY() <= p.getY() + Config.DINO_SIZE && p.getY() + Config.DINO_SIZE < o.getY() + Config.OBSTACLE_SIZE);
|
||||||
|
|
||||||
|
return xAxis && yAxis;
|
||||||
|
}
|
||||||
|
private int getDistanceToObstacle(){
|
||||||
|
return currentObstacle.getX() - dino.getX() + Config.DINO_SIZE;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public StepResultEnvironment step(DinoAction action) {
|
||||||
|
boolean done = false;
|
||||||
|
int reward = 1;
|
||||||
|
|
||||||
|
if(action == DinoAction.JUMP){
|
||||||
|
dino.jump();
|
||||||
|
}
|
||||||
|
|
||||||
|
dino.tick();
|
||||||
|
currentObstacle.tick();
|
||||||
|
if(currentObstacle.getX() < -Config.OBSTACLE_SIZE){
|
||||||
|
spawnNewObstacle();
|
||||||
|
}
|
||||||
|
|
||||||
|
if(ranIntoObstacle()){
|
||||||
|
done = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return new StepResultEnvironment(new DinoState(getDistanceToObstacle()), reward, done, "");
|
||||||
|
}
|
||||||
|
private void spawnNewObstacle(){
|
||||||
|
currentObstacle = new Obstacle(Config.OBSTACLE_SIZE, Config.FRAME_WIDTH + Config.OBSTACLE_SIZE, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE, -Config.OBSTACLE_SPEED, 0, Color.BLACK);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void spawnDino(){
|
||||||
|
dino = new Dino(Config.DINO_SIZE, Config.DINO_STARTING_X, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE, 0, 0, Color.GREEN);
|
||||||
|
}
|
||||||
|
@Override
|
||||||
|
public State reset() {
|
||||||
|
spawnDino();
|
||||||
|
spawnNewObstacle();
|
||||||
|
return new DinoState(getDistanceToObstacle());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public JComponent visualize() {
|
||||||
|
return new DinoWorldComponent(this);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,10 @@
|
||||||
|
package evironment.jumpingDino;
|
||||||
|
|
||||||
|
import java.awt.*;
|
||||||
|
|
||||||
|
public class Obstacle extends RenderObject {
|
||||||
|
|
||||||
|
public Obstacle(int size, int x, int y, int dx, int dy, Color color) {
|
||||||
|
super(size, x, y, dx, dy, color);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,28 @@
|
||||||
|
package evironment.jumpingDino;
|
||||||
|
|
||||||
|
import lombok.AllArgsConstructor;
|
||||||
|
import lombok.Getter;
|
||||||
|
|
||||||
|
import java.awt.*;
|
||||||
|
|
||||||
|
|
||||||
|
@AllArgsConstructor
|
||||||
|
@Getter
|
||||||
|
public abstract class RenderObject {
|
||||||
|
protected int size;
|
||||||
|
protected int x;
|
||||||
|
protected int y;
|
||||||
|
protected int dx;
|
||||||
|
protected int dy;
|
||||||
|
protected Color color;
|
||||||
|
|
||||||
|
public void render(Graphics g){
|
||||||
|
g.setColor(color);
|
||||||
|
g.fillRect(x, y, size, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void tick(){
|
||||||
|
y += dy;
|
||||||
|
x += dx;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
package evironment.jumpingDino.gui;
|
||||||
|
|
||||||
|
import evironment.jumpingDino.Config;
|
||||||
|
import evironment.jumpingDino.DinoWorld;
|
||||||
|
|
||||||
|
import javax.swing.*;
|
||||||
|
import java.awt.*;
|
||||||
|
|
||||||
|
public class DinoWorldComponent extends JComponent {
|
||||||
|
private DinoWorld dinoWorld;
|
||||||
|
|
||||||
|
public DinoWorldComponent(DinoWorld dinoWorld){
|
||||||
|
this.dinoWorld = dinoWorld;
|
||||||
|
setPreferredSize(new Dimension(Config.FRAME_WIDTH, Config.FRAME_HEIGHT));
|
||||||
|
setVisible(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void paintComponent(Graphics g) {
|
||||||
|
super.paintComponent(g);
|
||||||
|
g.setColor(Color.BLACK);
|
||||||
|
g.fillRect(0, Config.FRAME_HEIGHT - Config.GROUND_Y, Config.FRAME_WIDTH, 2);
|
||||||
|
|
||||||
|
dinoWorld.getDino().render(g);
|
||||||
|
dinoWorld.getCurrentObstacle().render(g);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,23 @@
|
||||||
|
package example;
|
||||||
|
|
||||||
|
import core.RNG;
|
||||||
|
import core.algo.Method;
|
||||||
|
import core.controller.RLController;
|
||||||
|
import evironment.jumpingDino.DinoAction;
|
||||||
|
import evironment.jumpingDino.DinoWorld;
|
||||||
|
|
||||||
|
public class JumpingDino {
|
||||||
|
public static void main(String[] args) {
|
||||||
|
RNG.setSeed(55);
|
||||||
|
|
||||||
|
RLController<DinoAction> rl = new RLController<DinoAction>()
|
||||||
|
.setEnvironment(new DinoWorld())
|
||||||
|
.setAllowedActions(DinoAction.values())
|
||||||
|
.setMethod(Method.MC_ONPOLICY_EGREEDY)
|
||||||
|
.setDiscountFactor(1f)
|
||||||
|
.setEpsilon(0.15f)
|
||||||
|
.setDelay(200)
|
||||||
|
.setEpisodes(100000);
|
||||||
|
rl.start();
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue