modify q Learning to sample results and update R script
This commit is contained in:
parent
eca0d8db4d
commit
328fc85214
|
@ -1,7 +1,28 @@
|
|||
# Libraries
|
||||
library(ggplot2)
|
||||
library(matrixStats)
|
||||
# file.choose()
|
||||
ta <- as.matrix(read.table(file.choose(), sep=",", header = FALSE))
|
||||
ta <- t(ta)
|
||||
dim(ta)
|
||||
head(ta)
|
||||
|
||||
# Create dummy data
|
||||
data <- data.frame(
|
||||
y=ta[,1],
|
||||
y2=ta[,2],
|
||||
y3=ta[,3],
|
||||
y4=ta[,4],
|
||||
x=seq(1, length(ta))
|
||||
)
|
||||
ggplot(data, aes(x)) +
|
||||
geom_line(aes(y = y, colour = "var0")) +
|
||||
geom_line(aes(y = y2, colour = "var1")) +
|
||||
geom_line(aes(y = y3, colour = "var2")) +
|
||||
geom_line(aes(y = y4, colour = "var3")) +
|
||||
scale_x_log10( breaks=c(1,5,10,15,20,50,100,200), limits=c(1,200) )
|
||||
|
||||
plot(ta, x=x*1000, log="x", type="o")
|
||||
|
||||
convergence <- read.csv(file.choose(), header=FALSE, row.names=1)
|
||||
|
||||
sds <- rowSds(sapply(convergence[,-1], `length<-`, max(lengths(convergence[,-1]))), na.rm=TRUE)
|
||||
|
|
|
@ -15,18 +15,26 @@ import java.util.Random;
|
|||
*/
|
||||
public class RNG {
|
||||
private static Random rng;
|
||||
private static Random rngEnv;
|
||||
private static int seed = 123;
|
||||
static {
|
||||
rng = new Random();
|
||||
rng.setSeed(seed);
|
||||
rngEnv = new Random();
|
||||
rngEnv.setSeed(seed);
|
||||
}
|
||||
|
||||
public static Random getRandom() {
|
||||
return rng;
|
||||
}
|
||||
public static Random getRandomEnv() {
|
||||
return rngEnv;
|
||||
}
|
||||
|
||||
public static void setSeed(int seed){
|
||||
RNG.seed = seed;
|
||||
rng.setSeed(seed);
|
||||
rngEnv = new Random();
|
||||
rngEnv.setSeed(seed);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -105,8 +105,7 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
|
|||
timestamp++;
|
||||
timestampCurrentEpisode++;
|
||||
// TODO: more sophisticated way to check convergence
|
||||
if(timestampCurrentEpisode > 30000000){
|
||||
converged = true;
|
||||
if(false){
|
||||
// t
|
||||
File file = new File(DinoSampling.FILE_NAME);
|
||||
try {
|
||||
|
@ -114,9 +113,7 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
|
|||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
System.out.println("converged after: " + currentEpisode/2 + " episode!");
|
||||
episodesToLearn.set(0);
|
||||
dispatchLearningEnd();
|
||||
// System.out.println("converged after: " + currentEpisode/2 + " episode!");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -40,16 +40,9 @@ public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends Episodic
|
|||
private Map<Pair<State, A>, Double> returnSum;
|
||||
private Map<Pair<State, A>, Integer> returnCount;
|
||||
|
||||
// t
|
||||
private float epsilon;
|
||||
// t
|
||||
private Policy<A> greedyPolicy = new GreedyPolicy<>();
|
||||
|
||||
|
||||
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
|
||||
super(environment, actionSpace, discountFactor, delay);
|
||||
// t
|
||||
this.epsilon = epsilon;
|
||||
this.policy = new EpsilonGreedyPolicy<>(epsilon);
|
||||
this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace);
|
||||
returnSum = new HashMap<>();
|
||||
|
@ -74,12 +67,7 @@ public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends Episodic
|
|||
|
||||
while(envResult == null || !envResult.isDone()) {
|
||||
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
|
||||
A chosenAction;
|
||||
if(currentEpisode % 2 == 1){
|
||||
chosenAction = greedyPolicy.chooseAction(actionValues);
|
||||
}else{
|
||||
chosenAction = policy.chooseAction(actionValues);
|
||||
}
|
||||
A chosenAction = policy.chooseAction(actionValues);
|
||||
|
||||
envResult = environment.step(chosenAction);
|
||||
State nextState = envResult.getState();
|
||||
|
@ -96,12 +84,9 @@ public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends Episodic
|
|||
}
|
||||
timestamp++;
|
||||
dispatchStepEnd();
|
||||
if(converged) return;
|
||||
}
|
||||
|
||||
if(currentEpisode % 2 == 1){
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
// System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
|
||||
Set<Pair<State, A>> stateActionPairs = new LinkedHashSet<>();
|
||||
|
|
|
@ -5,7 +5,15 @@ import core.algo.EpisodicLearning;
|
|||
import core.policy.EpsilonGreedyPolicy;
|
||||
import core.policy.GreedyPolicy;
|
||||
import core.policy.Policy;
|
||||
import evironment.antGame.Reward;
|
||||
import example.ContinuousAnt;
|
||||
import example.DinoSampling;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.StandardOpenOption;
|
||||
import java.util.Map;
|
||||
|
||||
public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearning<A> {
|
||||
|
@ -37,25 +45,43 @@ public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearnin
|
|||
|
||||
|
||||
sumOfRewards = 0;
|
||||
int timestampTilFood = 0;
|
||||
int rewardsPer1000 = 0;
|
||||
int foodCollected = 0;
|
||||
while(envResult == null || !envResult.isDone()) {
|
||||
actionValues = stateActionTable.getActionValues(state);
|
||||
A action;
|
||||
if(currentEpisode % 2 == 0){
|
||||
action = greedyPolicy.chooseAction(actionValues);
|
||||
}else{
|
||||
action = policy.chooseAction(actionValues);
|
||||
}
|
||||
if(converged) return;
|
||||
A action = policy.chooseAction(actionValues);
|
||||
|
||||
// Take a step
|
||||
envResult = environment.step(action);
|
||||
double reward = envResult.getReward();
|
||||
State nextState = envResult.getState();
|
||||
sumOfRewards += reward;
|
||||
if(currentEpisode % 2 == 0){
|
||||
state = nextState;
|
||||
dispatchStepEnd();
|
||||
continue;
|
||||
|
||||
rewardsPer1000+=reward;
|
||||
timestampTilFood++;
|
||||
|
||||
if(foodCollected == 10000){
|
||||
File file = new File(ContinuousAnt.FILE_NAME);
|
||||
try {
|
||||
Files.writeString(Path.of(file.getPath()), "\n", StandardOpenOption.APPEND);
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
return;
|
||||
}
|
||||
if(reward == Reward.FOOD_DROP_DOWN_SUCCESS){
|
||||
foodCollected++;
|
||||
File file = new File(ContinuousAnt.FILE_NAME);
|
||||
try {
|
||||
Files.writeString(Path.of(file.getPath()), timestampTilFood + ",", StandardOpenOption.APPEND);
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
timestampTilFood = 0;
|
||||
rewardsPer1000 = 0;
|
||||
}
|
||||
|
||||
// Q Update
|
||||
double currentQValue = stateActionTable.getActionValues(state).get(action);
|
||||
// maxQ(S', a);
|
||||
|
|
|
@ -11,7 +11,6 @@ import java.util.Map;
|
|||
|
||||
public class SARSA<A extends Enum> extends EpisodicLearning<A> {
|
||||
private float alpha;
|
||||
private Policy<A> greedyPolicy = new GreedyPolicy<>();
|
||||
|
||||
public SARSA(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, float learningRate, int delay) {
|
||||
super(environment, actionSpace, discountFactor, delay);
|
||||
|
@ -35,18 +34,13 @@ public class SARSA<A extends Enum> extends EpisodicLearning<A> {
|
|||
|
||||
StepResultEnvironment envResult = null;
|
||||
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
|
||||
A action;
|
||||
if(currentEpisode % 2 == 1){
|
||||
action = greedyPolicy.chooseAction(actionValues);
|
||||
}else{
|
||||
action = policy.chooseAction(actionValues);
|
||||
}
|
||||
A action = policy.chooseAction(actionValues);
|
||||
|
||||
//A action = policy.chooseAction(actionValues);
|
||||
|
||||
sumOfRewards = 0;
|
||||
while(envResult == null || !envResult.isDone()) {
|
||||
|
||||
if(converged) return;
|
||||
// Take a step
|
||||
envResult = environment.step(action);
|
||||
sumOfRewards += envResult.getReward();
|
||||
|
@ -56,19 +50,8 @@ public class SARSA<A extends Enum> extends EpisodicLearning<A> {
|
|||
// Pick next action
|
||||
actionValues = stateActionTable.getActionValues(nextState);
|
||||
|
||||
A nextAction;
|
||||
if(currentEpisode % 2 == 1){
|
||||
nextAction = greedyPolicy.chooseAction(actionValues);
|
||||
}else{
|
||||
nextAction = policy.chooseAction(actionValues);
|
||||
}
|
||||
//A nextAction = policy.chooseAction(actionValues);
|
||||
if(currentEpisode % 2 == 1){
|
||||
state = nextState;
|
||||
action = nextAction;
|
||||
dispatchStepEnd();
|
||||
continue;
|
||||
}
|
||||
A nextAction = policy.chooseAction(actionValues);
|
||||
|
||||
// td update
|
||||
// target = reward + gamma * Q(nextState, nextAction)
|
||||
double currentQValue = stateActionTable.getActionValues(state).get(action);
|
||||
|
|
|
@ -29,7 +29,7 @@ public class Grid {
|
|||
initialGrid[x][y] = new Cell(new Point(x, y), CellType.FREE);
|
||||
}
|
||||
}
|
||||
start = new Point(RNG.getRandom().nextInt(width), RNG.getRandom().nextInt(height));
|
||||
start = new Point(RNG.getRandomEnv().nextInt(width), RNG.getRandomEnv().nextInt(height));
|
||||
initialGrid[start.x][start.y] = new Cell(new Point(start.x, start.y), CellType.START);
|
||||
spawnNewFood(initialGrid);
|
||||
spawnObstacles();
|
||||
|
@ -58,8 +58,8 @@ public class Grid {
|
|||
Point potFood = new Point(0, 0);
|
||||
CellType potFieldType;
|
||||
while(!foodSpawned) {
|
||||
potFood.x = RNG.getRandom().nextInt(width);
|
||||
potFood.y = RNG.getRandom().nextInt(height);
|
||||
potFood.x = RNG.getRandomEnv().nextInt(width);
|
||||
potFood.y = RNG.getRandomEnv().nextInt(height);
|
||||
potFieldType = grid[potFood.x][potFood.y].getType();
|
||||
if(potFieldType != CellType.START && grid[potFood.x][potFood.y].getFood() == 0 && potFieldType != CellType.OBSTACLE) {
|
||||
grid[potFood.x][potFood.y].setFood(1);
|
||||
|
|
|
@ -31,7 +31,7 @@ public class DinoWorldAdvanced extends DinoWorld{
|
|||
protected void spawnNewObstacle() {
|
||||
int dx;
|
||||
int xSpawn;
|
||||
double ran = RNG.getRandom().nextDouble();
|
||||
double ran = RNG.getRandomEnv().nextDouble();
|
||||
if(ran < 0.25){
|
||||
dx = -(int) (0.35 * Config.OBSTACLE_SPEED);
|
||||
}else if(ran < 0.5){
|
||||
|
@ -41,7 +41,7 @@ public class DinoWorldAdvanced extends DinoWorld{
|
|||
} else{
|
||||
dx = -(int) (3.5 * Config.OBSTACLE_SPEED);
|
||||
}
|
||||
double ran2 = RNG.getRandom().nextDouble();
|
||||
double ran2 = RNG.getRandomEnv().nextDouble();
|
||||
if(ran2 < 0.25) {
|
||||
// randomly spawning more right outside of the screen
|
||||
xSpawn = Config.FRAME_WIDTH + Config.FRAME_WIDTH + Config.OBSTACLE_SIZE;
|
||||
|
|
|
@ -8,19 +8,28 @@ import evironment.antGame.AntAction;
|
|||
import evironment.antGame.AntWorldContinuous;
|
||||
import evironment.antGame.AntWorldContinuousOriginalState;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
|
||||
public class ContinuousAnt {
|
||||
public static final String FILE_NAME = "converge05.txt";
|
||||
public static void main(String[] args) {
|
||||
File file = new File(FILE_NAME);
|
||||
try {
|
||||
file.createNewFile();
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
RNG.setSeed(56);
|
||||
RLController<AntAction> rl = new RLControllerGUI<>(
|
||||
new AntWorldContinuousOriginalState(8, 8),
|
||||
RLController<AntAction> rl = new RLController<>(
|
||||
new AntWorldContinuous(8, 8),
|
||||
Method.Q_LEARNING_OFF_POLICY_CONTROL,
|
||||
AntAction.values());
|
||||
|
||||
rl.setDelay(200);
|
||||
rl.setNrOfEpisodes(10000);
|
||||
rl.setDiscountFactor(0.95f);
|
||||
rl.setEpsilon(0.15f);
|
||||
|
||||
rl.setDelay(0);
|
||||
rl.setNrOfEpisodes(1);
|
||||
rl.setDiscountFactor(0.7f);
|
||||
rl.setLearningRate(0.2f);
|
||||
rl.setEpsilon(0.5f);
|
||||
rl.start();
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue