modify q Learning to sample results and update R script

This commit is contained in:
Jan Löwenstrom 2020-03-28 12:35:33 +01:00
parent eca0d8db4d
commit 328fc85214
9 changed files with 97 additions and 68 deletions

View File

@ -1,7 +1,28 @@
# Libraries # Libraries
library(ggplot2) library(ggplot2)
library(matrixStats) library(matrixStats)
# file.choose() ta <- as.matrix(read.table(file.choose(), sep=",", header = FALSE))
ta <- t(ta)
dim(ta)
head(ta)
# Create dummy data
data <- data.frame(
y=ta[,1],
y2=ta[,2],
y3=ta[,3],
y4=ta[,4],
x=seq(1, length(ta))
)
ggplot(data, aes(x)) +
geom_line(aes(y = y, colour = "var0")) +
geom_line(aes(y = y2, colour = "var1")) +
geom_line(aes(y = y3, colour = "var2")) +
geom_line(aes(y = y4, colour = "var3")) +
scale_x_log10( breaks=c(1,5,10,15,20,50,100,200), limits=c(1,200) )
plot(ta, x=x*1000, log="x", type="o")
convergence <- read.csv(file.choose(), header=FALSE, row.names=1) convergence <- read.csv(file.choose(), header=FALSE, row.names=1)
sds <- rowSds(sapply(convergence[,-1], `length<-`, max(lengths(convergence[,-1]))), na.rm=TRUE) sds <- rowSds(sapply(convergence[,-1], `length<-`, max(lengths(convergence[,-1]))), na.rm=TRUE)

View File

@ -15,18 +15,26 @@ import java.util.Random;
*/ */
public class RNG { public class RNG {
private static Random rng; private static Random rng;
private static Random rngEnv;
private static int seed = 123; private static int seed = 123;
static { static {
rng = new Random(); rng = new Random();
rng.setSeed(seed); rng.setSeed(seed);
rngEnv = new Random();
rngEnv.setSeed(seed);
} }
public static Random getRandom() { public static Random getRandom() {
return rng; return rng;
} }
public static Random getRandomEnv() {
return rngEnv;
}
public static void setSeed(int seed){ public static void setSeed(int seed){
RNG.seed = seed; RNG.seed = seed;
rng.setSeed(seed); rng.setSeed(seed);
rngEnv = new Random();
rngEnv.setSeed(seed);
} }
} }

View File

@ -105,8 +105,7 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
timestamp++; timestamp++;
timestampCurrentEpisode++; timestampCurrentEpisode++;
// TODO: more sophisticated way to check convergence // TODO: more sophisticated way to check convergence
if(timestampCurrentEpisode > 30000000){ if(false){
converged = true;
// t // t
File file = new File(DinoSampling.FILE_NAME); File file = new File(DinoSampling.FILE_NAME);
try { try {
@ -114,9 +113,7 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
} catch (IOException e) { } catch (IOException e) {
e.printStackTrace(); e.printStackTrace();
} }
System.out.println("converged after: " + currentEpisode/2 + " episode!"); // System.out.println("converged after: " + currentEpisode/2 + " episode!");
episodesToLearn.set(0);
dispatchLearningEnd();
} }
} }

View File

@ -40,16 +40,9 @@ public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends Episodic
private Map<Pair<State, A>, Double> returnSum; private Map<Pair<State, A>, Double> returnSum;
private Map<Pair<State, A>, Integer> returnCount; private Map<Pair<State, A>, Integer> returnCount;
// t
private float epsilon;
// t
private Policy<A> greedyPolicy = new GreedyPolicy<>();
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) { public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
super(environment, actionSpace, discountFactor, delay); super(environment, actionSpace, discountFactor, delay);
// t
this.epsilon = epsilon;
this.policy = new EpsilonGreedyPolicy<>(epsilon); this.policy = new EpsilonGreedyPolicy<>(epsilon);
this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace); this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace);
returnSum = new HashMap<>(); returnSum = new HashMap<>();
@ -74,12 +67,7 @@ public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends Episodic
while(envResult == null || !envResult.isDone()) { while(envResult == null || !envResult.isDone()) {
Map<A, Double> actionValues = stateActionTable.getActionValues(state); Map<A, Double> actionValues = stateActionTable.getActionValues(state);
A chosenAction; A chosenAction = policy.chooseAction(actionValues);
if(currentEpisode % 2 == 1){
chosenAction = greedyPolicy.chooseAction(actionValues);
}else{
chosenAction = policy.chooseAction(actionValues);
}
envResult = environment.step(chosenAction); envResult = environment.step(chosenAction);
State nextState = envResult.getState(); State nextState = envResult.getState();
@ -96,12 +84,9 @@ public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends Episodic
} }
timestamp++; timestamp++;
dispatchStepEnd(); dispatchStepEnd();
if(converged) return;
} }
if(currentEpisode % 2 == 1){
return;
}
// System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards); // System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
Set<Pair<State, A>> stateActionPairs = new LinkedHashSet<>(); Set<Pair<State, A>> stateActionPairs = new LinkedHashSet<>();

