From 328fc8521407aa6ba9036ad383ec1bcc87dbcf53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20L=C3=B6wenstrom?= Date: Sat, 28 Mar 2020 12:35:33 +0100 Subject: [PATCH] modify q Learning to sample results and update R script --- epsilonValues.R | 23 ++++++++- src/main/java/core/RNG.java | 8 ++++ src/main/java/core/algo/EpisodicLearning.java | 7 +-- .../MonteCarloControlFirstVisitEGreedy.java | 19 +------- .../algo/td/QLearningOffPolicyTDControl.java | 48 ++++++++++++++----- src/main/java/core/algo/td/SARSA.java | 25 ++-------- src/main/java/evironment/antGame/Grid.java | 6 +-- .../jumpingDino/DinoWorldAdvanced.java | 4 +- src/main/java/example/ContinuousAnt.java | 25 ++++++---- 9 files changed, 97 insertions(+), 68 deletions(-) diff --git a/epsilonValues.R b/epsilonValues.R index ed356c1..ff0bd53 100644 --- a/epsilonValues.R +++ b/epsilonValues.R @@ -1,7 +1,28 @@ # Libraries library(ggplot2) library(matrixStats) -# file.choose() +ta <- as.matrix(read.table(file.choose(), sep=",", header = FALSE)) +ta <- t(ta) +dim(ta) +head(ta) + +# Create dummy data +data <- data.frame( + y=ta[,1], + y2=ta[,2], + y3=ta[,3], + y4=ta[,4], + x=seq(1, length(ta)) +) +ggplot(data, aes(x)) + + geom_line(aes(y = y, colour = "var0")) + + geom_line(aes(y = y2, colour = "var1")) + + geom_line(aes(y = y3, colour = "var2")) + + geom_line(aes(y = y4, colour = "var3")) + + scale_x_log10( breaks=c(1,5,10,15,20,50,100,200), limits=c(1,200) ) + +plot(ta, x=x*1000, log="x", type="o") + convergence <- read.csv(file.choose(), header=FALSE, row.names=1) sds <- rowSds(sapply(convergence[,-1], `length<-`, max(lengths(convergence[,-1]))), na.rm=TRUE) diff --git a/src/main/java/core/RNG.java b/src/main/java/core/RNG.java index b0708af..959ae59 100644 --- a/src/main/java/core/RNG.java +++ b/src/main/java/core/RNG.java @@ -15,18 +15,26 @@ import java.util.Random; */ public class RNG { private static Random rng; + private static Random rngEnv; private static int seed = 123; static { rng = new Random(); rng.setSeed(seed); + rngEnv = new Random(); + rngEnv.setSeed(seed); } public static Random getRandom() { return rng; } + public static Random getRandomEnv() { + return rngEnv; + } public static void setSeed(int seed){ RNG.seed = seed; rng.setSeed(seed); + rngEnv = new Random(); + rngEnv.setSeed(seed); } } diff --git a/src/main/java/core/algo/EpisodicLearning.java b/src/main/java/core/algo/EpisodicLearning.java index a14cba0..d7b6655 100644 --- a/src/main/java/core/algo/EpisodicLearning.java +++ b/src/main/java/core/algo/EpisodicLearning.java @@ -105,8 +105,7 @@ public abstract class EpisodicLearning extends Learning imple timestamp++; timestampCurrentEpisode++; // TODO: more sophisticated way to check convergence - if(timestampCurrentEpisode > 30000000){ - converged = true; + if(false){ // t File file = new File(DinoSampling.FILE_NAME); try { @@ -114,9 +113,7 @@ public abstract class EpisodicLearning extends Learning imple } catch (IOException e) { e.printStackTrace(); } - System.out.println("converged after: " + currentEpisode/2 + " episode!"); - episodesToLearn.set(0); - dispatchLearningEnd(); + // System.out.println("converged after: " + currentEpisode/2 + " episode!"); } } diff --git a/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java b/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java index 5bc93d3..ff3247b 100644 --- a/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java +++ b/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java @@ -40,16 +40,9 @@ public class MonteCarloControlFirstVisitEGreedy extends Episodic private Map, Double> returnSum; private Map, Integer> returnCount; - // t - private float epsilon; - // t - private Policy greedyPolicy = new GreedyPolicy<>(); - public MonteCarloControlFirstVisitEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay) { super(environment, actionSpace, discountFactor, delay); - // t - this.epsilon = epsilon; this.policy = new EpsilonGreedyPolicy<>(epsilon); this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace); returnSum = new HashMap<>(); @@ -74,12 +67,7 @@ public class MonteCarloControlFirstVisitEGreedy extends Episodic while(envResult == null || !envResult.isDone()) { Map actionValues = stateActionTable.getActionValues(state); - A chosenAction; - if(currentEpisode % 2 == 1){ - chosenAction = greedyPolicy.chooseAction(actionValues); - }else{ - chosenAction = policy.chooseAction(actionValues); - } + A chosenAction = policy.chooseAction(actionValues); envResult = environment.step(chosenAction); State nextState = envResult.getState(); @@ -96,12 +84,9 @@ public class MonteCarloControlFirstVisitEGreedy extends Episodic } timestamp++; dispatchStepEnd(); - if(converged) return; } - if(currentEpisode % 2 == 1){ - return; - } + // System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards); Set> stateActionPairs = new LinkedHashSet<>(); diff --git a/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java b/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java index 69919fa..4adfef3 100644 --- a/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java +++ b/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java @@ -5,7 +5,15 @@ import core.algo.EpisodicLearning; import core.policy.EpsilonGreedyPolicy; import core.policy.GreedyPolicy; import core.policy.Policy; +import evironment.antGame.Reward; +import example.ContinuousAnt; +import example.DinoSampling; +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; import java.util.Map; public class QLearningOffPolicyTDControl extends EpisodicLearning { @@ -37,25 +45,43 @@ public class QLearningOffPolicyTDControl extends EpisodicLearnin sumOfRewards = 0; + int timestampTilFood = 0; + int rewardsPer1000 = 0; + int foodCollected = 0; while(envResult == null || !envResult.isDone()) { actionValues = stateActionTable.getActionValues(state); - A action; - if(currentEpisode % 2 == 0){ - action = greedyPolicy.chooseAction(actionValues); - }else{ - action = policy.chooseAction(actionValues); - } - if(converged) return; + A action = policy.chooseAction(actionValues); + // Take a step envResult = environment.step(action); double reward = envResult.getReward(); State nextState = envResult.getState(); sumOfRewards += reward; - if(currentEpisode % 2 == 0){ - state = nextState; - dispatchStepEnd(); - continue; + + rewardsPer1000+=reward; + timestampTilFood++; + + if(foodCollected == 10000){ + File file = new File(ContinuousAnt.FILE_NAME); + try { + Files.writeString(Path.of(file.getPath()), "\n", StandardOpenOption.APPEND); + } catch (IOException e) { + e.printStackTrace(); + } + return; } + if(reward == Reward.FOOD_DROP_DOWN_SUCCESS){ + foodCollected++; + File file = new File(ContinuousAnt.FILE_NAME); + try { + Files.writeString(Path.of(file.getPath()), timestampTilFood + ",", StandardOpenOption.APPEND); + } catch (IOException e) { + e.printStackTrace(); + } + timestampTilFood = 0; + rewardsPer1000 = 0; + } + // Q Update double currentQValue = stateActionTable.getActionValues(state).get(action); // maxQ(S', a); diff --git a/src/main/java/core/algo/td/SARSA.java b/src/main/java/core/algo/td/SARSA.java index 336516d..b64d59e 100644 --- a/src/main/java/core/algo/td/SARSA.java +++ b/src/main/java/core/algo/td/SARSA.java @@ -11,7 +11,6 @@ import java.util.Map; public class SARSA extends EpisodicLearning { private float alpha; - private Policy greedyPolicy = new GreedyPolicy<>(); public SARSA(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, float learningRate, int delay) { super(environment, actionSpace, discountFactor, delay); @@ -35,18 +34,13 @@ public class SARSA extends EpisodicLearning { StepResultEnvironment envResult = null; Map actionValues = stateActionTable.getActionValues(state); - A action; - if(currentEpisode % 2 == 1){ - action = greedyPolicy.chooseAction(actionValues); - }else{ - action = policy.chooseAction(actionValues); - } + A action = policy.chooseAction(actionValues); + //A action = policy.chooseAction(actionValues); sumOfRewards = 0; while(envResult == null || !envResult.isDone()) { - if(converged) return; // Take a step envResult = environment.step(action); sumOfRewards += envResult.getReward(); @@ -56,19 +50,8 @@ public class SARSA extends EpisodicLearning { // Pick next action actionValues = stateActionTable.getActionValues(nextState); - A nextAction; - if(currentEpisode % 2 == 1){ - nextAction = greedyPolicy.chooseAction(actionValues); - }else{ - nextAction = policy.chooseAction(actionValues); - } - //A nextAction = policy.chooseAction(actionValues); - if(currentEpisode % 2 == 1){ - state = nextState; - action = nextAction; - dispatchStepEnd(); - continue; - } + A nextAction = policy.chooseAction(actionValues); + // td update // target = reward + gamma * Q(nextState, nextAction) double currentQValue = stateActionTable.getActionValues(state).get(action); diff --git a/src/main/java/evironment/antGame/Grid.java b/src/main/java/evironment/antGame/Grid.java index dd23e78..0393b93 100644 --- a/src/main/java/evironment/antGame/Grid.java +++ b/src/main/java/evironment/antGame/Grid.java @@ -29,7 +29,7 @@ public class Grid { initialGrid[x][y] = new Cell(new Point(x, y), CellType.FREE); } } - start = new Point(RNG.getRandom().nextInt(width), RNG.getRandom().nextInt(height)); + start = new Point(RNG.getRandomEnv().nextInt(width), RNG.getRandomEnv().nextInt(height)); initialGrid[start.x][start.y] = new Cell(new Point(start.x, start.y), CellType.START); spawnNewFood(initialGrid); spawnObstacles(); @@ -58,8 +58,8 @@ public class Grid { Point potFood = new Point(0, 0); CellType potFieldType; while(!foodSpawned) { - potFood.x = RNG.getRandom().nextInt(width); - potFood.y = RNG.getRandom().nextInt(height); + potFood.x = RNG.getRandomEnv().nextInt(width); + potFood.y = RNG.getRandomEnv().nextInt(height); potFieldType = grid[potFood.x][potFood.y].getType(); if(potFieldType != CellType.START && grid[potFood.x][potFood.y].getFood() == 0 && potFieldType != CellType.OBSTACLE) { grid[potFood.x][potFood.y].setFood(1); diff --git a/src/main/java/evironment/jumpingDino/DinoWorldAdvanced.java b/src/main/java/evironment/jumpingDino/DinoWorldAdvanced.java index a2ae420..7c35579 100644 --- a/src/main/java/evironment/jumpingDino/DinoWorldAdvanced.java +++ b/src/main/java/evironment/jumpingDino/DinoWorldAdvanced.java @@ -31,7 +31,7 @@ public class DinoWorldAdvanced extends DinoWorld{ protected void spawnNewObstacle() { int dx; int xSpawn; - double ran = RNG.getRandom().nextDouble(); + double ran = RNG.getRandomEnv().nextDouble(); if(ran < 0.25){ dx = -(int) (0.35 * Config.OBSTACLE_SPEED); }else if(ran < 0.5){ @@ -41,7 +41,7 @@ public class DinoWorldAdvanced extends DinoWorld{ } else{ dx = -(int) (3.5 * Config.OBSTACLE_SPEED); } - double ran2 = RNG.getRandom().nextDouble(); + double ran2 = RNG.getRandomEnv().nextDouble(); if(ran2 < 0.25) { // randomly spawning more right outside of the screen xSpawn = Config.FRAME_WIDTH + Config.FRAME_WIDTH + Config.OBSTACLE_SIZE; diff --git a/src/main/java/example/ContinuousAnt.java b/src/main/java/example/ContinuousAnt.java index 0a13e70..9005956 100644 --- a/src/main/java/example/ContinuousAnt.java +++ b/src/main/java/example/ContinuousAnt.java @@ -8,19 +8,28 @@ import evironment.antGame.AntAction; import evironment.antGame.AntWorldContinuous; import evironment.antGame.AntWorldContinuousOriginalState; +import java.io.File; +import java.io.IOException; + public class ContinuousAnt { + public static final String FILE_NAME = "converge05.txt"; public static void main(String[] args) { + File file = new File(FILE_NAME); + try { + file.createNewFile(); + } catch (IOException e) { + e.printStackTrace(); + } RNG.setSeed(56); - RLController rl = new RLControllerGUI<>( - new AntWorldContinuousOriginalState(8, 8), + RLController rl = new RLController<>( + new AntWorldContinuous(8, 8), Method.Q_LEARNING_OFF_POLICY_CONTROL, AntAction.values()); - - rl.setDelay(200); - rl.setNrOfEpisodes(10000); - rl.setDiscountFactor(0.95f); - rl.setEpsilon(0.15f); - + rl.setDelay(0); + rl.setNrOfEpisodes(1); + rl.setDiscountFactor(0.7f); + rl.setLearningRate(0.2f); + rl.setEpsilon(0.5f); rl.start(); } }