diff --git a/src/main/java/core/algo/EpisodicLearning.java b/src/main/java/core/algo/EpisodicLearning.java
index 559bfc6..b45ea50 100644
--- a/src/main/java/core/algo/EpisodicLearning.java
+++ b/src/main/java/core/algo/EpisodicLearning.java
@@ -50,7 +50,7 @@ public abstract class EpisodicLearning extends Learning imple
private void initBenchMarking(){
new Thread(()->{
- while (true){
+ while (currentlyLearning){
episodePerSecond = episodeSumCurrentSecond;
episodeSumCurrentSecond = 0;
try {
@@ -62,7 +62,7 @@ public abstract class EpisodicLearning extends Learning imple
}).start();
}
- protected void dispatchEpisodeEnd(){
+ private void dispatchEpisodeEnd(){
++episodeSumCurrentSecond;
if(rewardHistory.size() > 10000){
rewardHistory.clear();
@@ -75,18 +75,6 @@ public abstract class EpisodicLearning extends Learning imple
protected void dispatchEpisodeStart(){
++currentEpisode;
- /*
- 2f 0.02 => 100
- 1.5f 0.02 => 75
- 1.4f 0.02 => fail
- 1.5f 0.1 => 16 !
- */
- if(this.policy instanceof EpsilonGreedyPolicy){
- float ep = 1.5f/(float)currentEpisode;
- if(ep < 0.10) ep = 0;
- ((EpsilonGreedyPolicy) this.policy).setEpsilon(ep);
- System.out.println(ep);
- }
episodesToLearn.decrementAndGet();
for(LearningListener l: learningListeners){
l.onEpisodeStart();
@@ -97,31 +85,24 @@ public abstract class EpisodicLearning extends Learning imple
protected void dispatchStepEnd() {
super.dispatchStepEnd();
timestamp++;
- // TODO: more sophisticated way to check convergence
- if(timestamp > 300000){
- System.out.println("converged after: " + currentEpisode + " episode!");
- interruptLearning();
- }
- }
-
- @Override
- public void learn(){
- learn(LearningConfig.DEFAULT_NR_OF_EPISODES);
}
private void startLearning(){
- learningExecutor.submit(()->{
- dispatchLearningStart();
- while(episodesToLearn.get() > 0){
- dispatchEpisodeStart();
- nextEpisode();
- dispatchEpisodeEnd();
- }
- synchronized (this){
- dispatchLearningEnd();
- notifyAll();
- }
- });
+ dispatchLearningStart();
+ System.out.println(episodesToLearn.get());
+ while(episodesToLearn.get() > 0){
+ dispatchEpisodeStart();
+ nextEpisode();
+ dispatchEpisodeEnd();
+ }
+ synchronized (this){
+ dispatchLearningEnd();
+ notifyAll();
+ }
+ }
+
+ public void learnMoreEpisodes(int nrOfEpisodes){
+ episodesToLearn.addAndGet(nrOfEpisodes);
}
/**
@@ -146,8 +127,14 @@ public abstract class EpisodicLearning extends Learning imple
delay = prevDelay;
}
+ @Override
+ public void learn(){
+ learn(LearningConfig.DEFAULT_NR_OF_EPISODES);
+ }
+
public synchronized void learn(int nrOfEpisodes){
boolean isLearning = episodesToLearn.getAndAdd(nrOfEpisodes) != 0;
+ System.out.println(isLearning);
if(!isLearning)
startLearning();
}
diff --git a/src/main/java/core/algo/Learning.java b/src/main/java/core/algo/Learning.java
index fbba9ef..1bb5207 100644
--- a/src/main/java/core/algo/Learning.java
+++ b/src/main/java/core/algo/Learning.java
@@ -42,8 +42,7 @@ public abstract class Learning{
@Setter
protected int delay;
protected List rewardHistory;
- protected ExecutorService learningExecutor;
- protected boolean currentlyLearning;
+ protected volatile boolean currentlyLearning;
public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, int delay) {
this.environment = environment;
@@ -53,7 +52,6 @@ public abstract class Learning{
currentlyLearning = false;
learningListeners = new HashSet<>();
rewardHistory = new CopyOnWriteArrayList<>();
- learningExecutor = Executors.newSingleThreadExecutor();
}
public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor) {
@@ -89,8 +87,6 @@ public abstract class Learning{
protected void dispatchLearningEnd() {
currentlyLearning = false;
- System.out.println("Checksum: " + checkSum);
- System.out.println("Reward Checksum: " + rewardCheckSum);
for (LearningListener l : learningListeners) {
l.onLearningEnd();
}
diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java
index d855ab1..837e427 100644
--- a/src/main/java/core/controller/RLController.java
+++ b/src/main/java/core/controller/RLController.java
@@ -83,7 +83,7 @@ public class RLController implements LearningListener {
private void initLearning() {
if(learning instanceof EpisodicLearning) {
System.out.println("Starting learning of <" + nrOfEpisodes + "> episodes");
- ((EpisodicLearning) learning).learn(nrOfEpisodes);
+ ((EpisodicLearning) learning).learn(nrOfEpisodes);
} else {
learning.learn();
}
@@ -95,7 +95,13 @@ public class RLController implements LearningListener {
protected void learnMoreEpisodes(int nrOfEpisodes) {
if(learning instanceof EpisodicLearning) {
- ((EpisodicLearning) learning).learn(nrOfEpisodes);
+ if(learning.isCurrentlyLearning()){
+ ((EpisodicLearning) learning).learnMoreEpisodes(nrOfEpisodes);
+ }else{
+ new Thread(() -> {
+ ((EpisodicLearning) learning).learn(nrOfEpisodes);
+ }).start();
+ }
} else {
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
}
@@ -169,8 +175,8 @@ public class RLController implements LearningListener {
public void onEpisodeEnd(List rewardHistory) {
latestRewardsHistory = rewardHistory;
if(printNextEpisode) {
- System.out.println("Episode " + ((EpisodicLearning) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1));
- System.out.println("Eps/sec: " + ((EpisodicLearning) learning).getEpisodePerSecond());
+ System.out.println("Episode " + ((EpisodicLearning) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1));
+ System.out.println("Eps/sec: " + ((EpisodicLearning) learning).getEpisodePerSecond());
printNextEpisode = false;
}
}
diff --git a/src/main/java/core/gui/LearningInfoPanel.java b/src/main/java/core/gui/LearningInfoPanel.java
index 4959cee..ed2ebad 100644
--- a/src/main/java/core/gui/LearningInfoPanel.java
+++ b/src/main/java/core/gui/LearningInfoPanel.java
@@ -60,9 +60,9 @@ public class LearningInfoPanel extends JPanel {
viewListener.onFastLearnChange(fastLearning);
});
smoothGraphCheckbox = new JCheckBox("Smoothen Graph");
- smoothGraphCheckbox.setSelected(false);
+ smoothGraphCheckbox.setSelected(true);
last100Checkbox = new JCheckBox("Only show last 100 Rewards");
- last100Checkbox.setSelected(true);
+ last100Checkbox.setSelected(false);
drawEnvironmentCheckbox = new JCheckBox("Update Environment");
drawEnvironmentCheckbox.setSelected(true);
diff --git a/src/main/java/example/DinoSampling.java b/src/main/java/example/DinoSampling.java
deleted file mode 100644
index 09d5c5b..0000000
--- a/src/main/java/example/DinoSampling.java
+++ /dev/null
@@ -1,27 +0,0 @@
-package example;
-
-import core.RNG;
-import core.algo.Method;
-import core.controller.RLController;
-import evironment.jumpingDino.DinoAction;
-import evironment.jumpingDino.DinoWorld;
-
-public class DinoSampling {
- public static void main(String[] args) {
- for (int i = 0; i < 10 ; i++) {
- RNG.setSeed(55);
-
- RLController rl = new RLController<>(
- new DinoWorld(false, false),
- Method.MC_CONTROL_FIRST_VISIT,
- DinoAction.values());
-
- rl.setDelay(0);
- rl.setDiscountFactor(1f);
- rl.setEpsilon(0.15f);
- rl.setLearningRate(1f);
- rl.setNrOfEpisodes(400);
- rl.start();
- }
- }
-}
diff --git a/src/main/java/example/JumpingDino.java b/src/main/java/example/JumpingDino.java
index 41d5290..ce81753 100644
--- a/src/main/java/example/JumpingDino.java
+++ b/src/main/java/example/JumpingDino.java
@@ -3,6 +3,7 @@ package example;
import core.RNG;
import core.algo.Method;
import core.controller.RLController;
+import core.controller.RLControllerGUI;
import evironment.jumpingDino.DinoAction;
import evironment.jumpingDino.DinoWorld;
@@ -10,16 +11,16 @@ public class JumpingDino {
public static void main(String[] args) {
RNG.setSeed(55);
- RLController rl = new RLController<>(
+ RLController rl = new RLControllerGUI<>(
new DinoWorld(false, false),
Method.MC_CONTROL_FIRST_VISIT,
DinoAction.values());
- rl.setDelay(0);
+ rl.setDelay(100);
rl.setDiscountFactor(1f);
rl.setEpsilon(0.15f);
rl.setLearningRate(1f);
- rl.setNrOfEpisodes(400);
+ rl.setNrOfEpisodes(10000);
rl.start();
}
}