add dino jumping environment, deterministic/reproducable behaviour and save-and-load feature

- add feature to save and load learning progress (Q-Table) and current episode count
- episode end is now purely decided by environment instead of monte carlo algo capping it on 10 actions
- using linkedHashMap on all locations to ensure deterministic behaviour
- fixed major RNG issue to reproduce algorithmic behaviour
- clearing rewardHistory, to only save the last 10k rewards
- added google dino jump environment
This commit is contained in:
Jan Löwenstrom 2019-12-22 23:33:56 +01:00
parent b1246f62cc
commit 5a4e380faf
24 changed files with 415 additions and 56 deletions

View File

@ -1,20 +1,19 @@
package core; package core;
import evironment.antGame.AntAction; import java.io.Serializable;
import java.util.LinkedHashMap;
import java.util.HashMap;
import java.util.Map; import java.util.Map;
/** /**
* Premise: All states have the complete action space * Premise: All states have the complete action space
*/ */
public class StateActionHashTable<A extends Enum> implements StateActionTable<A> { public class DeterministicStateActionTable<A extends Enum> implements StateActionTable<A>, Serializable {
private final Map<State, Map<A, Double>> table; private final Map<State, Map<A, Double>> table;
private DiscreteActionSpace<A> discreteActionSpace; private DiscreteActionSpace<A> discreteActionSpace;
public StateActionHashTable(DiscreteActionSpace<A> discreteActionSpace){ public DeterministicStateActionTable(DiscreteActionSpace<A> discreteActionSpace){
table = new HashMap<>(); table = new LinkedHashMap<>();
this.discreteActionSpace = discreteActionSpace; this.discreteActionSpace = discreteActionSpace;
} }
@ -61,19 +60,15 @@ public class StateActionHashTable<A extends Enum> implements StateActionTable<A>
return table.get(state); return table.get(state);
} }
public static void main(String[] args) {
DiscreteActionSpace<AntAction> da = new ListDiscreteActionSpace<>(AntAction.MOVE_RIGHT, AntAction.PICK_UP);
StateActionTable sat = new StateActionHashTable<>(da);
State t = new State() {
};
System.out.println(sat.getActionValues(t));
}
private Map<A, Double> createDefaultActionValues(){ private Map<A, Double> createDefaultActionValues(){
final Map<A, Double> defaultActionValues = new HashMap<>(); final Map<A, Double> defaultActionValues = new LinkedHashMap<>();
for(A action: discreteActionSpace.getAllActions()){ for(A action: discreteActionSpace.getAllActions()){
defaultActionValues.put(action, DEFAULT_VALUE); defaultActionValues.put(action, DEFAULT_VALUE);
} }
return defaultActionValues; return defaultActionValues;
} }
@Override
public int getStateCount(){
return table.size();
}
} }

View File

@ -1,7 +1,7 @@
package core; package core;
public class LearningConfig { public class LearningConfig {
public static final int DEFAULT_DELAY = 1; public static final int DEFAULT_DELAY = 30;
public static final float DEFAULT_EPSILON = 0.1f; public static final float DEFAULT_EPSILON = 0.1f;
public static final float DEFAULT_DISCOUNT_FACTOR = 1.0f; public static final float DEFAULT_DISCOUNT_FACTOR = 1.0f;
} }

View File

