diff --git a/.gitignore b/.gitignore
index 25d1e1c..3e25165 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,6 @@
+learningStates/*
+!learningStates/.gitkeep
+
.idea/refo.iml
.idea/misc.xml
.idea/modules.xml
diff --git a/learningStates/.gitkeep b/learningStates/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/src/main/java/core/algo/EpisodicLearning.java b/src/main/java/core/algo/EpisodicLearning.java
index ceb3756..3a56363 100644
--- a/src/main/java/core/algo/EpisodicLearning.java
+++ b/src/main/java/core/algo/EpisodicLearning.java
@@ -2,59 +2,53 @@ package core.algo;
import core.DiscreteActionSpace;
import core.Environment;
+import core.StepResult;
import core.listener.LearningListener;
+import lombok.Getter;
import lombok.Setter;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
+
public abstract class EpisodicLearning extends Learning implements Episodic {
@Setter
protected int currentEpisode;
- protected int episodesToLearn;
+ protected volatile AtomicInteger episodesToLearn = new AtomicInteger(0);
+ @Getter
protected volatile int episodePerSecond;
protected int episodeSumCurrentSecond;
- private volatile boolean measureEpisodeBenchMark;
+ protected double sumOfRewards;
+ protected List> episode = new ArrayList<>();
public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, int delay) {
super(environment, actionSpace, discountFactor, delay);
+ initBenchMarking();
}
public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor) {
super(environment, actionSpace, discountFactor);
+ initBenchMarking();
}
public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace, int delay) {
super(environment, actionSpace, delay);
+ initBenchMarking();
}
public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace) {
super(environment, actionSpace);
+ initBenchMarking();
}
- protected void dispatchEpisodeEnd(double recentSumOfRewards){
- ++episodeSumCurrentSecond;
- if(rewardHistory.size() > 10000){
- rewardHistory.clear();
- }
- rewardHistory.add(recentSumOfRewards);
- for(LearningListener l: learningListeners) {
- l.onEpisodeEnd(rewardHistory);
- }
- }
+ protected abstract void nextEpisode();
- protected void dispatchEpisodeStart(){
- for(LearningListener l: learningListeners){
- l.onEpisodeStart();
- }
- }
-
- @Override
- public void learn(){
- learn(0);
- }
-
- public void learn(int nrOfEpisodes){
- measureEpisodeBenchMark = true;
+ private void initBenchMarking(){
new Thread(()->{
- while(measureEpisodeBenchMark){
+ while (true){
episodePerSecond = episodeSumCurrentSecond;
episodeSumCurrentSecond = 0;
try {
@@ -64,24 +58,89 @@ public abstract class EpisodicLearning extends Learning imple
}
}
}).start();
- episodesToLearn += nrOfEpisodes;
- dispatchLearningStart();
- for(int i=0; i < nrOfEpisodes; ++i){
- nextEpisode();
- }
- dispatchLearningEnd();
- measureEpisodeBenchMark = false;
}
- protected abstract void nextEpisode();
+ protected void dispatchEpisodeEnd(){
+ ++episodeSumCurrentSecond;
+ if(rewardHistory.size() > 10000){
+ rewardHistory.clear();
+ }
+ rewardHistory.add(sumOfRewards);
+ for(LearningListener l: learningListeners) {
+ l.onEpisodeEnd(rewardHistory);
+ }
+ }
+
+ protected void dispatchEpisodeStart(){
+ ++currentEpisode;
+ episodesToLearn.decrementAndGet();
+ for(LearningListener l: learningListeners){
+ l.onEpisodeStart();
+ }
+ }
@Override
- public int getCurrentEpisode(){
- return currentEpisode;
+ public void learn(){
+ // TODO remove or learn with default episode number
+ }
+
+ private void startLearning(){
+ learningExecutor.submit(()->{
+ dispatchLearningStart();
+ while(episodesToLearn.get() > 0){
+ dispatchEpisodeStart();
+ nextEpisode();
+ dispatchEpisodeEnd();
+ }
+ synchronized (this){
+ dispatchLearningEnd();
+ notifyAll();
+ }
+ });
+ }
+
+ /**
+ * Stopping the while loop by setting episodesToLearn to 0.
+ * The current episode can not be interrupted, so the sleep delay
+ * is removed and the calling thread has to wait until the
+ * current episode is done.
+ * Resetting the delay afterwards.
+ */
+ @Override
+ public synchronized void interruptLearning(){
+ episodesToLearn.set(0);
+ int prevDelay = delay;
+ delay = 0;
+ while(currentlyLearning) {
+ try {
+ wait();
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ }
+ delay = prevDelay;
+ }
+
+ public synchronized void learn(int nrOfEpisodes){
+ boolean isLearning = episodesToLearn.getAndAdd(nrOfEpisodes) != 0;
+ if(!isLearning)
+ startLearning();
}
@Override
public int getEpisodesToGo(){
- return episodesToLearn - currentEpisode;
+ return episodesToLearn.get();
+ }
+
+ @Override
+ public synchronized void save(ObjectOutputStream oos) throws IOException {
+ super.save(oos);
+ oos.writeInt(currentEpisode);
+ }
+
+ @Override
+ public synchronized void load(ObjectInputStream ois) throws IOException, ClassNotFoundException {
+ super.load(ois);
+ currentEpisode = ois.readInt();
}
}
diff --git a/src/main/java/core/algo/Learning.java b/src/main/java/core/algo/Learning.java
index 8c589d2..c63ef43 100644
--- a/src/main/java/core/algo/Learning.java
+++ b/src/main/java/core/algo/Learning.java
@@ -9,10 +9,16 @@ import core.policy.Policy;
import lombok.Getter;
import lombok.Setter;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import java.io.Serializable;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArrayList;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
/**
*
@@ -30,14 +36,18 @@ public abstract class Learning{
@Setter
protected int delay;
protected List rewardHistory;
+ protected ExecutorService learningExecutor;
+ protected boolean currentlyLearning;
public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, int delay) {
this.environment = environment;
this.actionSpace = actionSpace;
this.discountFactor = discountFactor;
this.delay = delay;
+ currentlyLearning = false;
learningListeners = new HashSet<>();
rewardHistory = new CopyOnWriteArrayList<>();
+ learningExecutor = Executors.newSingleThreadExecutor();
}
public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor) {
@@ -52,7 +62,6 @@ public abstract class Learning{
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_DELAY);
}
-
public abstract void learn();
public void addListener(LearningListener learningListener) {
@@ -66,15 +75,31 @@ public abstract class Learning{
}
protected void dispatchLearningStart() {
+ currentlyLearning = true;
for (LearningListener l : learningListeners) {
l.onLearningStart();
}
}
protected void dispatchLearningEnd() {
+ currentlyLearning = false;
for (LearningListener l : learningListeners) {
l.onLearningEnd();
}
}
+ public synchronized void interruptLearning(){
+ //TODO: for non episodic learning
+ }
+
+
+ public void save(ObjectOutputStream oos) throws IOException {
+ oos.writeObject(rewardHistory);
+ oos.writeObject(stateActionTable);
+ }
+
+ public void load(ObjectInputStream ois) throws IOException, ClassNotFoundException {
+ rewardHistory = (List) ois.readObject();
+ stateActionTable = (StateActionTable) ois.readObject();
+ }
}
diff --git a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java
index 6f468c0..e9b02c9 100644
--- a/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java
+++ b/src/main/java/core/algo/MC/MonteCarloOnPolicyEGreedy.java
@@ -5,6 +5,9 @@ import core.algo.EpisodicLearning;
import core.policy.EpsilonGreedyPolicy;
import javafx.util.Pair;
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
import java.util.*;
/**
@@ -44,19 +47,16 @@ public class MonteCarloOnPolicyEGreedy extends EpisodicLearning<
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay);
}
-
@Override
public void nextEpisode() {
- ++currentEpisode;
- List> episode = new ArrayList<>();
+ episode = new ArrayList<>();
State state = environment.reset();
- dispatchEpisodeStart();
try {
Thread.sleep(delay);
} catch (InterruptedException e) {
e.printStackTrace();
}
- double sumOfRewards = 0;
+ sumOfRewards = 0;
StepResultEnvironment envResult = null;
while(envResult == null || !envResult.isDone()){
Map actionValues = stateActionTable.getActionValues(state);
@@ -76,7 +76,6 @@ public class MonteCarloOnPolicyEGreedy extends EpisodicLearning<
dispatchStepEnd();
}
- dispatchEpisodeEnd(sumOfRewards);
// System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
Set> stateActionPairs = new LinkedHashSet<>();
@@ -115,4 +114,18 @@ public class MonteCarloOnPolicyEGreedy extends EpisodicLearning<
public int getEpisodesPerSecond(){
return episodePerSecond;
}
+
+ @Override
+ public void save(ObjectOutputStream oos) throws IOException {
+ super.save(oos);
+ oos.writeObject(returnSum);
+ oos.writeObject(returnCount);
+ }
+
+ @Override
+ public void load(ObjectInputStream ois) throws IOException, ClassNotFoundException {
+ super.load(ois);
+ returnSum = (Map, Double>) ois.readObject();
+ returnCount = (Map, Integer>) ois.readObject();
+ }
}
diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java
index 4cafe4e..0fb589a 100644
--- a/src/main/java/core/controller/RLController.java
+++ b/src/main/java/core/controller/RLController.java
@@ -18,6 +18,7 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
public class RLController implements ViewListener, LearningListener {
+ private final String folderPrefix = "learningStates" + File.separator;
private Environment environment;
private DiscreteActionSpace discreteActionSpace;
private Method method;
@@ -26,15 +27,12 @@ public class RLController implements ViewListener, LearningListe
private float epsilon = LearningConfig.DEFAULT_EPSILON;
private Learning learning;
private LearningView learningView;
- private ExecutorService learningExecutor;
- private boolean currentlyLearning;
private boolean fastLearning;
private List latestRewardsHistory;
private int nrOfEpisodes;
private int prevDelay;
public RLController(Environment env, Method method, A... actions){
- learningExecutor = Executors.newSingleThreadExecutor();
setEnvironment(env);
setMethod(method);
setAllowedActions(actions);
@@ -64,9 +62,9 @@ public class RLController implements ViewListener, LearningListe
private void initLearning(){
if(learning instanceof EpisodicLearning){
- learningExecutor.submit(()->((EpisodicLearning) learning).learn(nrOfEpisodes));
+ ((EpisodicLearning) learning).learn(nrOfEpisodes);
}else{
- learningExecutor.submit(()->learning.learn());
+ learning.learn();
}
}
@@ -75,27 +73,25 @@ public class RLController implements ViewListener, LearningListe
*************************************************/
@Override
public void onLearnMoreEpisodes(int nrOfEpisodes){
- if(!currentlyLearning){
- if(learning instanceof EpisodicLearning){
- learningExecutor.submit(()->((EpisodicLearning) learning).learn(nrOfEpisodes));
- }else{
- throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
- }
+ if(learning instanceof EpisodicLearning){
+ ((EpisodicLearning) learning).learn(nrOfEpisodes);
+ }else{
+ throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
}
+ learningView.updateLearningInfoPanel();
}
@Override
public void onLoadState(String fileName) {
FileInputStream fis;
- ObjectInput in;
+ ObjectInputStream in;
try {
fis = new FileInputStream(fileName);
in = new ObjectInputStream(fis);
- SaveState saveState = (SaveState) in.readObject();
- learning.setStateActionTable(saveState.getStateActionTable());
- if(learning instanceof EpisodicLearning){
- ((EpisodicLearning) learning).setCurrentEpisode(saveState.getCurrentEpisode());
- }
+ System.out.println("interrup" + Thread.currentThread().getId());
+ learning.interruptLearning();
+ learning.load(in);
+ SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
in.close();
} catch (IOException | ClassNotFoundException e) {
e.printStackTrace();
@@ -107,15 +103,10 @@ public class RLController implements ViewListener, LearningListe
FileOutputStream fos;
ObjectOutputStream out;
try{
- fos = new FileOutputStream(fileName);
+ fos = new FileOutputStream(folderPrefix + fileName);
out = new ObjectOutputStream(fos);
- int currentEpisode;
- if(learning instanceof EpisodicLearning){
- currentEpisode = ((EpisodicLearning) learning).getCurrentEpisode();
- }else{
- currentEpisode = 0;
- }
- out.writeObject(new SaveState<>(learning.getStateActionTable(), currentEpisode));
+ learning.interruptLearning();
+ learning.save(out);
out.close();
}catch (IOException e){
e.printStackTrace();
@@ -158,13 +149,12 @@ public class RLController implements ViewListener, LearningListe
*************************************************/
@Override
public void onLearningStart() {
- currentlyLearning = true;
}
@Override
public void onLearningEnd() {
- currentlyLearning = false;
SwingUtilities.invokeLater(()-> learningView.updateRewardGraph(latestRewardsHistory));
+ onSaveState( method.toString() + System.currentTimeMillis()/1000 + (learning instanceof EpisodicLearning ? "e " + ((EpisodicLearning) learning).getCurrentEpisode() : ""));
}
@Override
@@ -192,7 +182,7 @@ public class RLController implements ViewListener, LearningListe
/*************************************************
- ** SETTER **
+ ** SETTERS **
*************************************************/
private void setEnvironment(Environment environment){
diff --git a/src/main/java/evironment/jumpingDino/DinoWorld.java b/src/main/java/evironment/jumpingDino/DinoWorld.java
index 3f40524..9ca4fbf 100644
--- a/src/main/java/evironment/jumpingDino/DinoWorld.java
+++ b/src/main/java/evironment/jumpingDino/DinoWorld.java
@@ -14,12 +14,20 @@ import java.awt.*;
public class DinoWorld implements Environment, Visualizable {
private Dino dino;
private Obstacle currentObstacle;
+ private boolean randomObstacleSpeed;
+ private boolean randomObstacleDistance;
- public DinoWorld(){
+ public DinoWorld(boolean randomObstacleSpeed, boolean randomObstacleDistance){
+ this.randomObstacleSpeed = randomObstacleSpeed;
+ this.randomObstacleDistance = randomObstacleDistance;
dino = new Dino(Config.DINO_SIZE, Config.DINO_STARTING_X, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE, 0, 0, Color.GREEN);
spawnNewObstacle();
}
+ public DinoWorld(){
+ this(false, false);
+ }
+
private boolean ranIntoObstacle(){
Obstacle o = currentObstacle;
Dino p = dino;
@@ -32,6 +40,7 @@ public class DinoWorld implements Environment, Visualizable {
return xAxis && yAxis;
}
+
private int getDistanceToObstacle(){
return currentObstacle.getX() - dino.getX() + Config.DINO_SIZE;
}
@@ -57,8 +66,27 @@ public class DinoWorld implements Environment, Visualizable {
return new StepResultEnvironment(new DinoState(getDistanceToObstacle()), reward, done, "");
}
+
+
private void spawnNewObstacle(){
- currentObstacle = new Obstacle(Config.OBSTACLE_SIZE, Config.FRAME_WIDTH + Config.OBSTACLE_SIZE, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE, -Config.OBSTACLE_SPEED, 0, Color.BLACK);
+ int dx;
+ int xSpawn;
+
+ if(randomObstacleSpeed){
+ dx = -(int)((Math.random() + 0.5) * Config.OBSTACLE_SPEED);
+ }else{
+ dx = -Config.OBSTACLE_SPEED;
+ }
+
+ if(randomObstacleDistance){
+ // randomly spawning more right outside of the screen
+ xSpawn = (int)(Math.random() + 0.5 * Config.FRAME_WIDTH + Config.FRAME_WIDTH + Config.OBSTACLE_SIZE);
+ }else{
+ // instantly respawning on the left screen border
+ xSpawn = Config.FRAME_WIDTH + Config.OBSTACLE_SIZE;
+ }
+
+ currentObstacle = new Obstacle(Config.OBSTACLE_SIZE, xSpawn, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE, dx, 0, Color.BLACK);
}
private void spawnDino(){
diff --git a/src/main/java/example/JumpingDino.java b/src/main/java/example/JumpingDino.java
index bddc8b2..2e0a447 100644
--- a/src/main/java/example/JumpingDino.java
+++ b/src/main/java/example/JumpingDino.java
@@ -11,7 +11,7 @@ public class JumpingDino {
RNG.setSeed(55);
RLController rl = new RLController<>(
- new DinoWorld(),
+ new DinoWorld(true, true),
Method.MC_ONPOLICY_EGREEDY,
DinoAction.values());