diff --git a/epsilonValues.R b/epsilonValues.R
index ed356c1..ff0bd53 100644
--- a/epsilonValues.R
+++ b/epsilonValues.R
@@ -1,7 +1,28 @@
# Libraries
library(ggplot2)
library(matrixStats)
-# file.choose()
+ta <- as.matrix(read.table(file.choose(), sep=",", header = FALSE))
+ta <- t(ta)
+dim(ta)
+head(ta)
+
+# Create dummy data
+data <- data.frame(
+ y=ta[,1],
+ y2=ta[,2],
+ y3=ta[,3],
+ y4=ta[,4],
+ x=seq(1, length(ta))
+)
+ggplot(data, aes(x)) +
+ geom_line(aes(y = y, colour = "var0")) +
+ geom_line(aes(y = y2, colour = "var1")) +
+ geom_line(aes(y = y3, colour = "var2")) +
+ geom_line(aes(y = y4, colour = "var3")) +
+ scale_x_log10( breaks=c(1,5,10,15,20,50,100,200), limits=c(1,200) )
+
+plot(ta, x=x*1000, log="x", type="o")
+
convergence <- read.csv(file.choose(), header=FALSE, row.names=1)
sds <- rowSds(sapply(convergence[,-1], `length<-`, max(lengths(convergence[,-1]))), na.rm=TRUE)
diff --git a/src/main/java/core/RNG.java b/src/main/java/core/RNG.java
index b0708af..959ae59 100644
--- a/src/main/java/core/RNG.java
+++ b/src/main/java/core/RNG.java
@@ -15,18 +15,26 @@ import java.util.Random;
*/
public class RNG {
private static Random rng;
+ private static Random rngEnv;
private static int seed = 123;
static {
rng = new Random();
rng.setSeed(seed);
+ rngEnv = new Random();
+ rngEnv.setSeed(seed);
}
public static Random getRandom() {
return rng;
}
+ public static Random getRandomEnv() {
+ return rngEnv;
+ }
public static void setSeed(int seed){
RNG.seed = seed;
rng.setSeed(seed);
+ rngEnv = new Random();
+ rngEnv.setSeed(seed);
}
}
diff --git a/src/main/java/core/algo/EpisodicLearning.java b/src/main/java/core/algo/EpisodicLearning.java
index a14cba0..d7b6655 100644
--- a/src/main/java/core/algo/EpisodicLearning.java
+++ b/src/main/java/core/algo/EpisodicLearning.java
@@ -105,8 +105,7 @@ public abstract class EpisodicLearning extends Learning imple
timestamp++;
timestampCurrentEpisode++;
// TODO: more sophisticated way to check convergence
- if(timestampCurrentEpisode > 30000000){
- converged = true;
+ if(false){
// t
File file = new File(DinoSampling.FILE_NAME);
try {
@@ -114,9 +113,7 @@ public abstract class EpisodicLearning extends Learning imple
} catch (IOException e) {
e.printStackTrace();
}
- System.out.println("converged after: " + currentEpisode/2 + " episode!");
- episodesToLearn.set(0);
- dispatchLearningEnd();
+ // System.out.println("converged after: " + currentEpisode/2 + " episode!");
}
}
diff --git a/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java b/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java
index 5bc93d3..ff3247b 100644
--- a/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java
+++ b/src/main/java/core/algo/mc/MonteCarloControlFirstVisitEGreedy.java
@@ -40,16 +40,9 @@ public class MonteCarloControlFirstVisitEGreedy extends Episodic
private Map, Double> returnSum;
private Map, Integer> returnCount;
- // t
- private float epsilon;
- // t
- private Policy greedyPolicy = new GreedyPolicy<>();
-
public MonteCarloControlFirstVisitEGreedy(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, int delay) {
super(environment, actionSpace, discountFactor, delay);
- // t
- this.epsilon = epsilon;
this.policy = new EpsilonGreedyPolicy<>(epsilon);
this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace);
returnSum = new HashMap<>();
@@ -74,12 +67,7 @@ public class MonteCarloControlFirstVisitEGreedy extends Episodic
while(envResult == null || !envResult.isDone()) {
Map actionValues = stateActionTable.getActionValues(state);
- A chosenAction;
- if(currentEpisode % 2 == 1){
- chosenAction = greedyPolicy.chooseAction(actionValues);
- }else{
- chosenAction = policy.chooseAction(actionValues);
- }
+ A chosenAction = policy.chooseAction(actionValues);
envResult = environment.step(chosenAction);
State nextState = envResult.getState();
@@ -96,12 +84,9 @@ public class MonteCarloControlFirstVisitEGreedy extends Episodic
}
timestamp++;
dispatchStepEnd();
- if(converged) return;
}
- if(currentEpisode % 2 == 1){
- return;
- }
+
// System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
Set> stateActionPairs = new LinkedHashSet<>();
diff --git a/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java b/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java
index 69919fa..4adfef3 100644
--- a/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java
+++ b/src/main/java/core/algo/td/QLearningOffPolicyTDControl.java
@@ -5,7 +5,15 @@ import core.algo.EpisodicLearning;
import core.policy.EpsilonGreedyPolicy;
import core.policy.GreedyPolicy;
import core.policy.Policy;
+import evironment.antGame.Reward;
+import example.ContinuousAnt;
+import example.DinoSampling;
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.StandardOpenOption;
import java.util.Map;
public class QLearningOffPolicyTDControl extends EpisodicLearning {
@@ -37,25 +45,43 @@ public class QLearningOffPolicyTDControl extends EpisodicLearnin
sumOfRewards = 0;
+ int timestampTilFood = 0;
+ int rewardsPer1000 = 0;
+ int foodCollected = 0;
while(envResult == null || !envResult.isDone()) {
actionValues = stateActionTable.getActionValues(state);
- A action;
- if(currentEpisode % 2 == 0){
- action = greedyPolicy.chooseAction(actionValues);
- }else{
- action = policy.chooseAction(actionValues);
- }
- if(converged) return;
+ A action = policy.chooseAction(actionValues);
+
// Take a step
envResult = environment.step(action);
double reward = envResult.getReward();
State nextState = envResult.getState();
sumOfRewards += reward;
- if(currentEpisode % 2 == 0){
- state = nextState;
- dispatchStepEnd();
- continue;
+
+ rewardsPer1000+=reward;
+ timestampTilFood++;
+
+ if(foodCollected == 10000){
+ File file = new File(ContinuousAnt.FILE_NAME);
+ try {
+ Files.writeString(Path.of(file.getPath()), "\n", StandardOpenOption.APPEND);
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ return;
}
+ if(reward == Reward.FOOD_DROP_DOWN_SUCCESS){
+ foodCollected++;
+ File file = new File(ContinuousAnt.FILE_NAME);
+ try {
+ Files.writeString(Path.of(file.getPath()), timestampTilFood + ",", StandardOpenOption.APPEND);
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ timestampTilFood = 0;
+ rewardsPer1000 = 0;
+ }
+
// Q Update
double currentQValue = stateActionTable.getActionValues(state).get(action);
// maxQ(S', a);
diff --git a/src/main/java/core/algo/td/SARSA.java b/src/main/java/core/algo/td/SARSA.java
index 336516d..b64d59e 100644
--- a/src/main/java/core/algo/td/SARSA.java
+++ b/src/main/java/core/algo/td/SARSA.java
@@ -11,7 +11,6 @@ import java.util.Map;
public class SARSA extends EpisodicLearning {
private float alpha;
- private Policy greedyPolicy = new GreedyPolicy<>();
public SARSA(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, float epsilon, float learningRate, int delay) {
super(environment, actionSpace, discountFactor, delay);
@@ -35,18 +34,13 @@ public class SARSA extends EpisodicLearning {
StepResultEnvironment envResult = null;
Map actionValues = stateActionTable.getActionValues(state);
- A action;
- if(currentEpisode % 2 == 1){
- action = greedyPolicy.chooseAction(actionValues);
- }else{
- action = policy.chooseAction(actionValues);
- }
+ A action = policy.chooseAction(actionValues);
+
//A action = policy.chooseAction(actionValues);
sumOfRewards = 0;
while(envResult == null || !envResult.isDone()) {
- if(converged) return;
// Take a step
envResult = environment.step(action);
sumOfRewards += envResult.getReward();
@@ -56,19 +50,8 @@ public class SARSA extends EpisodicLearning {
// Pick next action
actionValues = stateActionTable.getActionValues(nextState);
- A nextAction;
- if(currentEpisode % 2 == 1){
- nextAction = greedyPolicy.chooseAction(actionValues);
- }else{
- nextAction = policy.chooseAction(actionValues);
- }
- //A nextAction = policy.chooseAction(actionValues);
- if(currentEpisode % 2 == 1){
- state = nextState;
- action = nextAction;
- dispatchStepEnd();
- continue;
- }
+ A nextAction = policy.chooseAction(actionValues);
+
// td update
// target = reward + gamma * Q(nextState, nextAction)
double currentQValue = stateActionTable.getActionValues(state).get(action);
diff --git a/src/main/java/evironment/antGame/Grid.java b/src/main/java/evironment/antGame/Grid.java
index dd23e78..0393b93 100644
--- a/src/main/java/evironment/antGame/Grid.java
+++ b/src/main/java/evironment/antGame/Grid.java
@@ -29,7 +29,7 @@ public class Grid {
initialGrid[x][y] = new Cell(new Point(x, y), CellType.FREE);
}
}
- start = new Point(RNG.getRandom().nextInt(width), RNG.getRandom().nextInt(height));
+ start = new Point(RNG.getRandomEnv().nextInt(width), RNG.getRandomEnv().nextInt(height));
initialGrid[start.x][start.y] = new Cell(new Point(start.x, start.y), CellType.START);
spawnNewFood(initialGrid);
spawnObstacles();
@@ -58,8 +58,8 @@ public class Grid {
Point potFood = new Point(0, 0);
CellType potFieldType;
while(!foodSpawned) {
- potFood.x = RNG.getRandom().nextInt(width);
- potFood.y = RNG.getRandom().nextInt(height);
+ potFood.x = RNG.getRandomEnv().nextInt(width);
+ potFood.y = RNG.getRandomEnv().nextInt(height);
potFieldType = grid[potFood.x][potFood.y].getType();
if(potFieldType != CellType.START && grid[potFood.x][potFood.y].getFood() == 0 && potFieldType != CellType.OBSTACLE) {
grid[potFood.x][potFood.y].setFood(1);
diff --git a/src/main/java/evironment/jumpingDino/DinoWorldAdvanced.java b/src/main/java/evironment/jumpingDino/DinoWorldAdvanced.java
index a2ae420..7c35579 100644
--- a/src/main/java/evironment/jumpingDino/DinoWorldAdvanced.java
+++ b/src/main/java/evironment/jumpingDino/DinoWorldAdvanced.java
@@ -31,7 +31,7 @@ public class DinoWorldAdvanced extends DinoWorld{
protected void spawnNewObstacle() {
int dx;
int xSpawn;
- double ran = RNG.getRandom().nextDouble();
+ double ran = RNG.getRandomEnv().nextDouble();
if(ran < 0.25){
dx = -(int) (0.35 * Config.OBSTACLE_SPEED);
}else if(ran < 0.5){
@@ -41,7 +41,7 @@ public class DinoWorldAdvanced extends DinoWorld{
} else{
dx = -(int) (3.5 * Config.OBSTACLE_SPEED);
}
- double ran2 = RNG.getRandom().nextDouble();
+ double ran2 = RNG.getRandomEnv().nextDouble();
if(ran2 < 0.25) {
// randomly spawning more right outside of the screen
xSpawn = Config.FRAME_WIDTH + Config.FRAME_WIDTH + Config.OBSTACLE_SIZE;
diff --git a/src/main/java/example/ContinuousAnt.java b/src/main/java/example/ContinuousAnt.java
index 0a13e70..9005956 100644
--- a/src/main/java/example/ContinuousAnt.java
+++ b/src/main/java/example/ContinuousAnt.java
@@ -8,19 +8,28 @@ import evironment.antGame.AntAction;
import evironment.antGame.AntWorldContinuous;
import evironment.antGame.AntWorldContinuousOriginalState;
+import java.io.File;
+import java.io.IOException;
+
public class ContinuousAnt {
+ public static final String FILE_NAME = "converge05.txt";
public static void main(String[] args) {
+ File file = new File(FILE_NAME);
+ try {
+ file.createNewFile();
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
RNG.setSeed(56);
- RLController rl = new RLControllerGUI<>(
- new AntWorldContinuousOriginalState(8, 8),
+ RLController rl = new RLController<>(
+ new AntWorldContinuous(8, 8),
Method.Q_LEARNING_OFF_POLICY_CONTROL,
AntAction.values());
-
- rl.setDelay(200);
- rl.setNrOfEpisodes(10000);
- rl.setDiscountFactor(0.95f);
- rl.setEpsilon(0.15f);
-
+ rl.setDelay(0);
+ rl.setNrOfEpisodes(1);
+ rl.setDiscountFactor(0.7f);
+ rl.setLearningRate(0.2f);
+ rl.setEpsilon(0.5f);
rl.start();
}
}