@ -1,10 +1,11 @@
package core; package core;
import java.io.Serializable;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSpace<A> { public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSpace<A>, Serializable {
private List<A> actions; private List<A> actions;
public ListDiscreteActionSpace(){ public ListDiscreteActionSpace(){

View File

@ -1,12 +1,14 @@
package core; package core;
import java.security.SecureRandom;
import java.util.Random; import java.util.Random;
public class RNG { public class RNG {
private static Random rng; private static SecureRandom rng;
private static int seed = 123; private static int seed = 123;
static { static {
rng = new Random(seed); rng = new SecureRandom();
rng.setSeed(seed);
} }
public static Random getRandom() { public static Random getRandom() {

View File

@ -0,0 +1,13 @@
package core;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.io.Serializable;
@AllArgsConstructor
@Getter
public class SaveState<A extends Enum> implements Serializable {
private StateActionTable<A> stateActionTable;
private int currentEpisode;
}

View File

@ -7,6 +7,6 @@ public interface StateActionTable<A extends Enum> {
double getValue(State state, A action); double getValue(State state, A action);
void setValue(State state, A action, double value); void setValue(State state, A action, double value);
int getStateCount();
Map<A, Double> getActionValues(State state); Map<A, Double> getActionValues(State state);
} }

View File

@ -3,13 +3,17 @@ package core.algo;
import core.DiscreteActionSpace; import core.DiscreteActionSpace;
import core.Environment; import core.Environment;
import core.listener.LearningListener; import core.listener.LearningListener;
import lombok.Getter;
import lombok.Setter;
public abstract class EpisodicLearning<A extends Enum> extends Learning<A> implements Episodic{ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> implements Episodic {
@Setter
@Getter
protected int currentEpisode; protected int currentEpisode;
protected int episodesToLearn; protected int episodesToLearn;
protected volatile int episodePerSecond; protected volatile int episodePerSecond;
protected int episodeSumCurrentSecond; protected int episodeSumCurrentSecond;
private volatile boolean meseaureEpisodeBenchMark; private volatile boolean measureEpisodeBenchMark;
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay) { public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay) {
super(environment, actionSpace, discountFactor, delay); super(environment, actionSpace, discountFactor, delay);
@ -29,6 +33,9 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
protected void dispatchEpisodeEnd(double recentSumOfRewards){ protected void dispatchEpisodeEnd(double recentSumOfRewards){
++episodeSumCurrentSecond; ++episodeSumCurrentSecond;
if(rewardHistory.size() > 10000){
rewardHistory.clear();
}
rewardHistory.add(recentSumOfRewards); rewardHistory.add(recentSumOfRewards);
for(LearningListener l: learningListeners) { for(LearningListener l: learningListeners) {
l.onEpisodeEnd(rewardHistory); l.onEpisodeEnd(rewardHistory);
@ -47,9 +54,9 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
} }
public void learn(int nrOfEpisodes){ public void learn(int nrOfEpisodes){
meseaureEpisodeBenchMark = true; measureEpisodeBenchMark = true;
new Thread(()->{ new Thread(()->{
while(meseaureEpisodeBenchMark){ while(measureEpisodeBenchMark){
episodePerSecond = episodeSumCurrentSecond; episodePerSecond = episodeSumCurrentSecond;
episodeSumCurrentSecond = 0; episodeSumCurrentSecond = 0;
try { try {
@ -65,7 +72,7 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
nextEpisode(); nextEpisode();
} }
dispatchLearningEnd(); dispatchLearningEnd();
meseaureEpisodeBenchMark = false; measureEpisodeBenchMark = false;
} }
protected abstract void nextEpisode(); protected abstract void nextEpisode();

View File

@ -9,15 +9,17 @@ import core.policy.Policy;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import java.io.Serializable;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CopyOnWriteArrayList;
@Getter @Getter
public abstract class Learning<A extends Enum> { public abstract class Learning<A extends Enum> implements Serializable {
protected Policy<A> policy; protected Policy<A> policy;
protected DiscreteActionSpace<A> actionSpace; protected DiscreteActionSpace<A> actionSpace;
@Setter
protected StateActionTable<A> stateActionTable; protected StateActionTable<A> stateActionTable;
protected Environment<A> environment; protected Environment<A> environment;
protected float discountFactor; protected float discountFactor;
@ -26,7 +28,7 @@ public abstract class Learning<A extends Enum> {
protected int delay; protected int delay;
protected List<Double> rewardHistory; protected List<Double> rewardHistory;
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay){ public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay) {
this.environment = environment; this.environment = environment;
this.actionSpace = actionSpace; this.actionSpace = actionSpace;
this.discountFactor = discountFactor; this.discountFactor = discountFactor;
@ -35,39 +37,41 @@ public abstract class Learning<A extends Enum> {
rewardHistory = new CopyOnWriteArrayList<>(); rewardHistory = new CopyOnWriteArrayList<>();
} }
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor){ public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor) {
this(environment, actionSpace, discountFactor, LearningConfig.DEFAULT_DELAY); this(environment, actionSpace, discountFactor, LearningConfig.DEFAULT_DELAY);
} }
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay){ public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, delay); this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, delay);
} }
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace){ public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace) {
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_DELAY); this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_DELAY);
} }
public abstract void learn(); public abstract void learn();
public void addListener(LearningListener learningListener){ public void addListener(LearningListener learningListener) {
learningListeners.add(learningListener); learningListeners.add(learningListener);
} }
protected void dispatchStepEnd(){ protected void dispatchStepEnd() {
for(LearningListener l: learningListeners){ for (LearningListener l : learningListeners) {
l.onStepEnd(); l.onStepEnd();
} }
} }
protected void dispatchLearningStart(){ protected void dispatchLearningStart() {
for(LearningListener l: learningListeners){ for (LearningListener l : learningListeners) {
l.onLearningStart(); l.onLearningStart();
} }
} }
protected void dispatchLearningEnd(){ protected void dispatchLearningEnd() {
for(LearningListener l: learningListeners){ for (LearningListener l : learningListeners) {
l.onLearningEnd(); l.onLearningEnd();
} }
} }
} }

View File

@ -35,7 +35,7 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
super(environment, actionSpace, discountFactor, delay); super(environment, actionSpace, discountFactor, delay);
currentEpisode = 0; currentEpisode = 0;
this.policy = new EpsilonGreedyPolicy<>(epsilon); this.policy = new EpsilonGreedyPolicy<>(epsilon);
this.stateActionTable = new StateActionHashTable<>(this.actionSpace); this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace);
returnSum = new HashMap<>(); returnSum = new HashMap<>();
returnCount = new HashMap<>(); returnCount = new HashMap<>();
} }
@ -57,16 +57,15 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
e.printStackTrace(); e.printStackTrace();
} }
double sumOfRewards = 0; double sumOfRewards = 0;
for (int j = 0; j < 10; ++j) { StepResultEnvironment envResult = null;
while(envResult == null || !envResult.isDone()){
Map<A, Double> actionValues = stateActionTable.getActionValues(state); Map<A, Double> actionValues = stateActionTable.getActionValues(state);
A chosenAction = policy.chooseAction(actionValues); A chosenAction = policy.chooseAction(actionValues);
StepResultEnvironment envResult = environment.step(chosenAction); envResult = environment.step(chosenAction);
State nextState = envResult.getState(); State nextState = envResult.getState();
sumOfRewards += envResult.getReward(); sumOfRewards += envResult.getReward();
episode.add(new StepResult<>(state, chosenAction, envResult.getReward())); episode.add(new StepResult<>(state, chosenAction, envResult.getReward()));
if (envResult.isDone()) break;
state = nextState; state = nextState;
try { try {
@ -78,13 +77,13 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
} }
dispatchEpisodeEnd(sumOfRewards); dispatchEpisodeEnd(sumOfRewards);
System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards); // System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
Set<Pair<State, A>> stateActionPairs = new HashSet<>(); Set<Pair<State, A>> stateActionPairs = new LinkedHashSet<>();
for (StepResult<A> sr : episode) { for (StepResult<A> sr : episode) {
stateActionPairs.add(new Pair<>(sr.getState(), sr.getAction())); stateActionPairs.add(new Pair<>(sr.getState(), sr.getAction()));
} }
System.out.println("stateActionPairs " + stateActionPairs.size()); //System.out.println("stateActionPairs " + stateActionPairs.size());
for (Pair<State, A> stateActionPair : stateActionPairs) { for (Pair<State, A> stateActionPair : stateActionPairs) {
int firstOccurenceIndex = 0; int firstOccurenceIndex = 0;
// find first occurance of state action pair // find first occurance of state action pair

View File

@ -1,8 +1,6 @@
package core.controller; package core.controller;
import core.DiscreteActionSpace; import core.*;
import core.Environment;
import core.ListDiscreteActionSpace;
import core.algo.EpisodicLearning; import core.algo.EpisodicLearning;
import core.algo.Learning; import core.algo.Learning;
import core.algo.Method; import core.algo.Method;
@ -14,6 +12,7 @@ import core.listener.ViewListener;
import core.policy.EpsilonPolicy; import core.policy.EpsilonPolicy;
import javax.swing.*; import javax.swing.*;
import java.io.*;
import java.util.List; import java.util.List;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
@ -23,10 +22,12 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
protected Learning<A> learning; protected Learning<A> learning;
protected DiscreteActionSpace<A> discreteActionSpace; protected DiscreteActionSpace<A> discreteActionSpace;
protected LearningView learningView; protected LearningView learningView;
private int delay;
private int nrOfEpisodes; private int nrOfEpisodes;
private Method method; private Method method;
private int prevDelay; private int prevDelay;
private int delay = LearningConfig.DEFAULT_DELAY;
private float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR;
private float epsilon = LearningConfig.DEFAULT_EPSILON;
private boolean fastLearning; private boolean fastLearning;
private boolean currentlyLearning; private boolean currentlyLearning;
private ExecutorService learningExecutor; private ExecutorService learningExecutor;
@ -36,6 +37,7 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
learningExecutor = Executors.newSingleThreadExecutor(); learningExecutor = Executors.newSingleThreadExecutor();
} }
public void start(){ public void start(){
if(environment == null || discreteActionSpace == null || method == null){ if(environment == null || discreteActionSpace == null || method == null){
throw new RuntimeException("Set environment, discreteActionSpace and method before calling .start()"); throw new RuntimeException("Set environment, discreteActionSpace and method before calling .start()");
@ -43,7 +45,7 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
switch (method){ switch (method){
case MC_ONPOLICY_EGREEDY: case MC_ONPOLICY_EGREEDY:
learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, delay); learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
break; break;
case TD_ONPOLICY: case TD_ONPOLICY:
break; break;
@ -76,6 +78,44 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
} }
} }
@Override
public void onLoadState(String fileName) {
FileInputStream fis;
ObjectInput in;
try {
fis = new FileInputStream(fileName);
in = new ObjectInputStream(fis);
SaveState<A> saveState = (SaveState<A>) in.readObject();
learning.setStateActionTable(saveState.getStateActionTable());
if(learning instanceof EpisodicLearning){
((EpisodicLearning) learning).setCurrentEpisode(saveState.getCurrentEpisode());
}
in.close();
} catch (IOException | ClassNotFoundException e) {
e.printStackTrace();
}
}
@Override
public void onSaveState(String fileName) {
FileOutputStream fos;
ObjectOutputStream out;
try{
fos = new FileOutputStream(fileName);
out = new ObjectOutputStream(fos);
int currentEpisode;
if(learning instanceof EpisodicLearning){
currentEpisode = ((EpisodicLearning) learning).getCurrentEpisode();
}else{
currentEpisode = 0;
}
out.writeObject(new SaveState<>(learning.getStateActionTable(), currentEpisode));
out.close();
}catch (IOException e){
e.printStackTrace();
}
}
@Override @Override
public void onEpsilonChange(float epsilon) { public void onEpsilonChange(float epsilon) {
if(learning.getPolicy() instanceof EpsilonPolicy){ if(learning.getPolicy() instanceof EpsilonPolicy){
@ -169,4 +209,13 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
this.nrOfEpisodes = nrOfEpisodes; this.nrOfEpisodes = nrOfEpisodes;
return this; return this;
} }
public RLController<A> setDiscountFactor(float discountFactor){
this.discountFactor = discountFactor;
return this;
}
public RLController<A> setEpsilon(float epsilon){
this.epsilon = epsilon;
return this;
}
} }

View File

@ -11,6 +11,8 @@ import org.knowm.xchart.XYChart;
import javax.swing.*; import javax.swing.*;
import java.awt.*; import java.awt.*;
import java.awt.event.ActionEvent;
import java.io.File;
import java.util.List; import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CopyOnWriteArrayList;
@ -26,6 +28,8 @@ public class View<A extends Enum> implements LearningView{
private JFrame environmentFrame; private JFrame environmentFrame;
private XChartPanel<XYChart> rewardChartPanel; private XChartPanel<XYChart> rewardChartPanel;
private ViewListener viewListener; private ViewListener viewListener;
private JMenuBar menuBar;
private JMenu fileMenu;
public View(Learning<A> learning, Environment<A> environment, ViewListener viewListener) { public View(Learning<A> learning, Environment<A> environment, ViewListener viewListener) {
this.learning = learning; this.learning = learning;
@ -38,7 +42,32 @@ public class View<A extends Enum> implements LearningView{
mainFrame = new JFrame(); mainFrame = new JFrame();
mainFrame.setPreferredSize(new Dimension(1280, 720)); mainFrame.setPreferredSize(new Dimension(1280, 720));
mainFrame.setLayout(new BorderLayout()); mainFrame.setLayout(new BorderLayout());
menuBar = new JMenuBar();
fileMenu = new JMenu("File");
menuBar.add(fileMenu);
fileMenu.add(new JMenuItem(new AbstractAction("Load") {
@Override
public void actionPerformed(ActionEvent e) {
final JFileChooser fc = new JFileChooser();
fc.setCurrentDirectory(new File(System.getProperty("user.dir")));
int returnVal = fc.showOpenDialog(mainFrame);
if (returnVal == JFileChooser.APPROVE_OPTION) {
viewListener.onLoadState(fc.getSelectedFile().toString());
}
}
}));
fileMenu.add(new JMenuItem(new AbstractAction("Save") {
@Override
public void actionPerformed(ActionEvent e) {
String fileName = JOptionPane.showInputDialog("Enter file name", "save");
if(fileName != null){
viewListener.onSaveState(fileName);
}
}
}));
mainFrame.setJMenuBar(menuBar);
initLearningInfoPanel(); initLearningInfoPanel();
initRewardChart(); initRewardChart();

View File

@ -5,4 +5,6 @@ public interface ViewListener {
void onDelayChange(int delay); void onDelayChange(int delay);
void onFastLearnChange(boolean isFastLearn); void onFastLearnChange(boolean isFastLearn);
void onLearnMoreEpisodes(int nrOfEpisodes); void onLearnMoreEpisodes(int nrOfEpisodes);
void onLoadState(String fileName);
void onSaveState(String fileName);
} }

View File

@ -29,7 +29,8 @@ public class EpsilonGreedyPolicy<A extends Enum> implements EpsilonPolicy<A>{
@Override @Override
public A chooseAction(Map<A, Double> actionValues) { public A chooseAction(Map<A, Double> actionValues) {
if(RNG.getRandom().nextFloat() < epsilon){ float f = RNG.getRandom().nextFloat();
if(f < epsilon){
// Take random action // Take random action
return randomPolicy.chooseAction(actionValues); return randomPolicy.chooseAction(actionValues);
}else{ }else{

View File

@ -1,9 +1,10 @@
package core.policy; package core.policy;
import core.RNG;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Random;
public class GreedyPolicy<A extends Enum> implements Policy<A> { public class GreedyPolicy<A extends Enum> implements Policy<A> {
@ -26,6 +27,6 @@ public class GreedyPolicy<A extends Enum> implements Policy<A> {
} }
} }
return equalHigh.get(new Random().nextInt(equalHigh.size())); return equalHigh.get(RNG.getRandom().nextInt(equalHigh.size()));
} }
} }

View File

@ -1,18 +1,17 @@
package core.policy; package core.policy;
import core.RNG; import core.RNG;
import java.util.Map; import java.util.Map;
public class RandomPolicy<A extends Enum> implements Policy<A>{ public class RandomPolicy<A extends Enum> implements Policy<A>{
@Override @Override
public A chooseAction(Map<A, Double> actionValues) { public A chooseAction(Map<A, Double> actionValues) {
int idx = RNG.getRandom().nextInt(actionValues.size()); int idx = RNG.getRandom().nextInt(actionValues.size());
System.out.println("selected action " + idx);
int i = 0; int i = 0;
for(A action : actionValues.keySet()){ for(A action : actionValues.keySet()){
if(i++ == idx) return action; if(i++ == idx) return action;
} }
return null; return null;
} }
} }

View File

@ -0,0 +1,13 @@
package evironment.jumpingDino;
public class Config {
public static final int FRAME_WIDTH = 1280;
public static final int FRAME_HEIGHT = 720;
public static final int GROUND_Y = 50;
public static final int DINO_STARTING_X = 50;
public static final int DINO_SIZE = 50;
public static final int OBSTACLE_SIZE = 60;
public static final int OBSTACLE_SPEED = 30;
public static final int DINO_JUMP_SPEED = 20;
public static final int MAX_JUMP_HEIGHT = 200;
}

View File

@ -0,0 +1,40 @@
package evironment.jumpingDino;
import lombok.Getter;
import java.awt.*;
public class Dino extends RenderObject {
@Getter
private boolean inJump;
public Dino(int size, int x, int y, int dx, int dy, Color color) {
super(size, x, y, dx, dy, color);
}
public void jump(){
if(!inJump){
dy = -Config.DINO_JUMP_SPEED;
inJump = true;
}
}
private void fall(){
if(inJump){
dy = Config.DINO_JUMP_SPEED;
}
}
@Override
public void tick(){
// reached max jump height
if(y + dy < Config.FRAME_HEIGHT - Config.GROUND_Y -Config.OBSTACLE_SIZE - Config.MAX_JUMP_HEIGHT){
fall();
}else if(y + dy >= Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE){
inJump = false;
dy = 0;
y = Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE;
}
super.tick();
}
}

View File

@ -0,0 +1,6 @@
package evironment.jumpingDino;
public enum DinoAction {
JUMP,
NOTHING,
}

View File

@ -0,0 +1,32 @@
package evironment.jumpingDino;
import core.State;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.io.Serializable;
@AllArgsConstructor
@Getter
public class DinoState implements State, Serializable {
private int xDistanceToObstacle;
@Override
public String toString() {
return Integer.toString(xDistanceToObstacle);
}
@Override
public int hashCode() {
return this.xDistanceToObstacle;
}
@Override
public boolean equals(Object obj) {
if(obj instanceof DinoState){
DinoState toCompare = (DinoState) obj;
return toCompare.getXDistanceToObstacle() == this.xDistanceToObstacle;
}
return super.equals(obj);
}
}

View File

@ -0,0 +1,78 @@
package evironment.jumpingDino;
import core.Environment;
import core.State;
import core.StepResultEnvironment;
import core.gui.Visualizable;
import evironment.jumpingDino.gui.DinoWorldComponent;
import lombok.Getter;
import javax.swing.*;
import java.awt.*;
@Getter
public class DinoWorld implements Environment<DinoAction>, Visualizable {
private Dino dino;
private Obstacle currentObstacle;
public DinoWorld(){
dino = new Dino(Config.DINO_SIZE, Config.DINO_STARTING_X, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE, 0, 0, Color.GREEN);
spawnNewObstacle();
}
private boolean ranIntoObstacle(){
Obstacle o = currentObstacle;
Dino p = dino;
boolean xAxis = (o.getX() <= p.getX() && p.getX() < o.getX() + Config.OBSTACLE_SIZE)
|| (o.getX() <= p.getX() + Config.DINO_SIZE && p.getX() + Config.DINO_SIZE < o.getX() + Config.OBSTACLE_SIZE);
boolean yAxis = (o.getY() <= p.getY() && p.getY() < o.getY() + Config.OBSTACLE_SIZE)
|| (o.getY() <= p.getY() + Config.DINO_SIZE && p.getY() + Config.DINO_SIZE < o.getY() + Config.OBSTACLE_SIZE);
return xAxis && yAxis;
}
private int getDistanceToObstacle(){
return currentObstacle.getX() - dino.getX() + Config.DINO_SIZE;
}
@Override
public StepResultEnvironment step(DinoAction action) {
boolean done = false;
int reward = 1;
if(action == DinoAction.JUMP){
dino.jump();
}
dino.tick();
currentObstacle.tick();
if(currentObstacle.getX() < -Config.OBSTACLE_SIZE){
spawnNewObstacle();
}
if(ranIntoObstacle()){
done = true;
}
return new StepResultEnvironment(new DinoState(getDistanceToObstacle()), reward, done, "");
}
private void spawnNewObstacle(){
currentObstacle = new Obstacle(Config.OBSTACLE_SIZE, Config.FRAME_WIDTH + Config.OBSTACLE_SIZE, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE, -Config.OBSTACLE_SPEED, 0, Color.BLACK);
}
private void spawnDino(){
dino = new Dino(Config.DINO_SIZE, Config.DINO_STARTING_X, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE, 0, 0, Color.GREEN);
}
@Override
public State reset() {
spawnDino();
spawnNewObstacle();
return new DinoState(getDistanceToObstacle());
}
@Override
public JComponent visualize() {
return new DinoWorldComponent(this);
}
}

View File

@ -0,0 +1,10 @@
package evironment.jumpingDino;
import java.awt.*;
public class Obstacle extends RenderObject {
public Obstacle(int size, int x, int y, int dx, int dy, Color color) {
super(size, x, y, dx, dy, color);
}
}

View File

@ -0,0 +1,28 @@
package evironment.jumpingDino;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.awt.*;
@AllArgsConstructor
@Getter
public abstract class RenderObject {
protected int size;
protected int x;
protected int y;
protected int dx;
protected int dy;
protected Color color;
public void render(Graphics g){
g.setColor(color);
g.fillRect(x, y, size, size);
}
public void tick(){
y += dy;
x += dx;
}
}

View File

@ -0,0 +1,27 @@
package evironment.jumpingDino.gui;
import evironment.jumpingDino.Config;
import evironment.jumpingDino.DinoWorld;
import javax.swing.*;
import java.awt.*;
public class DinoWorldComponent extends JComponent {
private DinoWorld dinoWorld;
public DinoWorldComponent(DinoWorld dinoWorld){
this.dinoWorld = dinoWorld;
setPreferredSize(new Dimension(Config.FRAME_WIDTH, Config.FRAME_HEIGHT));
setVisible(true);
}
@Override
protected void paintComponent(Graphics g) {
super.paintComponent(g);
g.setColor(Color.BLACK);
g.fillRect(0, Config.FRAME_HEIGHT - Config.GROUND_Y, Config.FRAME_WIDTH, 2);
dinoWorld.getDino().render(g);
dinoWorld.getCurrentObstacle().render(g);
}
}

View File

@ -0,0 +1,23 @@
package example;
import core.RNG;
import core.algo.Method;
import core.controller.RLController;
import evironment.jumpingDino.DinoAction;
import evironment.jumpingDino.DinoWorld;
public class JumpingDino {
public static void main(String[] args) {
RNG.setSeed(55);
RLController<DinoAction> rl = new RLController<DinoAction>()
.setEnvironment(new DinoWorld())
.setAllowedActions(DinoAction.values())
.setMethod(Method.MC_ONPOLICY_EGREEDY)
.setDiscountFactor(1f)
.setEpsilon(0.15f)
.setDelay(200)
.setEpisodes(100000);
rl.start();
}
}