modify q Learning to sample results and update R script

This commit is contained in:
Jan Löwenstrom 2020-03-28 12:35:33 +01:00
parent eca0d8db4d
commit 328fc85214
9 changed files with 97 additions and 68 deletions

View File

@ -1,7 +1,28 @@
# Libraries
library(ggplot2)
library(matrixStats)
# file.choose()
ta <- as.matrix(read.table(file.choose(), sep=",", header = FALSE))
ta <- t(ta)
dim(ta)
head(ta)
# Create dummy data
data <- data.frame(
y=ta[,1],
y2=ta[,2],
y3=ta[,3],
y4=ta[,4],
x=seq(1, length(ta))
)
ggplot(data, aes(x)) +
geom_line(aes(y = y, colour = "var0")) +
geom_line(aes(y = y2, colour = "var1")) +
geom_line(aes(y = y3, colour = "var2")) +
geom_line(aes(y = y4, colour = "var3")) +
scale_x_log10( breaks=c(1,5,10,15,20,50,100,200), limits=c(1,200) )
plot(ta, x=x*1000, log="x", type="o")
convergence <- read.csv(file.choose(), header=FALSE, row.names=1)
sds <- rowSds(sapply(convergence[,-1], `length<-`, max(lengths(convergence[,-1]))), na.rm=TRUE)

View File

@ -15,18 +15,26 @@ import java.util.Random;
*/
public class RNG {
private static Random rng;
private static Random rngEnv;
private static int seed = 123;
static {
rng = new Random();
rng.setSeed(seed);
rngEnv = new Random();
rngEnv.setSeed(seed);
}
public static Random getRandom() {
return rng;
}
public static Random getRandomEnv() {
return rngEnv;
}
public static void setSeed(int seed){
RNG.seed = seed;
rng.setSeed(seed);
rngEnv = new Random();
rngEnv.setSeed(seed);
}
}

View File

@ -105,8 +105,7 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
timestamp++;
timestampCurrentEpisode++;
// TODO: more sophisticated way to check convergence
if(timestampCurrentEpisode > 30000000){
converged = true;
if(false){
// t
File file = new File(DinoSampling.FILE_NAME);
try {
@ -114,9 +113,7 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
} catch (IOException e) {
e.printStackTrace();
}
System.out.println("converged after: " + currentEpisode/2 + " episode!");
episodesToLearn.set(0);
dispatchLearningEnd();
// System.out.println("converged after: " + currentEpisode/2 + " episode!");
}
}

View File

@ -40,16 +40,9 @@ public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends Episodic
private Map<Pair<State, A>, Double> returnSum;
private Map<Pair<State, A>, Integer> returnCount;
// t
private float epsilon;
// t
private Policy<A> greedyPolicy = new GreedyPolicy<>();
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
super(environment, actionSpace, discountFactor, delay);
// t
this.epsilon = epsilon;
this.policy = new EpsilonGreedyPolicy<>(epsilon);
this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace);
returnSum = new HashMap<>();
@ -74,12 +67,7 @@ public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends Episodic
while(envResult == null || !envResult.isDone()) {
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
A chosenAction;
if(currentEpisode % 2 == 1){
chosenAction = greedyPolicy.chooseAction(actionValues);
}else{
chosenAction = policy.chooseAction(actionValues);
}
A chosenAction = policy.chooseAction(actionValues);
envResult = environment.step(chosenAction);
State nextState = envResult.getState();
@ -96,12 +84,9 @@ public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends Episodic
}
timestamp++;
dispatchStepEnd();
if(converged) return;
}
if(currentEpisode % 2 == 1){
return;
}
// System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
Set<Pair<State, A>> stateActionPairs = new LinkedHashSet<>();

View File

