add epsilon convergence test and will remove unnecessary multithreaded learning

This commit is contained in:
Jan Löwenstrom 2020-03-03 02:52:39 +01:00
parent 6613e23c7c
commit 9b54b72a25
8 changed files with 167 additions and 64 deletions

BIN
Rplot.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.5 KiB

20
convergence.txt Normal file
View File

@ -0,0 +1,20 @@
0.05,300,2426,159,339,537,264,1223,298,314,295,177,343,272,256,214,105,87,2851,395,4537,9444,1,7466,25,1371,1020,112,5354,358,48,181,2327,112,21,682,504,2593,637,179,348,215,1294,1222,177,4454,974,1,24,4844,95,961,24,900,424,72,81,572,464,1280,1104,118,345,1748,506,2558,92,1977,804,1884,640,351,3124,2849,2360,708,143,177,209,69,1479,602,2108,2864,223,928,190,5142,79,774,540,1882,966,226,2262,295,2267,1536,103,244,2202
0.1,76,346,1058,131,17,773,58,168,78,3976,3294,548,271,22,135,543,408,441,33,988,911,337,82,1368,1107,3645,6,64,783,388,110,1096,189,347,36,1115,46,79,76,385,400,1455,333,1259,224,417,17,75,1095,334,6,711,150,28,510,54,759,1,931,156,1176,221,218,404,661,221,103,44,154,56,116,3316,348,338,130,257,1243,278,150,488,65,1063,33,77,17,827,51,674,885,374,1354,1245,1391,87,70,420,47,118,484,51
0.15,245,125,2100,41,73,69,75,262,43,91,48,10,80,2380,441,29,446,590,112,13,1496,50,131,633,262,851,738,43,167,95,1107,392,278,133,359,376,354,146,63,304,99,311,102,1047,123,298,757,158,5,33,1118,1727,169,182,265,24,39,62,53,442,6,50,127,523,729,192,19,268,351,1321,284,618,77,251,54,258,957,17,338,104,53,91,20,902,139,391,140,77,40,176,1066,7,1642,8,75,87,624,247,56,36
0.2,160,24,239,85,142,27,94,138,554,348,104,266,3,164,59,37,28,885,629,14,22,193,306,288,55,841,95,75,146,19,349,110,167,3,14,166,71,61,554,244,362,341,47,96,229,101,968,878,785,1,20,82,6,30,396,172,3,170,244,71,750,86,137,76,71,273,177,72,12,150,470,31,879,766,472,1958,10,420,16,19,170,82,575,110,32,493,249,12,106,996,297,184,111,127,1641,148,73,46,42,433
0.25,103,256,99,321,122,56,266,185,27,55,58,107,174,769,280,433,800,51,166,38,9,48,533,186,893,519,46,213,647,310,109,62,376,382,91,9,186,45,41,38,84,55,1725,128,161,693,325,270,38,1,588,953,26,201,494,79,235,224,65,260,57,843,111,357,526,59,328,893,372,826,26,506,125,168,414,71,24,22,144,270,530,140,261,297,13,181,1257,219,675,135,148,61,322,105,52,108,25,51,348,86
0.3,179,82,59,15,42,10,13,69,121,30,240,264,93,41,27,537,32,462,145,20,270,98,25,105,78,81,155,411,75,94,120,229,44,183,41,122,281,344,11,314,30,223,200,42,26,52,12,85,799,521,187,44,420,17,109,323,51,169,114,6,133,20,558,85,6,252,54,23,245,139,22,140,320,177,91,773,92,87,26,60,380,226,198,198,178,453,541,210,237,7,158,70,128,483,101,854,75,217,156,121
0.35,156,101,203,72,359,238,53,72,236,179,229,75,135,82,44,220,58,90,239,82,98,294,31,209,51,6,158,39,18,339,164,23,24,189,182,80,80,104,552,204,179,186,141,87,125,122,451,13,337,244,16,81,41,73,93,43,120,279,184,280,90,10,54,71,18,53,315,427,67,65,54,8,182,71,227,21,293,175,101,16,365,138,9,93,404,274,17,922,198,90,196,17,341,99,98,243,173,184,218,74
0.40,45,30,405,378,48,534,78,5,107,180,71,35,833,313,48,140,68,125,215,105,104,152,41,83,90,14,105,65,127,16,116,234,99,25,392,54,311,139,122,36,164,124,135,113,61,254,91,137,23,201,92,278,88,100,153,91,96,53,36,309,217,8,31,169,17,236,304,223,37,214,56,658,63,115,3,13,69,52,76,33,225,81,59,2,66,84,2,52,67,8,509,51,139,82,51,198,5,22,71,80
0.45,69,89,236,173,91,327,198,117,141,47,284,40,54,43,10,56,337,56,35,7,48,57,102,146,10,149,25,178,26,91,49,238,62,12,60,3,35,59,27,117,18,152,180,3,129,24,18,71,23,126,138,25,146,374,434,270,38,54,105,82,22,82,35,189,120,29,448,56,49,208,405,120,223,23,5,90,112,195,47,129,53,83,48,81,77,12,91,159,219,22,120,3,383,1,36,340,109,231,11,68
0.50,905,53,32,251,57,219,103,87,145,32,174,150,33,75,61,69,78,49,22,1,121,248,45,45,242,40,115,257,58,113,283,157,152,53,634,260,23,101,77,7,139,32,47,218,66,32,68,46,182,2,65,204,175,91,3,102,124,96,283,331,3,63,78,32,1,34,48,68,54,192,91,28,1,31,18,326,273,13,166,415,36,20,80,40,14,64,5,17,29,115,115,209,209,28,47,3,34,48,97,32
0.55,363,334,78,36,344,270,80,67,115,27,228,94,56,122,100,44,41,85,266,53,158,179,120,31,121,108,50,73,50,154,83,6,328,133,151,24,129,117,39,48,92,9,103,172,24,160,87,99,35,231,243,56,17,52,161,136,24,38,20,42,122,51,53,231,295,52,37,51,190,84,13,12,48,46,121,27,169,219,170,97,86,12,117,54,35,127,55,3,108,199,8,3,30,83,266,65,248,296,73,192
0.60,144,14,172,55,264,46,56,59,290,256,211,97,90,29,242,90,16,31,25,59,75,74,66,240,155,6,57,53,282,190,61,129,37,23,3,260,361,22,57,125,43,102,36,59,36,81,148,40,39,62,32,232,158,2,1,63,414,38,194,137,51,183,387,15,37,75,40,69,313,23,271,54,178,50,11,47,25,70,273,58,52,11,21,36,92,174,144,47,120,38,66,84,195,342,99,186,164,64,221,15
0.65,73,139,22,146,3,111,16,12,55,51,129,164,59,21,1,198,93,81,119,40,168,51,88,2,49,32,112,146,80,183,49,40,36,231,29,52,102,22,193,160,1,28,21,55,136,90,141,315,154,103,42,2,146,40,46,51,60,24,95,13,215,12,37,250,146,64,8,59,215,120,200,28,139,21,64,23,61,5,181,161,157,74,261,97,39,68,125,444,118,198,67,33,17,194,30,254,79,94,37,210
0.70,74,103,14,73,123,115,128,93,12,54,1,34,78,294,77,27,5,17,46,31,255,154,83,59,66,152,125,103,28,94,73,79,110,290,3,100,208,42,28,193,120,278,170,141,169,45,143,121,145,127,91,97,126,287,80,97,79,83,102,109,241,280,247,26,90,84,104,78,99,218,17,256,380,135,18,17,121,72,119,50,35,46,54,219,55,113,185,178,140,55,121,155,126,4,8,115,40,13,98,12
0.75,9,91,100,130,203,86,54,44,1,52,51,43,116,135,121,13,482,59,2,74,98,18,85,137,57,106,116,161,60,38,7,46,54,228,204,120,97,405,175,74,199,151,35,42,371,93,197,17,94,128,191,301,201,55,106,126,178,116,369,37,35,71,26,102,153,62,34,47,14,89,102,124,98,152,12,117,136,1,146,72,219,38,61,106,34,22,27,56,120,48,6,6,29,57,98,9,208,39,390,147
0.80,42,45,81,110,48,158,63,40,60,106,52,90,129,105,81,193,54,67,164,79,302,86,237,160,204,213,18,50,37,85,168,148,84,213,13,25,185,25,8,20,69,78,312,62,255,50,248,144,223,3,34,161,132,138,113,39,60,83,65,20,94,96,196,51,30,58,147,22,84,102,109,189,79,124,196,58,147,69,33,130,82,10,290,77,39,48,76,158,202,60,178,20,132,39,19,85,186,108,35,222
0.85,74,12,115,22,86,14,503,244,15,89,81,186,17,79,133,87,4,98,68,30,105,214,33,46,9,105,173,309,117,147,54,84,27,293,347,136,17,57,120,54,235,98,90,155,71,274,4,327,128,460,93,122,196,98,148,51,125,113,72,325,70,202,39,186,42,37,31,45,154,20,80,61,24,32,66,127,43,85,90,146,27,146,78,60,166,213,67,199,119,21,158,128,147,16,37,28,2,145,282,40
0.90,31,147,81,271,124,26,142,48,49,103,61,114,62,188,135,220,56,100,107,267,4,372,165,227,26,226,180,241,50,83,39,217,480,208,95,58,55,95,28,66,370,32,37,10,11,292,85,470,173,185,1,331,116,593,47,365,140,88,189,73,17,15,250,100,483,210,65,72,89,52,48,53,47,164,94,4,141,143,140,263,50,135,1,152,67,68,59,166,114,11,41,224,26,40,81,224,132,49,439,41
0.95,183,71,44,60,53,174,226,160,83,147,83,203,83,146,331,218,2,111,17,152,238,127,187,108,66,55,88,51,49,193,158,249,132,276,383,215,365,31,136,181,284,21,92,120,55,107,340,296,135,13,76,293,301,193,119,57,23,232,23,119,76,137,68,135,118,42,67,202,155,156,57,178,259,221,22,102,69,97,75,164,83,137,46,165,80,33,279,118,65,30,4,134,6,28,161,23,16,403,172,66
1.00,264,486,133,12,484,71,319,98,96,110,251,22,221,201,39,6,28,384,280,105,41,226,298,135,341,117,125,432,180,89,73,146,199,17,70,354,264,50,137,147,343,393,332,59,128,44,21,569,123,285,110,153,15,168,142,361,216,331,22,338,132,124,371,28,4,77,125,466,48,231,254,231,8,287,96,114,162,201,154,247,220,133,123,2,175,71,213,280,74,25,343,480,70,206,22,23,92,107,204,91

