diff --git a/Rplot.png b/Rplot.png new file mode 100644 index 0000000..dc1e2e4 Binary files /dev/null and b/Rplot.png differ diff --git a/convergence.txt b/convergence.txt new file mode 100644 index 0000000..ba6b6c3 --- /dev/null +++ b/convergence.txt @@ -0,0 +1,20 @@ +0.05,300,2426,159,339,537,264,1223,298,314,295,177,343,272,256,214,105,87,2851,395,4537,9444,1,7466,25,1371,1020,112,5354,358,48,181,2327,112,21,682,504,2593,637,179,348,215,1294,1222,177,4454,974,1,24,4844,95,961,24,900,424,72,81,572,464,1280,1104,118,345,1748,506,2558,92,1977,804,1884,640,351,3124,2849,2360,708,143,177,209,69,1479,602,2108,2864,223,928,190,5142,79,774,540,1882,966,226,2262,295,2267,1536,103,244,2202 +0.1,76,346,1058,131,17,773,58,168,78,3976,3294,548,271,22,135,543,408,441,33,988,911,337,82,1368,1107,3645,6,64,783,388,110,1096,189,347,36,1115,46,79,76,385,400,1455,333,1259,224,417,17,75,1095,334,6,711,150,28,510,54,759,1,931,156,1176,221,218,404,661,221,103,44,154,56,116,3316,348,338,130,257,1243,278,150,488,65,1063,33,77,17,827,51,674,885,374,1354,1245,1391,87,70,420,47,118,484,51 +0.15,245,125,2100,41,73,69,75,262,43,91,48,10,80,2380,441,29,446,590,112,13,1496,50,131,633,262,851,738,43,167,95,1107,392,278,133,359,376,354,146,63,304,99,311,102,1047,123,298,757,158,5,33,1118,1727,169,182,265,24,39,62,53,442,6,50,127,523,729,192,19,268,351,1321,284,618,77,251,54,258,957,17,338,104,53,91,20,902,139,391,140,77,40,176,1066,7,1642,8,75,87,624,247,56,36 +0.2,160,24,239,85,142,27,94,138,554,348,104,266,3,164,59,37,28,885,629,14,22,193,306,288,55,841,95,75,146,19,349,110,167,3,14,166,71,61,554,244,362,341,47,96,229,101,968,878,785,1,20,82,6,30,396,172,3,170,244,71,750,86,137,76,71,273,177,72,12,150,470,31,879,766,472,1958,10,420,16,19,170,82,575,110,32,493,249,12,106,996,297,184,111,127,1641,148,73,46,42,433 +0.25,103,256,99,321,122,56,266,185,27,55,58,107,174,769,280,433,800,51,166,38,9,48,533,186,893,519,46,213,647,310,109,62,376,382,91,9,186,45,41,38,84,55,1725,128,161,693,325,270,38,1,588,953,26,201,494,79,235,224,65,260,57,843,111,357,526,59,328,893,372,826,26,506,125,168,414,71,24,22,144,270,530,140,261,297,13,181,1257,219,675,135,148,61,322,105,52,108,25,51,348,86 +0.3,179,82,59,15,42,10,13,69,121,30,240,264,93,41,27,537,32,462,145,20,270,98,25,105,78,81,155,411,75,94,120,229,44,183,41,122,281,344,11,314,30,223,200,42,26,52,12,85,799,521,187,44,420,17,109,323,51,169,114,6,133,20,558,85,6,252,54,23,245,139,22,140,320,177,91,773,92,87,26,60,380,226,198,198,178,453,541,210,237,7,158,70,128,483,101,854,75,217,156,121 +0.35,156,101,203,72,359,238,53,72,236,179,229,75,135,82,44,220,58,90,239,82,98,294,31,209,51,6,158,39,18,339,164,23,24,189,182,80,80,104,552,204,179,186,141,87,125,122,451,13,337,244,16,81,41,73,93,43,120,279,184,280,90,10,54,71,18,53,315,427,67,65,54,8,182,71,227,21,293,175,101,16,365,138,9,93,404,274,17,922,198,90,196,17,341,99,98,243,173,184,218,74 +0.40,45,30,405,378,48,534,78,5,107,180,71,35,833,313,48,140,68,125,215,105,104,152,41,83,90,14,105,65,127,16,116,234,99,25,392,54,311,139,122,36,164,124,135,113,61,254,91,137,23,201,92,278,88,100,153,91,96,53,36,309,217,8,31,169,17,236,304,223,37,214,56,658,63,115,3,13,69,52,76,33,225,81,59,2,66,84,2,52,67,8,509,51,139,82,51,198,5,22,71,80 +0.45,69,89,236,173,91,327,198,117,141,47,284,40,54,43,10,56,337,56,35,7,48,57,102,146,10,149,25,178,26,91,49,238,62,12,60,3,35,59,27,117,18,152,180,3,129,24,18,71,23,126,138,25,146,374,434,270,38,54,105,82,22,82,35,189,120,29,448,56,49,208,405,120,223,23,5,90,112,195,47,129,53,83,48,81,77,12,91,159,219,22,120,3,383,1,36,340,109,231,11,68 +0.50,905,53,32,251,57,219,103,87,145,32,174,150,33,75,61,69,78,49,22,1,121,248,45,45,242,40,115,257,58,113,283,157,152,53,634,260,23,101,77,7,139,32,47,218,66,32,68,46,182,2,65,204,175,91,3,102,124,96,283,331,3,63,78,32,1,34,48,68,54,192,91,28,1,31,18,326,273,13,166,415,36,20,80,40,14,64,5,17,29,115,115,209,209,28,47,3,34,48,97,32 +0.55,363,334,78,36,344,270,80,67,115,27,228,94,56,122,100,44,41,85,266,53,158,179,120,31,121,108,50,73,50,154,83,6,328,133,151,24,129,117,39,48,92,9,103,172,24,160,87,99,35,231,243,56,17,52,161,136,24,38,20,42,122,51,53,231,295,52,37,51,190,84,13,12,48,46,121,27,169,219,170,97,86,12,117,54,35,127,55,3,108,199,8,3,30,83,266,65,248,296,73,192 +0.60,144,14,172,55,264,46,56,59,290,256,211,97,90,29,242,90,16,31,25,59,75,74,66,240,155,6,57,53,282,190,61,129,37,23,3,260,361,22,57,125,43,102,36,59,36,81,148,40,39,62,32,232,158,2,1,63,414,38,194,137,51,183,387,15,37,75,40,69,313,23,271,54,178,50,11,47,25,70,273,58,52,11,21,36,92,174,144,47,120,38,66,84,195,342,99,186,164,64,221,15 +0.65,73,139,22,146,3,111,16,12,55,51,129,164,59,21,1,198,93,81,119,40,168,51,88,2,49,32,112,146,80,183,49,40,36,231,29,52,102,22,193,160,1,28,21,55,136,90,141,315,154,103,42,2,146,40,46,51,60,24,95,13,215,12,37,250,146,64,8,59,215,120,200,28,139,21,64,23,61,5,181,161,157,74,261,97,39,68,125,444,118,198,67,33,17,194,30,254,79,94,37,210 +0.70,74,103,14,73,123,115,128,93,12,54,1,34,78,294,77,27,5,17,46,31,255,154,83,59,66,152,125,103,28,94,73,79,110,290,3,100,208,42,28,193,120,278,170,141,169,45,143,121,145,127,91,97,126,287,80,97,79,83,102,109,241,280,247,26,90,84,104,78,99,218,17,256,380,135,18,17,121,72,119,50,35,46,54,219,55,113,185,178,140,55,121,155,126,4,8,115,40,13,98,12 +0.75,9,91,100,130,203,86,54,44,1,52,51,43,116,135,121,13,482,59,2,74,98,18,85,137,57,106,116,161,60,38,7,46,54,228,204,120,97,405,175,74,199,151,35,42,371,93,197,17,94,128,191,301,201,55,106,126,178,116,369,37,35,71,26,102,153,62,34,47,14,89,102,124,98,152,12,117,136,1,146,72,219,38,61,106,34,22,27,56,120,48,6,6,29,57,98,9,208,39,390,147 +0.80,42,45,81,110,48,158,63,40,60,106,52,90,129,105,81,193,54,67,164,79,302,86,237,160,204,213,18,50,37,85,168,148,84,213,13,25,185,25,8,20,69,78,312,62,255,50,248,144,223,3,34,161,132,138,113,39,60,83,65,20,94,96,196,51,30,58,147,22,84,102,109,189,79,124,196,58,147,69,33,130,82,10,290,77,39,48,76,158,202,60,178,20,132,39,19,85,186,108,35,222 +0.85,74,12,115,22,86,14,503,244,15,89,81,186,17,79,133,87,4,98,68,30,105,214,33,46,9,105,173,309,117,147,54,84,27,293,347,136,17,57,120,54,235,98,90,155,71,274,4,327,128,460,93,122,196,98,148,51,125,113,72,325,70,202,39,186,42,37,31,45,154,20,80,61,24,32,66,127,43,85,90,146,27,146,78,60,166,213,67,199,119,21,158,128,147,16,37,28,2,145,282,40 +0.90,31,147,81,271,124,26,142,48,49,103,61,114,62,188,135,220,56,100,107,267,4,372,165,227,26,226,180,241,50,83,39,217,480,208,95,58,55,95,28,66,370,32,37,10,11,292,85,470,173,185,1,331,116,593,47,365,140,88,189,73,17,15,250,100,483,210,65,72,89,52,48,53,47,164,94,4,141,143,140,263,50,135,1,152,67,68,59,166,114,11,41,224,26,40,81,224,132,49,439,41 +0.95,183,71,44,60,53,174,226,160,83,147,83,203,83,146,331,218,2,111,17,152,238,127,187,108,66,55,88,51,49,193,158,249,132,276,383,215,365,31,136,181,284,21,92,120,55,107,340,296,135,13,76,293,301,193,119,57,23,232,23,119,76,137,68,135,118,42,67,202,155,156,57,178,259,221,22,102,69,97,75,164,83,137,46,165,80,33,279,118,65,30,4,134,6,28,161,23,16,403,172,66 +1.00,264,486,133,12,484,71,319,98,96,110,251,22,221,201,39,6,28,384,280,105,41,226,298,135,341,117,125,432,180,89,73,146,199,17,70,354,264,50,137,147,343,393,332,59,128,44,21,569,123,285,110,153,15,168,142,361,216,331,22,338,132,124,371,28,4,77,125,466,48,231,254,231,8,287,96,114,162,201,154,247,220,133,123,2,175,71,213,280,74,25,343,480,70,206,22,23,92,107,204,91 diff --git a/src/main/java/core/algo/EpisodicLearning.java b/src/main/java/core/algo/EpisodicLearning.java index 559bfc6..669c931 100644 --- a/src/main/java/core/algo/EpisodicLearning.java +++ b/src/main/java/core/algo/EpisodicLearning.java @@ -9,9 +9,13 @@ import core.policy.EpsilonGreedyPolicy; import lombok.Getter; import lombok.Setter; +import java.io.File; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; import java.util.ArrayList; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; @@ -25,7 +29,8 @@ public abstract class EpisodicLearning extends Learning imple protected int episodeSumCurrentSecond; protected double sumOfRewards; protected List> episode = new ArrayList<>(); - + protected int timestampCurrentEpisode = 0; + protected boolean converged; public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, int delay) { super(environment, actionSpace, discountFactor, delay); initBenchMarking(); @@ -50,7 +55,7 @@ public abstract class EpisodicLearning extends Learning imple private void initBenchMarking(){ new Thread(()->{ - while (true){ + while (currentlyLearning){ episodePerSecond = episodeSumCurrentSecond; episodeSumCurrentSecond = 0; try { @@ -62,7 +67,7 @@ public abstract class EpisodicLearning extends Learning imple }).start(); } - protected void dispatchEpisodeEnd(){ + private void dispatchEpisodeEnd(){ ++episodeSumCurrentSecond; if(rewardHistory.size() > 10000){ rewardHistory.clear(); @@ -73,20 +78,20 @@ public abstract class EpisodicLearning extends Learning imple } } - protected void dispatchEpisodeStart(){ + private void dispatchEpisodeStart(){ ++currentEpisode; /* - 2f 0.02 => 100 - 1.5f 0.02 => 75 - 1.4f 0.02 => fail - 1.5f 0.1 => 16 ! - */ - if(this.policy instanceof EpsilonGreedyPolicy){ - float ep = 1.5f/(float)currentEpisode; - if(ep < 0.10) ep = 0; - ((EpsilonGreedyPolicy) this.policy).setEpsilon(ep); - System.out.println(ep); - } + 2f 0.02 => 100 + 1.5f 0.02 => 75 + 1.4f 0.02 => fail + 1.5f 0.1 => 16 ! + */ +// if(this.policy instanceof EpsilonGreedyPolicy){ +// float ep = 2f/(float)currentEpisode; +// if(ep < 0.02) ep = 0; +// ((EpsilonGreedyPolicy) this.policy).setEpsilon(ep); +// System.out.println(ep); +// } episodesToLearn.decrementAndGet(); for(LearningListener l: learningListeners){ l.onEpisodeStart(); @@ -97,10 +102,20 @@ public abstract class EpisodicLearning extends Learning imple protected void dispatchStepEnd() { super.dispatchStepEnd(); timestamp++; + timestampCurrentEpisode++; // TODO: more sophisticated way to check convergence - if(timestamp > 300000){ - System.out.println("converged after: " + currentEpisode + " episode!"); - interruptLearning(); + if(timestampCurrentEpisode > 300000){ + converged = true; + // t + File file = new File("convergence.txt"); + try { + Files.writeString(Path.of(file.getPath()), currentEpisode/2 + ",", StandardOpenOption.APPEND); + } catch (IOException e) { + e.printStackTrace(); + } + System.out.println("converged after: " + currentEpisode/2 + " episode!"); + episodesToLearn.set(0); + dispatchLearningEnd(); } } @@ -110,20 +125,23 @@ public abstract class EpisodicLearning extends Learning imple } private void startLearning(){ - learningExecutor.submit(()->{ - dispatchLearningStart(); - while(episodesToLearn.get() > 0){ - dispatchEpisodeStart(); - nextEpisode(); - dispatchEpisodeEnd(); - } - synchronized (this){ - dispatchLearningEnd(); - notifyAll(); - } - }); + dispatchLearningStart(); + while(episodesToLearn.get() > 0){ + + dispatchEpisodeStart(); + timestampCurrentEpisode = 0; + nextEpisode(); + dispatchEpisodeEnd(); + } + synchronized (this){ + dispatchLearningEnd(); + notifyAll(); + } } + public void learnMoreEpisodes(int nrOfEpisodes){ + episodesToLearn.addAndGet(nrOfEpisodes); + } /** * Stopping the while loop by setting episodesToLearn to 0. * The current episode can not be interrupted, so the sleep delay diff --git a/src/main/java/core/algo/Learning.java b/src/main/java/core/algo/Learning.java index fbba9ef..1bb5207 100644 --- a/src/main/java/core/algo/Learning.java +++ b/src/main/java/core/algo/Learning.java @@ -42,8 +42,7 @@ public abstract class Learning{ @Setter protected int delay; protected List rewardHistory; - protected ExecutorService learningExecutor; - protected boolean currentlyLearning; + protected volatile boolean currentlyLearning; public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, int delay) { this.environment = environment; @@ -53,7 +52,6 @@ public abstract class Learning{ currentlyLearning = false; learningListeners = new HashSet<>(); rewardHistory = new CopyOnWriteArrayList<>(); - learningExecutor = Executors.newSingleThreadExecutor(); } public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor) { @@ -89,8 +87,6 @@ public abstract class Learning{ protected void dispatchLearningEnd() { currentlyLearning = false; - System.out.println("Checksum: " + checkSum); - System.out.println("Reward Checksum: " + rewardCheckSum); for (LearningListener l : learningListeners) { l.onLearningEnd(); } diff --git a/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java b/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java index 50e9ce2..5bc93d3 100644 --- a/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java +++ b/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java @@ -3,12 +3,17 @@ package core.algo.mc; import core.*; import core.algo.EpisodicLearning; import core.policy.EpsilonGreedyPolicy; +import core.policy.GreedyPolicy; +import core.policy.Policy; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; -import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; +import java.io.*; +import java.net.URI; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.StandardOpenOption; import java.util.*; /** @@ -35,8 +40,16 @@ public class MonteCarloControlFirstVisitEGreedy extends Episodic private Map, Double> returnSum; private Map, Integer> returnCount; + // t + private float epsilon; + // t + private Policy greedyPolicy = new GreedyPolicy<>(); + + public MonteCarloControlFirstVisitEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay) { super(environment, actionSpace, discountFactor, delay); + // t + this.epsilon = epsilon; this.policy = new EpsilonGreedyPolicy<>(epsilon); this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace); returnSum = new HashMap<>(); @@ -58,12 +71,16 @@ public class MonteCarloControlFirstVisitEGreedy extends Episodic } sumOfRewards = 0; StepResultEnvironment envResult = null; - //TODO extract to learning - int timestamp = 0; + while(envResult == null || !envResult.isDone()) { Map actionValues = stateActionTable.getActionValues(state); - A chosenAction = policy.chooseAction(actionValues); - checkSum += chosenAction.ordinal(); + A chosenAction; + if(currentEpisode % 2 == 1){ + chosenAction = greedyPolicy.chooseAction(actionValues); + }else{ + chosenAction = policy.chooseAction(actionValues); + } + envResult = environment.step(chosenAction); State nextState = envResult.getState(); sumOfRewards += envResult.getReward(); @@ -79,6 +96,11 @@ public class MonteCarloControlFirstVisitEGreedy extends Episodic } timestamp++; dispatchStepEnd(); + if(converged) return; + } + + if(currentEpisode % 2 == 1){ + return; } // System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards); diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java index d855ab1..fc33239 100644 --- a/src/main/java/core/controller/RLController.java +++ b/src/main/java/core/controller/RLController.java @@ -67,17 +67,17 @@ public class RLController implements LearningListener { } protected void initListeners() { - learning.addListener(this); - new Thread(() -> { - while(true) { - printNextEpisode = true; - try { - Thread.sleep(30 * 1000); - } catch (InterruptedException e) { - e.printStackTrace(); + learning.addListener(this); + new Thread(() -> { + while(learning.isCurrentlyLearning()) { + printNextEpisode = true; + try { + Thread.sleep(30 * 1000); + } catch (InterruptedException e) { + e.printStackTrace(); + } } - } - }).start(); + }).start(); } private void initLearning() { @@ -95,7 +95,13 @@ public class RLController implements LearningListener { protected void learnMoreEpisodes(int nrOfEpisodes) { if(learning instanceof EpisodicLearning) { - ((EpisodicLearning) learning).learn(nrOfEpisodes); + if(learning.isCurrentlyLearning()){ + ((EpisodicLearning) learning).learnMoreEpisodes(nrOfEpisodes); + }else{ + new Thread(() -> { + ((EpisodicLearning) learning).learn(nrOfEpisodes); + }).start(); + } } else { throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!"); } diff --git a/src/main/java/example/JumpingDino.java b/src/main/java/example/JumpingDino.java index 41d5290..13efb4e 100644 --- a/src/main/java/example/JumpingDino.java +++ b/src/main/java/example/JumpingDino.java @@ -3,23 +3,49 @@ package example; import core.RNG; import core.algo.Method; import core.controller.RLController; +import core.controller.RLControllerGUI; import evironment.jumpingDino.DinoAction; import evironment.jumpingDino.DinoWorld; +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; + public class JumpingDino { public static void main(String[] args) { - RNG.setSeed(55); + File file = new File("convergence.txt"); + for(float f = 0.05f; f <=1.003 ; f+=0.05f){ + try { + Files.writeString(Path.of(file.getPath()), f + ",", StandardOpenOption.APPEND); + } catch (IOException e) { + e.printStackTrace(); + } + for(int i = 1; i <= 100; i++) { + System.out.println("seed: " + i *13); + RNG.setSeed(i *13); - RLController rl = new RLController<>( - new DinoWorld(false, false), - Method.MC_CONTROL_FIRST_VISIT, - DinoAction.values()); + RLController rl = new RLController<>( + new DinoWorld(false, false), + Method.MC_CONTROL_FIRST_VISIT, + DinoAction.values()); - rl.setDelay(0); - rl.setDiscountFactor(1f); - rl.setEpsilon(0.15f); - rl.setLearningRate(1f); - rl.setNrOfEpisodes(400); - rl.start(); + + rl.setDelay(0); + rl.setDiscountFactor(1f); + rl.setEpsilon(f); + rl.setLearningRate(1f); + rl.setNrOfEpisodes(20000); + rl.start(); + + } + try { + Files.writeString(Path.of(file.getPath()), "\n", StandardOpenOption.APPEND); + } catch (IOException e) { + e.printStackTrace(); + } + } + System.out.println("kek"); } } diff --git a/src/main/java/example/Results b/src/main/java/example/Results new file mode 100644 index 0000000..ba42fc7 --- /dev/null +++ b/src/main/java/example/Results @@ -0,0 +1,15 @@ +Method: +Epsilon = k / currentEpisode +set to 0 if Epsilon < b + +k = 1.5 +b = 0.1 => conv. 16 + +k = 1.5 +b = 0.02 => 75 + +k = 1.4 +b = 0.02 => fail + +k = 2.0 +b = 0.02 => conv. 100