diff --git a/src/main/java/core/StateActionHashTable.java b/src/main/java/core/DeterministicStateActionTable.java
similarity index 74%
rename from src/main/java/core/StateActionHashTable.java
rename to src/main/java/core/DeterministicStateActionTable.java
index dd981d0..6c9d55d 100644
--- a/src/main/java/core/StateActionHashTable.java
+++ b/src/main/java/core/DeterministicStateActionTable.java
@@ -1,20 +1,19 @@
package core;
-import evironment.antGame.AntAction;
-
-import java.util.HashMap;
+import java.io.Serializable;
+import java.util.LinkedHashMap;
import java.util.Map;
/**
* Premise: All states have the complete action space
*/
-public class StateActionHashTable implements StateActionTable {
+public class DeterministicStateActionTable implements StateActionTable, Serializable {
private final Map> table;
private DiscreteActionSpace discreteActionSpace;
- public StateActionHashTable(DiscreteActionSpace discreteActionSpace){
- table = new HashMap<>();
+ public DeterministicStateActionTable(DiscreteActionSpace discreteActionSpace){
+ table = new LinkedHashMap<>();
this.discreteActionSpace = discreteActionSpace;
}
@@ -61,19 +60,15 @@ public class StateActionHashTable implements StateActionTable
return table.get(state);
}
- public static void main(String[] args) {
- DiscreteActionSpace da = new ListDiscreteActionSpace<>(AntAction.MOVE_RIGHT, AntAction.PICK_UP);
- StateActionTable sat = new StateActionHashTable<>(da);
- State t = new State() {
- };
-
- System.out.println(sat.getActionValues(t));
- }
private Map createDefaultActionValues(){
- final Map defaultActionValues = new HashMap<>();
+ final Map defaultActionValues = new LinkedHashMap<>();
for(A action: discreteActionSpace.getAllActions()){
defaultActionValues.put(action, DEFAULT_VALUE);
}
return defaultActionValues;
}
+ @Override
+ public int getStateCount(){
+ return table.size();
+ }
}
diff --git a/src/main/java/core/LearningConfig.java b/src/main/java/core/LearningConfig.java
index 916de16..a4ecb68 100644
--- a/src/main/java/core/LearningConfig.java
+++ b/src/main/java/core/LearningConfig.java
@@ -1,7 +1,7 @@
package core;
public class LearningConfig {
- public static final int DEFAULT_DELAY = 1;
+ public static final int DEFAULT_DELAY = 30;
public static final float DEFAULT_EPSILON = 0.1f;
public static final float DEFAULT_DISCOUNT_FACTOR = 1.0f;
}
diff --git a/src/main/java/core/ListDiscreteActionSpace.java b/src/main/java/core/ListDiscreteActionSpace.java
index 42de87a..3e1b862 100644
--- a/src/main/java/core/ListDiscreteActionSpace.java
+++ b/src/main/java/core/ListDiscreteActionSpace.java
@@ -1,10 +1,11 @@
package core;
+import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
-public class ListDiscreteActionSpace implements DiscreteActionSpace {
+public class ListDiscreteActionSpace implements DiscreteActionSpace, Serializable {
private List actions;
public ListDiscreteActionSpace(){
diff --git a/src/main/java/core/RNG.java b/src/main/java/core/RNG.java
index 2fe8929..daf03b4 100644
--- a/src/main/java/core/RNG.java
+++ b/src/main/java/core/RNG.java
@@ -1,12 +1,14 @@
package core;
+import java.security.SecureRandom;
import java.util.Random;
public class RNG {
- private static Random rng;
+ private static SecureRandom rng;
private static int seed = 123;
static {
- rng = new Random(seed);
+ rng = new SecureRandom();
+ rng.setSeed(seed);
}
public static Random getRandom() {
diff --git a/src/main/java/core/SaveState.java b/src/main/java/core/SaveState.java
new file mode 100644
index 0000000..140ac7a
--- /dev/null
+++ b/src/main/java/core/SaveState.java
@@ -0,0 +1,13 @@
+package core;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+
+import java.io.Serializable;
+
+@AllArgsConstructor
+@Getter
+public class SaveState implements Serializable {
+ private StateActionTable stateActionTable;
+ private int currentEpisode;
+}
diff --git a/src/main/java/core/StateActionTable.java b/src/main/java/core/StateActionTable.java
index 851ffba..3f863ba 100644
--- a/src/main/java/core/StateActionTable.java
+++ b/src/main/java/core/StateActionTable.java
@@ -7,6 +7,6 @@ public interface StateActionTable {
double getValue(State state, A action);
void setValue(State state, A action, double value);
-
+ int getStateCount();
Map getActionValues(State state);
}
diff --git a/src/main/java/core/algo/EpisodicLearning.java b/src/main/java/core/algo/EpisodicLearning.java
index f9aa580..9211224 100644
--- a/src/main/java/core/algo/EpisodicLearning.java
+++ b/src/main/java/core/algo/EpisodicLearning.java
@@ -3,13 +3,17 @@ package core.algo;
import core.DiscreteActionSpace;
import core.Environment;
import core.listener.LearningListener;
+import lombok.Getter;
+import lombok.Setter;
-public abstract class EpisodicLearning extends Learning implements Episodic{
+public abstract class EpisodicLearning extends Learning implements Episodic {
+ @Setter
+ @Getter
protected int currentEpisode;
protected int episodesToLearn;
protected volatile int episodePerSecond;
protected int episodeSumCurrentSecond;
- private volatile boolean meseaureEpisodeBenchMark;
+ private volatile boolean measureEpisodeBenchMark;
public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, int delay) {
super(environment, actionSpace, discountFactor, delay);
@@ -29,6 +33,9 @@ public abstract class EpisodicLearning extends Learning imple
protected void dispatchEpisodeEnd(double recentSumOfRewards){
++episodeSumCurrentSecond;
+ if(rewardHistory.size() > 10000){
+ rewardHistory.clear();
+ }
rewardHistory.add(recentSumOfRewards);
for(LearningListener l: learningListeners) {
l.onEpisodeEnd(rewardHistory);
@@ -47,9 +54,9 @@ public abstract class EpisodicLearning extends Learning imple
}
public void learn(int nrOfEpisodes){
- meseaureEpisodeBenchMark = true;
+ measureEpisodeBenchMark = true;
new Thread(()->{
- while(meseaureEpisodeBenchMark){
+ while(measureEpisodeBenchMark){
episodePerSecond = episodeSumCurrentSecond;
episodeSumCurrentSecond = 0;
try {
@@ -65,7 +72,7 @@ public abstract class EpisodicLearning extends Learning imple
nextEpisode();
}
dispatchLearningEnd();
- meseaureEpisodeBenchMark = false;
+ measureEpisodeBenchMark = false;
}
protected abstract void nextEpisode();
diff --git a/src/main/java/core/algo/Learning.java b/src/main/java/core/algo/Learning.java
index 1999f46..a9c73c7 100644
--- a/src/main/java/core/algo/Learning.java
+++ b/src/main/java/core/algo/Learning.java
@@ -9,15 +9,17 @@ import core.policy.Policy;
import lombok.Getter;
import lombok.Setter;
+import java.io.Serializable;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArrayList;
@Getter
-public abstract class Learning {
+public abstract class Learning implements Serializable {
protected Policy policy;
protected DiscreteActionSpace actionSpace;
+ @Setter
protected StateActionTable stateActionTable;
protected Environment environment;
protected float discountFactor;
@@ -26,7 +28,7 @@ public abstract class Learning {
protected int delay;
protected List rewardHistory;
- public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, int delay){
+ public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, int delay) {
this.environment = environment;
this.actionSpace = actionSpace;
this.discountFactor = discountFactor;
@@ -35,39 +37,41 @@ public abstract class Learning {
rewardHistory = new CopyOnWriteArrayList<>();
}
- public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor){
+ public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor) {
this(environment, actionSpace, discountFactor, LearningConfig.DEFAULT_DELAY);
}
- public Learning(Environment environment, DiscreteActionSpace actionSpace, int delay){
+ public Learning(Environment environment, DiscreteActionSpace actionSpace, int delay) {
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, delay);
}
- public Learning(Environment environment, DiscreteActionSpace actionSpace){
+ public Learning(Environment environment, DiscreteActionSpace actionSpace) {
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_DELAY);
}
+
public abstract void learn();
- public void addListener(LearningListener learningListener){
+ public void addListener(LearningListener learningListener) {
learningListeners.add(learningListener);
}
- protected void dispatchStepEnd(){
- for(LearningListener l: learningListeners){
+ protected void dispatchStepEnd() {
+ for (LearningListener l : learningListeners) {
l.onStepEnd();
}
}
- protected void dispatchLearningStart(){
- for(LearningListener l: learningListeners){
+ protected void dispatchLearningStart() {
+ for (LearningListener l : learningListeners) {
l.onLearningStart();
}
}
- protected void dispatchLearningEnd(){
- for(LearningListener l: learningListeners){
+ protected void dispatchLearningEnd() {
+ for (LearningListener l : learningListeners) {
l.onLearningEnd();
}
}
+
}
diff --git a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java
index 544fd1b..6f468c0 100644
--- a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java
+++ b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java
@@ -35,7 +35,7 @@ public class MonteCarloOnPolicyEGreedy extends EpisodicLearning<
super(environment, actionSpace, discountFactor, delay);
currentEpisode = 0;
this.policy = new EpsilonGreedyPolicy<>(epsilon);
- this.stateActionTable = new StateActionHashTable<>(this.actionSpace);
+ this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace);
returnSum = new HashMap<>();
returnCount = new HashMap<>();
}
@@ -57,16 +57,15 @@ public class MonteCarloOnPolicyEGreedy extends EpisodicLearning<
e.printStackTrace();
}
double sumOfRewards = 0;
- for (int j = 0; j < 10; ++j) {
+ StepResultEnvironment envResult = null;
+ while(envResult == null || !envResult.isDone()){
Map actionValues = stateActionTable.getActionValues(state);
A chosenAction = policy.chooseAction(actionValues);
- StepResultEnvironment envResult = environment.step(chosenAction);
+ envResult = environment.step(chosenAction);
State nextState = envResult.getState();
sumOfRewards += envResult.getReward();
episode.add(new StepResult<>(state, chosenAction, envResult.getReward()));
- if (envResult.isDone()) break;
-
state = nextState;
try {
@@ -78,13 +77,13 @@ public class MonteCarloOnPolicyEGreedy extends EpisodicLearning<
}
dispatchEpisodeEnd(sumOfRewards);
- System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
- Set> stateActionPairs = new HashSet<>();
+ // System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
+ Set> stateActionPairs = new LinkedHashSet<>();
for (StepResult sr : episode) {
stateActionPairs.add(new Pair<>(sr.getState(), sr.getAction()));
}
- System.out.println("stateActionPairs " + stateActionPairs.size());
+ //System.out.println("stateActionPairs " + stateActionPairs.size());
for (Pair stateActionPair : stateActionPairs) {
int firstOccurenceIndex = 0;
// find first occurance of state action pair
diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java
index 4a149e5..b733807 100644
--- a/src/main/java/core/controller/RLController.java
+++ b/src/main/java/core/controller/RLController.java
@@ -1,8 +1,6 @@
package core.controller;
-import core.DiscreteActionSpace;
-import core.Environment;
-import core.ListDiscreteActionSpace;
+import core.*;
import core.algo.EpisodicLearning;
import core.algo.Learning;
import core.algo.Method;
@@ -14,6 +12,7 @@ import core.listener.ViewListener;
import core.policy.EpsilonPolicy;
import javax.swing.*;
+import java.io.*;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@@ -23,10 +22,12 @@ public class RLController implements ViewListener, LearningListe
protected Learning learning;
protected DiscreteActionSpace discreteActionSpace;
protected LearningView learningView;
- private int delay;
private int nrOfEpisodes;
private Method method;
private int prevDelay;
+ private int delay = LearningConfig.DEFAULT_DELAY;
+ private float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR;
+ private float epsilon = LearningConfig.DEFAULT_EPSILON;
private boolean fastLearning;
private boolean currentlyLearning;
private ExecutorService learningExecutor;
@@ -36,6 +37,7 @@ public class RLController implements ViewListener, LearningListe
learningExecutor = Executors.newSingleThreadExecutor();
}
+
public void start(){
if(environment == null || discreteActionSpace == null || method == null){
throw new RuntimeException("Set environment, discreteActionSpace and method before calling .start()");
@@ -43,7 +45,7 @@ public class RLController implements ViewListener, LearningListe
switch (method){
case MC_ONPOLICY_EGREEDY:
- learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, delay);
+ learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
break;
case TD_ONPOLICY:
break;
@@ -76,6 +78,44 @@ public class RLController implements ViewListener, LearningListe
}
}
+ @Override
+ public void onLoadState(String fileName) {
+ FileInputStream fis;
+ ObjectInput in;
+ try {
+ fis = new FileInputStream(fileName);
+ in = new ObjectInputStream(fis);
+ SaveState saveState = (SaveState) in.readObject();
+ learning.setStateActionTable(saveState.getStateActionTable());
+ if(learning instanceof EpisodicLearning){
+ ((EpisodicLearning) learning).setCurrentEpisode(saveState.getCurrentEpisode());
+ }
+ in.close();
+ } catch (IOException | ClassNotFoundException e) {
+ e.printStackTrace();
+ }
+ }
+
+ @Override
+ public void onSaveState(String fileName) {
+ FileOutputStream fos;
+ ObjectOutputStream out;
+ try{
+ fos = new FileOutputStream(fileName);
+ out = new ObjectOutputStream(fos);
+ int currentEpisode;
+ if(learning instanceof EpisodicLearning){
+ currentEpisode = ((EpisodicLearning) learning).getCurrentEpisode();
+ }else{
+ currentEpisode = 0;
+ }
+ out.writeObject(new SaveState<>(learning.getStateActionTable(), currentEpisode));
+ out.close();
+ }catch (IOException e){
+ e.printStackTrace();
+ }
+ }
+
@Override
public void onEpsilonChange(float epsilon) {
if(learning.getPolicy() instanceof EpsilonPolicy){
@@ -169,4 +209,13 @@ public class RLController implements ViewListener, LearningListe
this.nrOfEpisodes = nrOfEpisodes;
return this;
}
+
+ public RLController setDiscountFactor(float discountFactor){
+ this.discountFactor = discountFactor;
+ return this;
+ }
+ public RLController setEpsilon(float epsilon){
+ this.epsilon = epsilon;
+ return this;
+ }
}
diff --git a/src/main/java/core/gui/View.java b/src/main/java/core/gui/View.java
index d89cc97..5ad77db 100644
--- a/src/main/java/core/gui/View.java
+++ b/src/main/java/core/gui/View.java
@@ -11,6 +11,8 @@ import org.knowm.xchart.XYChart;
import javax.swing.*;
import java.awt.*;
+import java.awt.event.ActionEvent;
+import java.io.File;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
@@ -26,6 +28,8 @@ public class View implements LearningView{
private JFrame environmentFrame;
private XChartPanel rewardChartPanel;
private ViewListener viewListener;
+ private JMenuBar menuBar;
+ private JMenu fileMenu;
public View(Learning learning, Environment environment, ViewListener viewListener) {
this.learning = learning;
@@ -38,7 +42,32 @@ public class View implements LearningView{
mainFrame = new JFrame();
mainFrame.setPreferredSize(new Dimension(1280, 720));
mainFrame.setLayout(new BorderLayout());
+ menuBar = new JMenuBar();
+ fileMenu = new JMenu("File");
+ menuBar.add(fileMenu);
+ fileMenu.add(new JMenuItem(new AbstractAction("Load") {
+ @Override
+ public void actionPerformed(ActionEvent e) {
+ final JFileChooser fc = new JFileChooser();
+ fc.setCurrentDirectory(new File(System.getProperty("user.dir")));
+ int returnVal = fc.showOpenDialog(mainFrame);
+ if (returnVal == JFileChooser.APPROVE_OPTION) {
+ viewListener.onLoadState(fc.getSelectedFile().toString());
+ }
+ }
+ }));
+
+ fileMenu.add(new JMenuItem(new AbstractAction("Save") {
+ @Override
+ public void actionPerformed(ActionEvent e) {
+ String fileName = JOptionPane.showInputDialog("Enter file name", "save");
+ if(fileName != null){
+ viewListener.onSaveState(fileName);
+ }
+ }
+ }));
+ mainFrame.setJMenuBar(menuBar);
initLearningInfoPanel();
initRewardChart();
diff --git a/src/main/java/core/listener/ViewListener.java b/src/main/java/core/listener/ViewListener.java
index dbf01d4..abd0004 100644
--- a/src/main/java/core/listener/ViewListener.java
+++ b/src/main/java/core/listener/ViewListener.java
@@ -5,4 +5,6 @@ public interface ViewListener {
void onDelayChange(int delay);
void onFastLearnChange(boolean isFastLearn);
void onLearnMoreEpisodes(int nrOfEpisodes);
+ void onLoadState(String fileName);
+ void onSaveState(String fileName);
}
diff --git a/src/main/java/core/policy/EpsilonGreedyPolicy.java b/src/main/java/core/policy/EpsilonGreedyPolicy.java
index 550889f..cfe506c 100644
--- a/src/main/java/core/policy/EpsilonGreedyPolicy.java
+++ b/src/main/java/core/policy/EpsilonGreedyPolicy.java
@@ -29,7 +29,8 @@ public class EpsilonGreedyPolicy implements EpsilonPolicy{
@Override
public A chooseAction(Map actionValues) {
- if(RNG.getRandom().nextFloat() < epsilon){
+ float f = RNG.getRandom().nextFloat();
+ if(f < epsilon){
// Take random action
return randomPolicy.chooseAction(actionValues);
}else{
diff --git a/src/main/java/core/policy/GreedyPolicy.java b/src/main/java/core/policy/GreedyPolicy.java
index 6ff7739..30d901f 100644
--- a/src/main/java/core/policy/GreedyPolicy.java
+++ b/src/main/java/core/policy/GreedyPolicy.java
@@ -1,9 +1,10 @@
package core.policy;
+import core.RNG;
+
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
-import java.util.Random;
public class GreedyPolicy implements Policy {
@@ -26,6 +27,6 @@ public class GreedyPolicy implements Policy {
}
}
- return equalHigh.get(new Random().nextInt(equalHigh.size()));
+ return equalHigh.get(RNG.getRandom().nextInt(equalHigh.size()));
}
}
diff --git a/src/main/java/core/policy/RandomPolicy.java b/src/main/java/core/policy/RandomPolicy.java
index 094b41c..ea99dd5 100644
--- a/src/main/java/core/policy/RandomPolicy.java
+++ b/src/main/java/core/policy/RandomPolicy.java
@@ -1,18 +1,17 @@
package core.policy;
import core.RNG;
+
import java.util.Map;
public class RandomPolicy implements Policy{
@Override
public A chooseAction(Map actionValues) {
int idx = RNG.getRandom().nextInt(actionValues.size());
- System.out.println("selected action " + idx);
int i = 0;
for(A action : actionValues.keySet()){
- if(i++ == idx) return action;
+ if(i++ == idx) return action;
}
-
return null;
}
}
diff --git a/src/main/java/evironment/jumpingDino/Config.java b/src/main/java/evironment/jumpingDino/Config.java
new file mode 100644
index 0000000..29af520
--- /dev/null
+++ b/src/main/java/evironment/jumpingDino/Config.java
@@ -0,0 +1,13 @@
+package evironment.jumpingDino;
+
+public class Config {
+ public static final int FRAME_WIDTH = 1280;
+ public static final int FRAME_HEIGHT = 720;
+ public static final int GROUND_Y = 50;
+ public static final int DINO_STARTING_X = 50;
+ public static final int DINO_SIZE = 50;
+ public static final int OBSTACLE_SIZE = 60;
+ public static final int OBSTACLE_SPEED = 30;
+ public static final int DINO_JUMP_SPEED = 20;
+ public static final int MAX_JUMP_HEIGHT = 200;
+}
diff --git a/src/main/java/evironment/jumpingDino/Dino.java b/src/main/java/evironment/jumpingDino/Dino.java
new file mode 100644
index 0000000..0c29ce3
--- /dev/null
+++ b/src/main/java/evironment/jumpingDino/Dino.java
@@ -0,0 +1,40 @@
+package evironment.jumpingDino;
+
+import lombok.Getter;
+
+import java.awt.*;
+
+public class Dino extends RenderObject {
+ @Getter
+ private boolean inJump;
+
+ public Dino(int size, int x, int y, int dx, int dy, Color color) {
+ super(size, x, y, dx, dy, color);
+ }
+
+ public void jump(){
+ if(!inJump){
+ dy = -Config.DINO_JUMP_SPEED;
+ inJump = true;
+ }
+ }
+
+ private void fall(){
+ if(inJump){
+ dy = Config.DINO_JUMP_SPEED;
+ }
+ }
+
+ @Override
+ public void tick(){
+ // reached max jump height
+ if(y + dy < Config.FRAME_HEIGHT - Config.GROUND_Y -Config.OBSTACLE_SIZE - Config.MAX_JUMP_HEIGHT){
+ fall();
+ }else if(y + dy >= Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE){
+ inJump = false;
+ dy = 0;
+ y = Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE;
+ }
+ super.tick();
+ }
+}
diff --git a/src/main/java/evironment/jumpingDino/DinoAction.java b/src/main/java/evironment/jumpingDino/DinoAction.java
new file mode 100644
index 0000000..8adc627
--- /dev/null
+++ b/src/main/java/evironment/jumpingDino/DinoAction.java
@@ -0,0 +1,6 @@
+package evironment.jumpingDino;
+
+public enum DinoAction {
+ JUMP,
+ NOTHING,
+}
diff --git a/src/main/java/evironment/jumpingDino/DinoState.java b/src/main/java/evironment/jumpingDino/DinoState.java
new file mode 100644
index 0000000..0f59dc1
--- /dev/null
+++ b/src/main/java/evironment/jumpingDino/DinoState.java
@@ -0,0 +1,32 @@
+package evironment.jumpingDino;
+
+import core.State;
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+
+import java.io.Serializable;
+
+@AllArgsConstructor
+@Getter
+public class DinoState implements State, Serializable {
+ private int xDistanceToObstacle;
+
+ @Override
+ public String toString() {
+ return Integer.toString(xDistanceToObstacle);
+ }
+
+ @Override
+ public int hashCode() {
+ return this.xDistanceToObstacle;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if(obj instanceof DinoState){
+ DinoState toCompare = (DinoState) obj;
+ return toCompare.getXDistanceToObstacle() == this.xDistanceToObstacle;
+ }
+ return super.equals(obj);
+ }
+}
diff --git a/src/main/java/evironment/jumpingDino/DinoWorld.java b/src/main/java/evironment/jumpingDino/DinoWorld.java
new file mode 100644
index 0000000..3f40524
--- /dev/null
+++ b/src/main/java/evironment/jumpingDino/DinoWorld.java
@@ -0,0 +1,78 @@
+package evironment.jumpingDino;
+
+import core.Environment;
+import core.State;
+import core.StepResultEnvironment;
+import core.gui.Visualizable;
+import evironment.jumpingDino.gui.DinoWorldComponent;
+import lombok.Getter;
+
+import javax.swing.*;
+import java.awt.*;
+
+@Getter
+public class DinoWorld implements Environment, Visualizable {
+ private Dino dino;
+ private Obstacle currentObstacle;
+
+ public DinoWorld(){
+ dino = new Dino(Config.DINO_SIZE, Config.DINO_STARTING_X, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE, 0, 0, Color.GREEN);
+ spawnNewObstacle();
+ }
+
+ private boolean ranIntoObstacle(){
+ Obstacle o = currentObstacle;
+ Dino p = dino;
+
+ boolean xAxis = (o.getX() <= p.getX() && p.getX() < o.getX() + Config.OBSTACLE_SIZE)
+ || (o.getX() <= p.getX() + Config.DINO_SIZE && p.getX() + Config.DINO_SIZE < o.getX() + Config.OBSTACLE_SIZE);
+
+ boolean yAxis = (o.getY() <= p.getY() && p.getY() < o.getY() + Config.OBSTACLE_SIZE)
+ || (o.getY() <= p.getY() + Config.DINO_SIZE && p.getY() + Config.DINO_SIZE < o.getY() + Config.OBSTACLE_SIZE);
+
+ return xAxis && yAxis;
+ }
+ private int getDistanceToObstacle(){
+ return currentObstacle.getX() - dino.getX() + Config.DINO_SIZE;
+ }
+
+ @Override
+ public StepResultEnvironment step(DinoAction action) {
+ boolean done = false;
+ int reward = 1;
+
+ if(action == DinoAction.JUMP){
+ dino.jump();
+ }
+
+ dino.tick();
+ currentObstacle.tick();
+ if(currentObstacle.getX() < -Config.OBSTACLE_SIZE){
+ spawnNewObstacle();
+ }
+
+ if(ranIntoObstacle()){
+ done = true;
+ }
+
+ return new StepResultEnvironment(new DinoState(getDistanceToObstacle()), reward, done, "");
+ }
+ private void spawnNewObstacle(){
+ currentObstacle = new Obstacle(Config.OBSTACLE_SIZE, Config.FRAME_WIDTH + Config.OBSTACLE_SIZE, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE, -Config.OBSTACLE_SPEED, 0, Color.BLACK);
+ }
+
+ private void spawnDino(){
+ dino = new Dino(Config.DINO_SIZE, Config.DINO_STARTING_X, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE, 0, 0, Color.GREEN);
+ }
+ @Override
+ public State reset() {
+ spawnDino();
+ spawnNewObstacle();
+ return new DinoState(getDistanceToObstacle());
+ }
+
+ @Override
+ public JComponent visualize() {
+ return new DinoWorldComponent(this);
+ }
+}
diff --git a/src/main/java/evironment/jumpingDino/Obstacle.java b/src/main/java/evironment/jumpingDino/Obstacle.java
new file mode 100644
index 0000000..0d01c4b
--- /dev/null
+++ b/src/main/java/evironment/jumpingDino/Obstacle.java
@@ -0,0 +1,10 @@
+package evironment.jumpingDino;
+
+import java.awt.*;
+
+public class Obstacle extends RenderObject {
+
+ public Obstacle(int size, int x, int y, int dx, int dy, Color color) {
+ super(size, x, y, dx, dy, color);
+ }
+}
diff --git a/src/main/java/evironment/jumpingDino/RenderObject.java b/src/main/java/evironment/jumpingDino/RenderObject.java
new file mode 100644
index 0000000..3b1430b
--- /dev/null
+++ b/src/main/java/evironment/jumpingDino/RenderObject.java
@@ -0,0 +1,28 @@
+package evironment.jumpingDino;
+
+import lombok.AllArgsConstructor;
+import lombok.Getter;
+
+import java.awt.*;
+
+
+@AllArgsConstructor
+@Getter
+public abstract class RenderObject {
+ protected int size;
+ protected int x;
+ protected int y;
+ protected int dx;
+ protected int dy;
+ protected Color color;
+
+ public void render(Graphics g){
+ g.setColor(color);
+ g.fillRect(x, y, size, size);
+ }
+
+ public void tick(){
+ y += dy;
+ x += dx;
+ }
+}
diff --git a/src/main/java/evironment/jumpingDino/gui/DinoWorldComponent.java b/src/main/java/evironment/jumpingDino/gui/DinoWorldComponent.java
new file mode 100644
index 0000000..461934f
--- /dev/null
+++ b/src/main/java/evironment/jumpingDino/gui/DinoWorldComponent.java
@@ -0,0 +1,27 @@
+package evironment.jumpingDino.gui;
+
+import evironment.jumpingDino.Config;
+import evironment.jumpingDino.DinoWorld;
+
+import javax.swing.*;
+import java.awt.*;
+
+public class DinoWorldComponent extends JComponent {
+ private DinoWorld dinoWorld;
+
+ public DinoWorldComponent(DinoWorld dinoWorld){
+ this.dinoWorld = dinoWorld;
+ setPreferredSize(new Dimension(Config.FRAME_WIDTH, Config.FRAME_HEIGHT));
+ setVisible(true);
+ }
+
+ @Override
+ protected void paintComponent(Graphics g) {
+ super.paintComponent(g);
+ g.setColor(Color.BLACK);
+ g.fillRect(0, Config.FRAME_HEIGHT - Config.GROUND_Y, Config.FRAME_WIDTH, 2);
+
+ dinoWorld.getDino().render(g);
+ dinoWorld.getCurrentObstacle().render(g);
+ }
+}
diff --git a/src/main/java/example/JumpingDino.java b/src/main/java/example/JumpingDino.java
new file mode 100644
index 0000000..87b1ceb
--- /dev/null
+++ b/src/main/java/example/JumpingDino.java
@@ -0,0 +1,23 @@
+package example;
+
+import core.RNG;
+import core.algo.Method;
+import core.controller.RLController;
+import evironment.jumpingDino.DinoAction;
+import evironment.jumpingDino.DinoWorld;
+
+public class JumpingDino {
+ public static void main(String[] args) {
+ RNG.setSeed(55);
+
+ RLController rl = new RLController()
+ .setEnvironment(new DinoWorld())
+ .setAllowedActions(DinoAction.values())
+ .setMethod(Method.MC_ONPOLICY_EGREEDY)
+ .setDiscountFactor(1f)
+ .setEpsilon(0.15f)
+ .setDelay(200)
+ .setEpisodes(100000);
+ rl.start();
+ }
+}