View File

@ -9,9 +9,13 @@ import core.policy.EpsilonGreedyPolicy;
import lombok.Getter;
import lombok.Setter;
import java.io.File;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
@ -25,7 +29,8 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
protected int episodeSumCurrentSecond;
protected double sumOfRewards;
protected List<StepResult<A>> episode = new ArrayList<>();
protected int timestampCurrentEpisode = 0;
protected boolean converged;
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay) {
super(environment, actionSpace, discountFactor, delay);
initBenchMarking();
@ -50,7 +55,7 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
private void initBenchMarking(){
new Thread(()->{
while (true){
while (currentlyLearning){
episodePerSecond = episodeSumCurrentSecond;
episodeSumCurrentSecond = 0;
try {
@ -62,7 +67,7 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
}).start();
}
protected void dispatchEpisodeEnd(){
private void dispatchEpisodeEnd(){
++episodeSumCurrentSecond;
if(rewardHistory.size() > 10000){
rewardHistory.clear();
@ -73,20 +78,20 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
}
}
protected void dispatchEpisodeStart(){
private void dispatchEpisodeStart(){
++currentEpisode;
/*
2f 0.02 => 100
1.5f 0.02 => 75
1.4f 0.02 => fail
1.5f 0.1 => 16 !
*/
if(this.policy instanceof EpsilonGreedyPolicy){
float ep = 1.5f/(float)currentEpisode;
if(ep < 0.10) ep = 0;
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(ep);
System.out.println(ep);
}
2f 0.02 => 100
1.5f 0.02 => 75
1.4f 0.02 => fail
1.5f 0.1 => 16 !
*/
// if(this.policy instanceof EpsilonGreedyPolicy){
// float ep = 2f/(float)currentEpisode;
// if(ep < 0.02) ep = 0;
// ((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(ep);
// System.out.println(ep);
// }
episodesToLearn.decrementAndGet();
for(LearningListener l: learningListeners){
l.onEpisodeStart();
@ -97,10 +102,20 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
protected void dispatchStepEnd() {
super.dispatchStepEnd();
timestamp++;
timestampCurrentEpisode++;
// TODO: more sophisticated way to check convergence
if(timestamp > 300000){
System.out.println("converged after: " + currentEpisode + " episode!");
interruptLearning();
if(timestampCurrentEpisode > 300000){
converged = true;
// t
File file = new File("convergence.txt");
try {
Files.writeString(Path.of(file.getPath()), currentEpisode/2 + ",", StandardOpenOption.APPEND);
} catch (IOException e) {
e.printStackTrace();
}
System.out.println("converged after: " + currentEpisode/2 + " episode!");
episodesToLearn.set(0);
dispatchLearningEnd();
}
}
@ -110,20 +125,23 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
}
private void startLearning(){
learningExecutor.submit(()->{
dispatchLearningStart();
while(episodesToLearn.get() > 0){
dispatchEpisodeStart();
nextEpisode();
dispatchEpisodeEnd();
}
synchronized (this){
dispatchLearningEnd();
notifyAll();
}
});
dispatchLearningStart();
while(episodesToLearn.get() > 0){
dispatchEpisodeStart();
timestampCurrentEpisode = 0;
nextEpisode();
dispatchEpisodeEnd();
}
synchronized (this){
dispatchLearningEnd();
notifyAll();
}
}
public void learnMoreEpisodes(int nrOfEpisodes){
episodesToLearn.addAndGet(nrOfEpisodes);
}
/**
* Stopping the while loop by setting episodesToLearn to 0.
* The current episode can not be interrupted, so the sleep delay

View File

@ -42,8 +42,7 @@ public abstract class Learning<A extends Enum>{
@Setter
protected int delay;
protected List<Double> rewardHistory;
protected ExecutorService learningExecutor;
protected boolean currentlyLearning;
protected volatile boolean currentlyLearning;
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay) {
this.environment = environment;
@ -53,7 +52,6 @@ public abstract class Learning<A extends Enum>{
currentlyLearning = false;
learningListeners = new HashSet<>();
rewardHistory = new CopyOnWriteArrayList<>();
learningExecutor = Executors.newSingleThreadExecutor();
}
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor) {
@ -89,8 +87,6 @@ public abstract class Learning<A extends Enum>{
protected void dispatchLearningEnd() {
currentlyLearning = false;
System.out.println("Checksum: " + checkSum);
System.out.println("Reward Checksum: " + rewardCheckSum);
for (LearningListener l : learningListeners) {
l.onLearningEnd();
}

View File

@ -3,12 +3,17 @@ package core.algo.mc;
import core.*;
import core.algo.EpisodicLearning;
import core.policy.EpsilonGreedyPolicy;
import core.policy.GreedyPolicy;
import core.policy.Policy;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.*;
import java.net.URI;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.util.*;
/**
@ -35,8 +40,16 @@ public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends Episodic
private Map<Pair<State, A>, Double> returnSum;
private Map<Pair<State, A>, Integer> returnCount;
// t
private float epsilon;
// t
private Policy<A> greedyPolicy = new GreedyPolicy<>();
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
super(environment, actionSpace, discountFactor, delay);
// t
this.epsilon = epsilon;
this.policy = new EpsilonGreedyPolicy<>(epsilon);
this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace);
returnSum = new HashMap<>();
@ -58,12 +71,16 @@ public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends Episodic
}
sumOfRewards = 0;
StepResultEnvironment envResult = null;
//TODO extract to learning
int timestamp = 0;
while(envResult == null || !envResult.isDone()) {
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
A chosenAction = policy.chooseAction(actionValues);
checkSum += chosenAction.ordinal();
A chosenAction;
if(currentEpisode % 2 == 1){
chosenAction = greedyPolicy.chooseAction(actionValues);
}else{
chosenAction = policy.chooseAction(actionValues);
}
envResult = environment.step(chosenAction);
State nextState = envResult.getState();
sumOfRewards += envResult.getReward();
@ -79,6 +96,11 @@ public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends Episodic
}
timestamp++;
dispatchStepEnd();
if(converged) return;
}
if(currentEpisode % 2 == 1){
return;
}
// System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);

View File

@ -67,17 +67,17 @@ public class RLController<A extends Enum> implements LearningListener {
}
protected void initListeners() {
learning.addListener(this);
new Thread(() -> {
while(true) {
printNextEpisode = true;
try {
Thread.sleep(30 * 1000);
} catch (InterruptedException e) {
e.printStackTrace();
learning.addListener(this);
new Thread(() -> {
while(learning.isCurrentlyLearning()) {
printNextEpisode = true;
try {
Thread.sleep(30 * 1000);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
}).start();
}).start();
}
private void initLearning() {
@ -95,7 +95,13 @@ public class RLController<A extends Enum> implements LearningListener {
protected void learnMoreEpisodes(int nrOfEpisodes) {
if(learning instanceof EpisodicLearning) {
((EpisodicLearning) learning).learn(nrOfEpisodes);
if(learning.isCurrentlyLearning()){
((EpisodicLearning) learning).learnMoreEpisodes(nrOfEpisodes);
}else{
new Thread(() -> {
((EpisodicLearning) learning).learn(nrOfEpisodes);
}).start();
}
} else {
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
}

View File

@ -3,23 +3,49 @@ package example;
import core.RNG;
import core.algo.Method;
import core.controller.RLController;
import core.controller.RLControllerGUI;
import evironment.jumpingDino.DinoAction;
import evironment.jumpingDino.DinoWorld;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
public class JumpingDino {
public static void main(String[] args) {
RNG.setSeed(55);
File file = new File("convergence.txt");
for(float f = 0.05f; f <=1.003 ; f+=0.05f){
try {
Files.writeString(Path.of(file.getPath()), f + ",", StandardOpenOption.APPEND);
} catch (IOException e) {
e.printStackTrace();
}
for(int i = 1; i <= 100; i++) {
System.out.println("seed: " + i *13);
RNG.setSeed(i *13);
RLController<DinoAction> rl = new RLController<>(
new DinoWorld(false, false),
Method.MC_CONTROL_FIRST_VISIT,
DinoAction.values());
RLController<DinoAction> rl = new RLController<>(
new DinoWorld(false, false),
Method.MC_CONTROL_FIRST_VISIT,
DinoAction.values());
rl.setDelay(0);
rl.setDiscountFactor(1f);
rl.setEpsilon(0.15f);
rl.setLearningRate(1f);
rl.setNrOfEpisodes(400);
rl.start();
rl.setDelay(0);
rl.setDiscountFactor(1f);
rl.setEpsilon(f);
rl.setLearningRate(1f);
rl.setNrOfEpisodes(20000);
rl.start();
}
try {
Files.writeString(Path.of(file.getPath()), "\n", StandardOpenOption.APPEND);
} catch (IOException e) {
e.printStackTrace();
}
}
System.out.println("kek");
}
}

View File

@ -0,0 +1,15 @@
Method:
Epsilon = k / currentEpisode
set to 0 if Epsilon < b
k = 1.5
b = 0.1 => conv. 16
k = 1.5
b = 0.02 => 75
k = 1.4
b = 0.02 => fail
k = 2.0
b = 0.02 => conv. 100