View File

@ -5,7 +5,15 @@ import core.algo.EpisodicLearning;
import core.policy.EpsilonGreedyPolicy; import core.policy.EpsilonGreedyPolicy;
import core.policy.GreedyPolicy; import core.policy.GreedyPolicy;
import core.policy.Policy; import core.policy.Policy;
import evironment.antGame.Reward;
import example.ContinuousAnt;
import example.DinoSampling;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.Map; import java.util.Map;
public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearning<A> { public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearning<A> {
@ -37,25 +45,43 @@ public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearnin
sumOfRewards = 0; sumOfRewards = 0;
int timestampTilFood = 0;
int rewardsPer1000 = 0;
int foodCollected = 0;
while(envResult == null || !envResult.isDone()) { while(envResult == null || !envResult.isDone()) {
actionValues = stateActionTable.getActionValues(state); actionValues = stateActionTable.getActionValues(state);
A action; A action = policy.chooseAction(actionValues);
if(currentEpisode % 2 == 0){
action = greedyPolicy.chooseAction(actionValues);
}else{
action = policy.chooseAction(actionValues);
}
if(converged) return;
// Take a step // Take a step
envResult = environment.step(action); envResult = environment.step(action);
double reward = envResult.getReward(); double reward = envResult.getReward();
State nextState = envResult.getState(); State nextState = envResult.getState();
sumOfRewards += reward; sumOfRewards += reward;
if(currentEpisode % 2 == 0){
state = nextState; rewardsPer1000+=reward;
dispatchStepEnd(); timestampTilFood++;
continue;
if(foodCollected == 10000){
File file = new File(ContinuousAnt.FILE_NAME);
try {
Files.writeString(Path.of(file.getPath()), "\n", StandardOpenOption.APPEND);
} catch (IOException e) {
e.printStackTrace();
} }
return;
}
if(reward == Reward.FOOD_DROP_DOWN_SUCCESS){
foodCollected++;
File file = new File(ContinuousAnt.FILE_NAME);
try {
Files.writeString(Path.of(file.getPath()), timestampTilFood + ",", StandardOpenOption.APPEND);
} catch (IOException e) {
e.printStackTrace();
}
timestampTilFood = 0;
rewardsPer1000 = 0;
}
// Q Update // Q Update
double currentQValue = stateActionTable.getActionValues(state).get(action); double currentQValue = stateActionTable.getActionValues(state).get(action);
// maxQ(S', a); // maxQ(S', a);

View File

@ -11,7 +11,6 @@ import java.util.Map;
public class SARSA<A extends Enum> extends EpisodicLearning<A> { public class SARSA<A extends Enum> extends EpisodicLearning<A> {
private float alpha; private float alpha;
private Policy<A> greedyPolicy = new GreedyPolicy<>();
public SARSA(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, float learningRate, int delay) { public SARSA(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, float learningRate, int delay) {
super(environment, actionSpace, discountFactor, delay); super(environment, actionSpace, discountFactor, delay);
@ -35,18 +34,13 @@ public class SARSA<A extends Enum> extends EpisodicLearning<A> {
StepResultEnvironment envResult = null; StepResultEnvironment envResult = null;
Map<A, Double> actionValues = stateActionTable.getActionValues(state); Map<A, Double> actionValues = stateActionTable.getActionValues(state);
A action; A action = policy.chooseAction(actionValues);
if(currentEpisode % 2 == 1){
action = greedyPolicy.chooseAction(actionValues);
}else{
action = policy.chooseAction(actionValues);
}
//A action = policy.chooseAction(actionValues); //A action = policy.chooseAction(actionValues);
sumOfRewards = 0; sumOfRewards = 0;
while(envResult == null || !envResult.isDone()) { while(envResult == null || !envResult.isDone()) {
if(converged) return;
// Take a step // Take a step
envResult = environment.step(action); envResult = environment.step(action);
sumOfRewards += envResult.getReward(); sumOfRewards += envResult.getReward();
@ -56,19 +50,8 @@ public class SARSA<A extends Enum> extends EpisodicLearning<A> {
// Pick next action // Pick next action
actionValues = stateActionTable.getActionValues(nextState); actionValues = stateActionTable.getActionValues(nextState);
A nextAction; A nextAction = policy.chooseAction(actionValues);
if(currentEpisode % 2 == 1){
nextAction = greedyPolicy.chooseAction(actionValues);
}else{
nextAction = policy.chooseAction(actionValues);
}
//A nextAction = policy.chooseAction(actionValues);
if(currentEpisode % 2 == 1){
state = nextState;
action = nextAction;
dispatchStepEnd();
continue;
}
// td update // td update
// target = reward + gamma * Q(nextState, nextAction) // target = reward + gamma * Q(nextState, nextAction)
double currentQValue = stateActionTable.getActionValues(state).get(action); double currentQValue = stateActionTable.getActionValues(state).get(action);

View File

@ -29,7 +29,7 @@ public class Grid {
initialGrid[x][y] = new Cell(new Point(x, y), CellType.FREE); initialGrid[x][y] = new Cell(new Point(x, y), CellType.FREE);
} }
} }
start = new Point(RNG.getRandom().nextInt(width), RNG.getRandom().nextInt(height)); start = new Point(RNG.getRandomEnv().nextInt(width), RNG.getRandomEnv().nextInt(height));
initialGrid[start.x][start.y] = new Cell(new Point(start.x, start.y), CellType.START); initialGrid[start.x][start.y] = new Cell(new Point(start.x, start.y), CellType.START);
spawnNewFood(initialGrid); spawnNewFood(initialGrid);
spawnObstacles(); spawnObstacles();
@ -58,8 +58,8 @@ public class Grid {
Point potFood = new Point(0, 0); Point potFood = new Point(0, 0);
CellType potFieldType; CellType potFieldType;
while(!foodSpawned) { while(!foodSpawned) {
potFood.x = RNG.getRandom().nextInt(width); potFood.x = RNG.getRandomEnv().nextInt(width);
potFood.y = RNG.getRandom().nextInt(height); potFood.y = RNG.getRandomEnv().nextInt(height);
potFieldType = grid[potFood.x][potFood.y].getType(); potFieldType = grid[potFood.x][potFood.y].getType();
if(potFieldType != CellType.START && grid[potFood.x][potFood.y].getFood() == 0 && potFieldType != CellType.OBSTACLE) { if(potFieldType != CellType.START && grid[potFood.x][potFood.y].getFood() == 0 && potFieldType != CellType.OBSTACLE) {
grid[potFood.x][potFood.y].setFood(1); grid[potFood.x][potFood.y].setFood(1);

View File

@ -31,7 +31,7 @@ public class DinoWorldAdvanced extends DinoWorld{
protected void spawnNewObstacle() { protected void spawnNewObstacle() {
int dx; int dx;
int xSpawn; int xSpawn;
double ran = RNG.getRandom().nextDouble(); double ran = RNG.getRandomEnv().nextDouble();
if(ran < 0.25){ if(ran < 0.25){
dx = -(int) (0.35 * Config.OBSTACLE_SPEED); dx = -(int) (0.35 * Config.OBSTACLE_SPEED);
}else if(ran < 0.5){ }else if(ran < 0.5){
@ -41,7 +41,7 @@ public class DinoWorldAdvanced extends DinoWorld{
} else{ } else{
dx = -(int) (3.5 * Config.OBSTACLE_SPEED); dx = -(int) (3.5 * Config.OBSTACLE_SPEED);
} }
double ran2 = RNG.getRandom().nextDouble(); double ran2 = RNG.getRandomEnv().nextDouble();
if(ran2 < 0.25) { if(ran2 < 0.25) {
// randomly spawning more right outside of the screen // randomly spawning more right outside of the screen
xSpawn = Config.FRAME_WIDTH + Config.FRAME_WIDTH + Config.OBSTACLE_SIZE; xSpawn = Config.FRAME_WIDTH + Config.FRAME_WIDTH + Config.OBSTACLE_SIZE;

View File

@ -8,19 +8,28 @@ import evironment.antGame.AntAction;
import evironment.antGame.AntWorldContinuous; import evironment.antGame.AntWorldContinuous;
import evironment.antGame.AntWorldContinuousOriginalState; import evironment.antGame.AntWorldContinuousOriginalState;
import java.io.File;
import java.io.IOException;
public class ContinuousAnt { public class ContinuousAnt {
public static final String FILE_NAME = "converge05.txt";
public static void main(String[] args) { public static void main(String[] args) {
File file = new File(FILE_NAME);
try {
file.createNewFile();
} catch (IOException e) {
e.printStackTrace();
}
RNG.setSeed(56); RNG.setSeed(56);
RLController<AntAction> rl = new RLControllerGUI<>( RLController<AntAction> rl = new RLController<>(
new AntWorldContinuousOriginalState(8, 8), new AntWorldContinuous(8, 8),
Method.Q_LEARNING_OFF_POLICY_CONTROL, Method.Q_LEARNING_OFF_POLICY_CONTROL,
AntAction.values()); AntAction.values());
rl.setDelay(0);
rl.setDelay(200); rl.setNrOfEpisodes(1);
rl.setNrOfEpisodes(10000); rl.setDiscountFactor(0.7f);
rl.setDiscountFactor(0.95f); rl.setLearningRate(0.2f);
rl.setEpsilon(0.15f); rl.setEpsilon(0.5f);
rl.start(); rl.start();
} }
} }