@ -5,7 +5,15 @@ import core.algo.EpisodicLearning;
import core.policy.EpsilonGreedyPolicy;
import core.policy.GreedyPolicy;
import core.policy.Policy;
import evironment.antGame.Reward;
import example.ContinuousAnt;
import example.DinoSampling;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.Map;
public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearning<A> {
@ -37,25 +45,43 @@ public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearnin
sumOfRewards = 0;
int timestampTilFood = 0;
int rewardsPer1000 = 0;
int foodCollected = 0;
while(envResult == null || !envResult.isDone()) {
actionValues = stateActionTable.getActionValues(state);
A action;
if(currentEpisode % 2 == 0){
action = greedyPolicy.chooseAction(actionValues);
}else{
action = policy.chooseAction(actionValues);
}
if(converged) return;
A action = policy.chooseAction(actionValues);
// Take a step
envResult = environment.step(action);
double reward = envResult.getReward();
State nextState = envResult.getState();
sumOfRewards += reward;
if(currentEpisode % 2 == 0){
state = nextState;
dispatchStepEnd();
continue;
rewardsPer1000+=reward;
timestampTilFood++;
if(foodCollected == 10000){
File file = new File(ContinuousAnt.FILE_NAME);
try {
Files.writeString(Path.of(file.getPath()), "\n", StandardOpenOption.APPEND);
} catch (IOException e) {
e.printStackTrace();
}
return;
}
if(reward == Reward.FOOD_DROP_DOWN_SUCCESS){
foodCollected++;
File file = new File(ContinuousAnt.FILE_NAME);
try {
Files.writeString(Path.of(file.getPath()), timestampTilFood + ",", StandardOpenOption.APPEND);
} catch (IOException e) {
e.printStackTrace();
}
timestampTilFood = 0;
rewardsPer1000 = 0;
}
// Q Update
double currentQValue = stateActionTable.getActionValues(state).get(action);
// maxQ(S', a);

View File

@ -11,7 +11,6 @@ import java.util.Map;
public class SARSA<A extends Enum> extends EpisodicLearning<A> {
private float alpha;
private Policy<A> greedyPolicy = new GreedyPolicy<>();
public SARSA(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, float learningRate, int delay) {
super(environment, actionSpace, discountFactor, delay);
@ -35,18 +34,13 @@ public class SARSA<A extends Enum> extends EpisodicLearning<A> {
StepResultEnvironment envResult = null;
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
A action;
if(currentEpisode % 2 == 1){
action = greedyPolicy.chooseAction(actionValues);
}else{
action = policy.chooseAction(actionValues);
}
A action = policy.chooseAction(actionValues);
//A action = policy.chooseAction(actionValues);
sumOfRewards = 0;
while(envResult == null || !envResult.isDone()) {
if(converged) return;
// Take a step
envResult = environment.step(action);
sumOfRewards += envResult.getReward();
@ -56,19 +50,8 @@ public class SARSA<A extends Enum> extends EpisodicLearning<A> {
// Pick next action
actionValues = stateActionTable.getActionValues(nextState);
A nextAction;
if(currentEpisode % 2 == 1){
nextAction = greedyPolicy.chooseAction(actionValues);
}else{
nextAction = policy.chooseAction(actionValues);
}
//A nextAction = policy.chooseAction(actionValues);
if(currentEpisode % 2 == 1){
state = nextState;
action = nextAction;
dispatchStepEnd();
continue;
}
A nextAction = policy.chooseAction(actionValues);
// td update
// target = reward + gamma * Q(nextState, nextAction)
double currentQValue = stateActionTable.getActionValues(state).get(action);

View File

@ -29,7 +29,7 @@ public class Grid {
initialGrid[x][y] = new Cell(new Point(x, y), CellType.FREE);
}
}
start = new Point(RNG.getRandom().nextInt(width), RNG.getRandom().nextInt(height));
start = new Point(RNG.getRandomEnv().nextInt(width), RNG.getRandomEnv().nextInt(height));
initialGrid[start.x][start.y] = new Cell(new Point(start.x, start.y), CellType.START);
spawnNewFood(initialGrid);
spawnObstacles();
@ -58,8 +58,8 @@ public class Grid {
Point potFood = new Point(0, 0);
CellType potFieldType;
while(!foodSpawned) {
potFood.x = RNG.getRandom().nextInt(width);
potFood.y = RNG.getRandom().nextInt(height);
potFood.x = RNG.getRandomEnv().nextInt(width);
potFood.y = RNG.getRandomEnv().nextInt(height);
potFieldType = grid[potFood.x][potFood.y].getType();
if(potFieldType != CellType.START && grid[potFood.x][potFood.y].getFood() == 0 && potFieldType != CellType.OBSTACLE) {
grid[potFood.x][potFood.y].setFood(1);

View File

@ -31,7 +31,7 @@ public class DinoWorldAdvanced extends DinoWorld{
protected void spawnNewObstacle() {
int dx;
int xSpawn;
double ran = RNG.getRandom().nextDouble();
double ran = RNG.getRandomEnv().nextDouble();
if(ran < 0.25){
dx = -(int) (0.35 * Config.OBSTACLE_SPEED);
}else if(ran < 0.5){
@ -41,7 +41,7 @@ public class DinoWorldAdvanced extends DinoWorld{
} else{
dx = -(int) (3.5 * Config.OBSTACLE_SPEED);
}
double ran2 = RNG.getRandom().nextDouble();
double ran2 = RNG.getRandomEnv().nextDouble();
if(ran2 < 0.25) {
// randomly spawning more right outside of the screen
xSpawn = Config.FRAME_WIDTH + Config.FRAME_WIDTH + Config.OBSTACLE_SIZE;

View File

@ -8,19 +8,28 @@ import evironment.antGame.AntAction;
import evironment.antGame.AntWorldContinuous;
import evironment.antGame.AntWorldContinuousOriginalState;
import java.io.File;
import java.io.IOException;
public class ContinuousAnt {
public static final String FILE_NAME = "converge05.txt";
public static void main(String[] args) {
File file = new File(FILE_NAME);
try {
file.createNewFile();
} catch (IOException e) {
e.printStackTrace();
}
RNG.setSeed(56);
RLController<AntAction> rl = new RLControllerGUI<>(
new AntWorldContinuousOriginalState(8, 8),
RLController<AntAction> rl = new RLController<>(
new AntWorldContinuous(8, 8),
Method.Q_LEARNING_OFF_POLICY_CONTROL,
AntAction.values());
rl.setDelay(200);
rl.setNrOfEpisodes(10000);
rl.setDiscountFactor(0.95f);
rl.setEpsilon(0.15f);
rl.setDelay(0);
rl.setNrOfEpisodes(1);
rl.setDiscountFactor(0.7f);
rl.setLearningRate(0.2f);
rl.setEpsilon(0.5f);
rl.start();
}
}