diff --git a/Rplot.png b/Rplot.png
new file mode 100644
index 0000000..dc1e2e4
Binary files /dev/null and b/Rplot.png differ
diff --git a/convergence.txt b/convergence.txt
new file mode 100644
index 0000000..ba6b6c3
--- /dev/null
+++ b/convergence.txt
@@ -0,0 +1,20 @@
+0.05,300,2426,159,339,537,264,1223,298,314,295,177,343,272,256,214,105,87,2851,395,4537,9444,1,7466,25,1371,1020,112,5354,358,48,181,2327,112,21,682,504,2593,637,179,348,215,1294,1222,177,4454,974,1,24,4844,95,961,24,900,424,72,81,572,464,1280,1104,118,345,1748,506,2558,92,1977,804,1884,640,351,3124,2849,2360,708,143,177,209,69,1479,602,2108,2864,223,928,190,5142,79,774,540,1882,966,226,2262,295,2267,1536,103,244,2202
+0.1,76,346,1058,131,17,773,58,168,78,3976,3294,548,271,22,135,543,408,441,33,988,911,337,82,1368,1107,3645,6,64,783,388,110,1096,189,347,36,1115,46,79,76,385,400,1455,333,1259,224,417,17,75,1095,334,6,711,150,28,510,54,759,1,931,156,1176,221,218,404,661,221,103,44,154,56,116,3316,348,338,130,257,1243,278,150,488,65,1063,33,77,17,827,51,674,885,374,1354,1245,1391,87,70,420,47,118,484,51
+0.15,245,125,2100,41,73,69,75,262,43,91,48,10,80,2380,441,29,446,590,112,13,1496,50,131,633,262,851,738,43,167,95,1107,392,278,133,359,376,354,146,63,304,99,311,102,1047,123,298,757,158,5,33,1118,1727,169,182,265,24,39,62,53,442,6,50,127,523,729,192,19,268,351,1321,284,618,77,251,54,258,957,17,338,104,53,91,20,902,139,391,140,77,40,176,1066,7,1642,8,75,87,624,247,56,36
+0.2,160,24,239,85,142,27,94,138,554,348,104,266,3,164,59,37,28,885,629,14,22,193,306,288,55,841,95,75,146,19,349,110,167,3,14,166,71,61,554,244,362,341,47,96,229,101,968,878,785,1,20,82,6,30,396,172,3,170,244,71,750,86,137,76,71,273,177,72,12,150,470,31,879,766,472,1958,10,420,16,19,170,82,575,110,32,493,249,12,106,996,297,184,111,127,1641,148,73,46,42,433
+0.25,103,256,99,321,122,56,266,185,27,55,58,107,174,769,280,433,800,51,166,38,9,48,533,186,893,519,46,213,647,310,109,62,376,382,91,9,186,45,41,38,84,55,1725,128,161,693,325,270,38,1,588,953,26,201,494,79,235,224,65,260,57,843,111,357,526,59,328,893,372,826,26,506,125,168,414,71,24,22,144,270,530,140,261,297,13,181,1257,219,675,135,148,61,322,105,52,108,25,51,348,86
+0.3,179,82,59,15,42,10,13,69,121,30,240,264,93,41,27,537,32,462,145,20,270,98,25,105,78,81,155,411,75,94,120,229,44,183,41,122,281,344,11,314,30,223,200,42,26,52,12,85,799,521,187,44,420,17,109,323,51,169,114,6,133,20,558,85,6,252,54,23,245,139,22,140,320,177,91,773,92,87,26,60,380,226,198,198,178,453,541,210,237,7,158,70,128,483,101,854,75,217,156,121
+0.35,156,101,203,72,359,238,53,72,236,179,229,75,135,82,44,220,58,90,239,82,98,294,31,209,51,6,158,39,18,339,164,23,24,189,182,80,80,104,552,204,179,186,141,87,125,122,451,13,337,244,16,81,41,73,93,43,120,279,184,280,90,10,54,71,18,53,315,427,67,65,54,8,182,71,227,21,293,175,101,16,365,138,9,93,404,274,17,922,198,90,196,17,341,99,98,243,173,184,218,74
+0.40,45,30,405,378,48,534,78,5,107,180,71,35,833,313,48,140,68,125,215,105,104,152,41,83,90,14,105,65,127,16,116,234,99,25,392,54,311,139,122,36,164,124,135,113,61,254,91,137,23,201,92,278,88,100,153,91,96,53,36,309,217,8,31,169,17,236,304,223,37,214,56,658,63,115,3,13,69,52,76,33,225,81,59,2,66,84,2,52,67,8,509,51,139,82,51,198,5,22,71,80
+0.45,69,89,236,173,91,327,198,117,141,47,284,40,54,43,10,56,337,56,35,7,48,57,102,146,10,149,25,178,26,91,49,238,62,12,60,3,35,59,27,117,18,152,180,3,129,24,18,71,23,126,138,25,146,374,434,270,38,54,105,82,22,82,35,189,120,29,448,56,49,208,405,120,223,23,5,90,112,195,47,129,53,83,48,81,77,12,91,159,219,22,120,3,383,1,36,340,109,231,11,68
+0.50,905,53,32,251,57,219,103,87,145,32,174,150,33,75,61,69,78,49,22,1,121,248,45,45,242,40,115,257,58,113,283,157,152,53,634,260,23,101,77,7,139,32,47,218,66,32,68,46,182,2,65,204,175,91,3,102,124,96,283,331,3,63,78,32,1,34,48,68,54,192,91,28,1,31,18,326,273,13,166,415,36,20,80,40,14,64,5,17,29,115,115,209,209,28,47,3,34,48,97,32
+0.55,363,334,78,36,344,270,80,67,115,27,228,94,56,122,100,44,41,85,266,53,158,179,120,31,121,108,50,73,50,154,83,6,328,133,151,24,129,117,39,48,92,9,103,172,24,160,87,99,35,231,243,56,17,52,161,136,24,38,20,42,122,51,53,231,295,52,37,51,190,84,13,12,48,46,121,27,169,219,170,97,86,12,117,54,35,127,55,3,108,199,8,3,30,83,266,65,248,296,73,192
+0.60,144,14,172,55,264,46,56,59,290,256,211,97,90,29,242,90,16,31,25,59,75,74,66,240,155,6,57,53,282,190,61,129,37,23,3,260,361,22,57,125,43,102,36,59,36,81,148,40,39,62,32,232,158,2,1,63,414,38,194,137,51,183,387,15,37,75,40,69,313,23,271,54,178,50,11,47,25,70,273,58,52,11,21,36,92,174,144,47,120,38,66,84,195,342,99,186,164,64,221,15
+0.65,73,139,22,146,3,111,16,12,55,51,129,164,59,21,1,198,93,81,119,40,168,51,88,2,49,32,112,146,80,183,49,40,36,231,29,52,102,22,193,160,1,28,21,55,136,90,141,315,154,103,42,2,146,40,46,51,60,24,95,13,215,12,37,250,146,64,8,59,215,120,200,28,139,21,64,23,61,5,181,161,157,74,261,97,39,68,125,444,118,198,67,33,17,194,30,254,79,94,37,210
+0.70,74,103,14,73,123,115,128,93,12,54,1,34,78,294,77,27,5,17,46,31,255,154,83,59,66,152,125,103,28,94,73,79,110,290,3,100,208,42,28,193,120,278,170,141,169,45,143,121,145,127,91,97,126,287,80,97,79,83,102,109,241,280,247,26,90,84,104,78,99,218,17,256,380,135,18,17,121,72,119,50,35,46,54,219,55,113,185,178,140,55,121,155,126,4,8,115,40,13,98,12
+0.75,9,91,100,130,203,86,54,44,1,52,51,43,116,135,121,13,482,59,2,74,98,18,85,137,57,106,116,161,60,38,7,46,54,228,204,120,97,405,175,74,199,151,35,42,371,93,197,17,94,128,191,301,201,55,106,126,178,116,369,37,35,71,26,102,153,62,34,47,14,89,102,124,98,152,12,117,136,1,146,72,219,38,61,106,34,22,27,56,120,48,6,6,29,57,98,9,208,39,390,147
+0.80,42,45,81,110,48,158,63,40,60,106,52,90,129,105,81,193,54,67,164,79,302,86,237,160,204,213,18,50,37,85,168,148,84,213,13,25,185,25,8,20,69,78,312,62,255,50,248,144,223,3,34,161,132,138,113,39,60,83,65,20,94,96,196,51,30,58,147,22,84,102,109,189,79,124,196,58,147,69,33,130,82,10,290,77,39,48,76,158,202,60,178,20,132,39,19,85,186,108,35,222
+0.85,74,12,115,22,86,14,503,244,15,89,81,186,17,79,133,87,4,98,68,30,105,214,33,46,9,105,173,309,117,147,54,84,27,293,347,136,17,57,120,54,235,98,90,155,71,274,4,327,128,460,93,122,196,98,148,51,125,113,72,325,70,202,39,186,42,37,31,45,154,20,80,61,24,32,66,127,43,85,90,146,27,146,78,60,166,213,67,199,119,21,158,128,147,16,37,28,2,145,282,40
+0.90,31,147,81,271,124,26,142,48,49,103,61,114,62,188,135,220,56,100,107,267,4,372,165,227,26,226,180,241,50,83,39,217,480,208,95,58,55,95,28,66,370,32,37,10,11,292,85,470,173,185,1,331,116,593,47,365,140,88,189,73,17,15,250,100,483,210,65,72,89,52,48,53,47,164,94,4,141,143,140,263,50,135,1,152,67,68,59,166,114,11,41,224,26,40,81,224,132,49,439,41
+0.95,183,71,44,60,53,174,226,160,83,147,83,203,83,146,331,218,2,111,17,152,238,127,187,108,66,55,88,51,49,193,158,249,132,276,383,215,365,31,136,181,284,21,92,120,55,107,340,296,135,13,76,293,301,193,119,57,23,232,23,119,76,137,68,135,118,42,67,202,155,156,57,178,259,221,22,102,69,97,75,164,83,137,46,165,80,33,279,118,65,30,4,134,6,28,161,23,16,403,172,66
+1.00,264,486,133,12,484,71,319,98,96,110,251,22,221,201,39,6,28,384,280,105,41,226,298,135,341,117,125,432,180,89,73,146,199,17,70,354,264,50,137,147,343,393,332,59,128,44,21,569,123,285,110,153,15,168,142,361,216,331,22,338,132,124,371,28,4,77,125,466,48,231,254,231,8,287,96,114,162,201,154,247,220,133,123,2,175,71,213,280,74,25,343,480,70,206,22,23,92,107,204,91
diff --git a/src/main/java/core/algo/EpisodicLearning.java b/src/main/java/core/algo/EpisodicLearning.java
index 559bfc6..669c931 100644
--- a/src/main/java/core/algo/EpisodicLearning.java
+++ b/src/main/java/core/algo/EpisodicLearning.java
@@ -9,9 +9,13 @@ import core.policy.EpsilonGreedyPolicy;
import lombok.Getter;
import lombok.Setter;
+import java.io.File;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
@@ -25,7 +29,8 @@ public abstract class EpisodicLearning extends Learning imple
protected int episodeSumCurrentSecond;
protected double sumOfRewards;
protected List> episode = new ArrayList<>();
-
+ protected int timestampCurrentEpisode = 0;
+ protected boolean converged;
public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, int delay) {
super(environment, actionSpace, discountFactor, delay);
initBenchMarking();
@@ -50,7 +55,7 @@ public abstract class EpisodicLearning extends Learning imple
private void initBenchMarking(){
new Thread(()->{
- while (true){
+ while (currentlyLearning){
episodePerSecond = episodeSumCurrentSecond;
episodeSumCurrentSecond = 0;
try {
@@ -62,7 +67,7 @@ public abstract class EpisodicLearning extends Learning imple
}).start();
}
- protected void dispatchEpisodeEnd(){
+ private void dispatchEpisodeEnd(){
++episodeSumCurrentSecond;
if(rewardHistory.size() > 10000){
rewardHistory.clear();
@@ -73,20 +78,20 @@ public abstract class EpisodicLearning extends Learning imple
}
}
- protected void dispatchEpisodeStart(){
+ private void dispatchEpisodeStart(){
++currentEpisode;
/*
- 2f 0.02 => 100
- 1.5f 0.02 => 75
- 1.4f 0.02 => fail
- 1.5f 0.1 => 16 !
- */
- if(this.policy instanceof EpsilonGreedyPolicy){
- float ep = 1.5f/(float)currentEpisode;
- if(ep < 0.10) ep = 0;
- ((EpsilonGreedyPolicy) this.policy).setEpsilon(ep);
- System.out.println(ep);
- }
+ 2f 0.02 => 100
+ 1.5f 0.02 => 75
+ 1.4f 0.02 => fail
+ 1.5f 0.1 => 16 !
+ */
+// if(this.policy instanceof EpsilonGreedyPolicy){
+// float ep = 2f/(float)currentEpisode;
+// if(ep < 0.02) ep = 0;
+// ((EpsilonGreedyPolicy) this.policy).setEpsilon(ep);
+// System.out.println(ep);
+// }
episodesToLearn.decrementAndGet();
for(LearningListener l: learningListeners){
l.onEpisodeStart();
@@ -97,10 +102,20 @@ public abstract class EpisodicLearning extends Learning imple
protected void dispatchStepEnd() {
super.dispatchStepEnd();
timestamp++;
+ timestampCurrentEpisode++;
// TODO: more sophisticated way to check convergence
- if(timestamp > 300000){
- System.out.println("converged after: " + currentEpisode + " episode!");
- interruptLearning();
+ if(timestampCurrentEpisode > 300000){
+ converged = true;
+ // t
+ File file = new File("convergence.txt");
+ try {
+ Files.writeString(Path.of(file.getPath()), currentEpisode/2 + ",", StandardOpenOption.APPEND);
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ System.out.println("converged after: " + currentEpisode/2 + " episode!");
+ episodesToLearn.set(0);
+ dispatchLearningEnd();
}
}
@@ -110,20 +125,23 @@ public abstract class EpisodicLearning extends Learning imple
}
private void startLearning(){
- learningExecutor.submit(()->{
- dispatchLearningStart();
- while(episodesToLearn.get() > 0){
- dispatchEpisodeStart();
- nextEpisode();
- dispatchEpisodeEnd();
- }
- synchronized (this){
- dispatchLearningEnd();
- notifyAll();
- }
- });
+ dispatchLearningStart();
+ while(episodesToLearn.get() > 0){
+
+ dispatchEpisodeStart();
+ timestampCurrentEpisode = 0;
+ nextEpisode();
+ dispatchEpisodeEnd();
+ }
+ synchronized (this){
+ dispatchLearningEnd();
+ notifyAll();
+ }
}
+ public void learnMoreEpisodes(int nrOfEpisodes){
+ episodesToLearn.addAndGet(nrOfEpisodes);
+ }
/**
* Stopping the while loop by setting episodesToLearn to 0.
* The current episode can not be interrupted, so the sleep delay
diff --git a/src/main/java/core/algo/Learning.java b/src/main/java/core/algo/Learning.java
index fbba9ef..1bb5207 100644
--- a/src/main/java/core/algo/Learning.java
+++ b/src/main/java/core/algo/Learning.java
@@ -42,8 +42,7 @@ public abstract class Learning{
@Setter
protected int delay;
protected List rewardHistory;
- protected ExecutorService learningExecutor;
- protected boolean currentlyLearning;
+ protected volatile boolean currentlyLearning;
public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, int delay) {
this.environment = environment;
@@ -53,7 +52,6 @@ public abstract class Learning{
currentlyLearning = false;
learningListeners = new HashSet<>();
rewardHistory = new CopyOnWriteArrayList<>();
- learningExecutor = Executors.newSingleThreadExecutor();
}
public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor) {
@@ -89,8 +87,6 @@ public abstract class Learning{
protected void dispatchLearningEnd() {
currentlyLearning = false;
- System.out.println("Checksum: " + checkSum);
- System.out.println("Reward Checksum: " + rewardCheckSum);
for (LearningListener l : learningListeners) {
l.onLearningEnd();
}
diff --git a/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java b/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java
index 50e9ce2..5bc93d3 100644
--- a/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java
+++ b/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java
@@ -3,12 +3,17 @@ package core.algo.mc;
import core.*;
import core.algo.EpisodicLearning;
import core.policy.EpsilonGreedyPolicy;
+import core.policy.GreedyPolicy;
+import core.policy.Policy;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
-import java.io.IOException;
-import java.io.ObjectInputStream;
-import java.io.ObjectOutputStream;
+import java.io.*;
+import java.net.URI;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.nio.file.StandardOpenOption;
import java.util.*;
/**
@@ -35,8 +40,16 @@ public class MonteCarloControlFirstVisitEGreedy extends Episodic
private Map, Double> returnSum;
private Map, Integer> returnCount;
+ // t
+ private float epsilon;
+ // t
+ private Policy greedyPolicy = new GreedyPolicy<>();
+
+
public MonteCarloControlFirstVisitEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay) {
super(environment, actionSpace, discountFactor, delay);
+ // t
+ this.epsilon = epsilon;
this.policy = new EpsilonGreedyPolicy<>(epsilon);
this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace);
returnSum = new HashMap<>();
@@ -58,12 +71,16 @@ public class MonteCarloControlFirstVisitEGreedy extends Episodic
}
sumOfRewards = 0;
StepResultEnvironment envResult = null;
- //TODO extract to learning
- int timestamp = 0;
+
while(envResult == null || !envResult.isDone()) {
Map actionValues = stateActionTable.getActionValues(state);
- A chosenAction = policy.chooseAction(actionValues);
- checkSum += chosenAction.ordinal();
+ A chosenAction;
+ if(currentEpisode % 2 == 1){
+ chosenAction = greedyPolicy.chooseAction(actionValues);
+ }else{
+ chosenAction = policy.chooseAction(actionValues);
+ }
+
envResult = environment.step(chosenAction);
State nextState = envResult.getState();
sumOfRewards += envResult.getReward();
@@ -79,6 +96,11 @@ public class MonteCarloControlFirstVisitEGreedy extends Episodic
}
timestamp++;
dispatchStepEnd();
+ if(converged) return;
+ }
+
+ if(currentEpisode % 2 == 1){
+ return;
}
// System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java
index d855ab1..fc33239 100644
--- a/src/main/java/core/controller/RLController.java
+++ b/src/main/java/core/controller/RLController.java
@@ -67,17 +67,17 @@ public class RLController implements LearningListener {
}
protected void initListeners() {
- learning.addListener(this);
- new Thread(() -> {
- while(true) {
- printNextEpisode = true;
- try {
- Thread.sleep(30 * 1000);
- } catch (InterruptedException e) {
- e.printStackTrace();
+ learning.addListener(this);
+ new Thread(() -> {
+ while(learning.isCurrentlyLearning()) {
+ printNextEpisode = true;
+ try {
+ Thread.sleep(30 * 1000);
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
}
- }
- }).start();
+ }).start();
}
private void initLearning() {
@@ -95,7 +95,13 @@ public class RLController implements LearningListener {
protected void learnMoreEpisodes(int nrOfEpisodes) {
if(learning instanceof EpisodicLearning) {
- ((EpisodicLearning) learning).learn(nrOfEpisodes);
+ if(learning.isCurrentlyLearning()){
+ ((EpisodicLearning) learning).learnMoreEpisodes(nrOfEpisodes);
+ }else{
+ new Thread(() -> {
+ ((EpisodicLearning) learning).learn(nrOfEpisodes);
+ }).start();
+ }
} else {
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
}
diff --git a/src/main/java/example/JumpingDino.java b/src/main/java/example/JumpingDino.java
index 41d5290..13efb4e 100644
--- a/src/main/java/example/JumpingDino.java
+++ b/src/main/java/example/JumpingDino.java
@@ -3,23 +3,49 @@ package example;
import core.RNG;
import core.algo.Method;
import core.controller.RLController;
+import core.controller.RLControllerGUI;
import evironment.jumpingDino.DinoAction;
import evironment.jumpingDino.DinoWorld;
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.StandardOpenOption;
+
public class JumpingDino {
public static void main(String[] args) {
- RNG.setSeed(55);
+ File file = new File("convergence.txt");
+ for(float f = 0.05f; f <=1.003 ; f+=0.05f){
+ try {
+ Files.writeString(Path.of(file.getPath()), f + ",", StandardOpenOption.APPEND);
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ for(int i = 1; i <= 100; i++) {
+ System.out.println("seed: " + i *13);
+ RNG.setSeed(i *13);
- RLController rl = new RLController<>(
- new DinoWorld(false, false),
- Method.MC_CONTROL_FIRST_VISIT,
- DinoAction.values());
+ RLController rl = new RLController<>(
+ new DinoWorld(false, false),
+ Method.MC_CONTROL_FIRST_VISIT,
+ DinoAction.values());
- rl.setDelay(0);
- rl.setDiscountFactor(1f);
- rl.setEpsilon(0.15f);
- rl.setLearningRate(1f);
- rl.setNrOfEpisodes(400);
- rl.start();
+
+ rl.setDelay(0);
+ rl.setDiscountFactor(1f);
+ rl.setEpsilon(f);
+ rl.setLearningRate(1f);
+ rl.setNrOfEpisodes(20000);
+ rl.start();
+
+ }
+ try {
+ Files.writeString(Path.of(file.getPath()), "\n", StandardOpenOption.APPEND);
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+ System.out.println("kek");
}
}
diff --git a/src/main/java/example/Results b/src/main/java/example/Results
new file mode 100644
index 0000000..ba42fc7
--- /dev/null
+++ b/src/main/java/example/Results
@@ -0,0 +1,15 @@
+Method:
+Epsilon = k / currentEpisode
+set to 0 if Epsilon < b
+
+k = 1.5
+b = 0.1 => conv. 16
+
+k = 1.5
+b = 0.02 => 75
+
+k = 1.4
+b = 0.02 => fail
+
+k = 2.0
+b = 0.02 => conv. 100