removed unnecessary stuff from sampling branches

This commit is contained in:
Jan Löwenstrom 2020-04-05 13:37:38 +02:00
parent 0300f3b1fd
commit bbccef1e71
7 changed files with 10 additions and 85 deletions

View File

@ -16,12 +16,12 @@ import java.util.List;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
public abstract class EpisodicLearning<A extends Enum> extends Learning<A> implements Episodic { public abstract class EpisodicLearning<A extends Enum> extends Learning<A> implements Episodic {
private volatile AtomicInteger episodesToLearn = new AtomicInteger(0);
private int episodeSumCurrentSecond;
@Setter @Setter
protected int currentEpisode = 0; protected int currentEpisode = 0;
protected volatile AtomicInteger episodesToLearn = new AtomicInteger(0);
@Getter @Getter
protected volatile int episodePerSecond; protected volatile int episodePerSecond;
protected int episodeSumCurrentSecond;
protected double sumOfRewards; protected double sumOfRewards;
protected List<StepResult<A>> episode = new ArrayList<>(); protected List<StepResult<A>> episode = new ArrayList<>();
@ -84,7 +84,6 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
protected void dispatchStepEnd() { protected void dispatchStepEnd() {
super.dispatchStepEnd(); super.dispatchStepEnd();
timestamp++; timestamp++;
timestampCurrentEpisode++;
} }
@Override @Override
@ -95,9 +94,7 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
private void startLearning(){ private void startLearning(){
dispatchLearningStart(); dispatchLearningStart();
while(episodesToLearn.get() > 0){ while(episodesToLearn.get() > 0){
dispatchEpisodeStart(); dispatchEpisodeStart();
timestampCurrentEpisode = 0;
nextEpisode(); nextEpisode();
dispatchEpisodeEnd(); dispatchEpisodeEnd();
} }

View File

@ -23,10 +23,6 @@ import java.util.concurrent.CopyOnWriteArrayList;
*/ */
@Getter @Getter
public abstract class Learning<A extends Enum>{ public abstract class Learning<A extends Enum>{
// TODO: temp testing -> extract to dedicated test
protected int checkSum;
protected int rewardCheckSum;
// current discrete timestamp t // current discrete timestamp t
protected int timestamp; protected int timestamp;
protected int currentEpisode; protected int currentEpisode;

View File

@ -60,7 +60,6 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A
envResult = environment.step(chosenAction); envResult = environment.step(chosenAction);
State nextState = envResult.getState(); State nextState = envResult.getState();
sumOfRewards += envResult.getReward(); sumOfRewards += envResult.getReward();
rewardCheckSum += envResult.getReward();
episode.add(new StepResult<>(state, chosenAction, envResult.getReward())); episode.add(new StepResult<>(state, chosenAction, envResult.getReward()));
state = nextState; state = nextState;
@ -74,8 +73,6 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A
dispatchStepEnd(); dispatchStepEnd();
} }
// System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards); // System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
HashMap<Pair<State, A>, List<Integer>> stateActionPairs = new LinkedHashMap<>(); HashMap<Pair<State, A>, List<Integer>> stateActionPairs = new LinkedHashMap<>();

View File

@ -5,18 +5,12 @@ import core.algo.EpisodicLearning;
import core.policy.EpsilonGreedyPolicy; import core.policy.EpsilonGreedyPolicy;
import core.policy.GreedyPolicy; import core.policy.GreedyPolicy;
import core.policy.Policy; import core.policy.Policy;
import evironment.antGame.Reward;
import example.ContinuousAnt;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.Map; import java.util.Map;
public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearning<A> { public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearning<A> {
private float alpha; private float alpha;
private Policy<A> greedyPolicy = new GreedyPolicy<>(); private Policy<A> greedyPolicy = new GreedyPolicy<>();
public QLearningOffPolicyTDControl(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, float learningRate, int delay) { public QLearningOffPolicyTDControl(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, float learningRate, int delay) {
@ -42,11 +36,7 @@ public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearnin
StepResultEnvironment envResult = null; StepResultEnvironment envResult = null;
Map<A, Double> actionValues = null; Map<A, Double> actionValues = null;
sumOfRewards = 0; sumOfRewards = 0;
int timestampTilFood = 0;
int foodCollected = 0;
int foodTimestampsTotal= 0;
while(envResult == null || !envResult.isDone()) { while(envResult == null || !envResult.isDone()) {
actionValues = stateActionTable.getActionValues(state); actionValues = stateActionTable.getActionValues(state);
A action = policy.chooseAction(actionValues); A action = policy.chooseAction(actionValues);
@ -56,44 +46,6 @@ public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearnin
double reward = envResult.getReward(); double reward = envResult.getReward();
State nextState = envResult.getState(); State nextState = envResult.getState();
sumOfRewards += reward; sumOfRewards += reward;
timestampTilFood++;
if(reward == Reward.FOOD_DROP_DOWN_SUCCESS) {
foodCollected++;
foodTimestampsTotal += timestampTilFood;
File file = new File(ContinuousAnt.FILE_NAME);
if(foodCollected % 1000 == 0) {
System.out.println(foodTimestampsTotal / 1000f + " " + timestampCurrentEpisode);
try {
Files.writeString(Path.of(file.getPath()), foodTimestampsTotal / 1000f + ",", StandardOpenOption.APPEND);
} catch (IOException e) {
e.printStackTrace();
}
foodTimestampsTotal = 0;
}
if(foodCollected == 1000){
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(0.15f);
}
if(foodCollected == 2000){
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(0.10f);
}
if(foodCollected == 3000){
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(0.05f);
}
if(foodCollected == 4000){
System.out.println("Reached 0 exploration");
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(0.00f);
}
if(foodCollected == 15000){
try {
Files.writeString(Path.of(file.getPath()), "\n", StandardOpenOption.APPEND);
} catch (IOException e) {
e.printStackTrace();
}
return;
}
timestampTilFood = 0;
}
// Q Update // Q Update
double currentQValue = stateActionTable.getActionValues(state).get(action); double currentQValue = stateActionTable.getActionValues(state).get(action);

View File

@ -3,8 +3,6 @@ package core.algo.td;
import core.*; import core.*;
import core.algo.EpisodicLearning; import core.algo.EpisodicLearning;
import core.policy.EpsilonGreedyPolicy; import core.policy.EpsilonGreedyPolicy;
import core.policy.GreedyPolicy;
import core.policy.Policy;
import java.util.Map; import java.util.Map;
@ -35,10 +33,8 @@ public class SARSA<A extends Enum> extends EpisodicLearning<A> {
StepResultEnvironment envResult = null; StepResultEnvironment envResult = null;
Map<A, Double> actionValues = stateActionTable.getActionValues(state); Map<A, Double> actionValues = stateActionTable.getActionValues(state);
A action = policy.chooseAction(actionValues); A action = policy.chooseAction(actionValues);
//A action = policy.chooseAction(actionValues);
sumOfRewards = 0; sumOfRewards = 0;
while(envResult == null || !envResult.isDone()) { while(envResult == null || !envResult.isDone()) {
// Take a step // Take a step

View File

@ -7,31 +7,18 @@ import core.controller.RLControllerGUI;
import evironment.antGame.AntAction; import evironment.antGame.AntAction;
import evironment.antGame.AntWorldContinuous; import evironment.antGame.AntWorldContinuous;
import java.io.File;
import java.io.IOException;
public class ContinuousAnt { public class ContinuousAnt {
public static final String FILE_NAME = "converge.txt";
public static void main(String[] args) { public static void main(String[] args) {
File file = new File(FILE_NAME);
try {
file.createNewFile();
} catch (IOException e) {
e.printStackTrace();
}
RNG.setSeed(13, true); RNG.setSeed(13, true);
RLController<AntAction> rl = new RLControllerGUI<>( RLController<AntAction> rl = new RLControllerGUI<>(
new AntWorldContinuous(8, 8), new AntWorldContinuous(8, 8),
Method.Q_LEARNING_OFF_POLICY_CONTROL, Method.Q_LEARNING_OFF_POLICY_CONTROL,
AntAction.values()); AntAction.values());
rl.setDelay(20); rl.setDelay(200);
rl.setNrOfEpisodes(1); rl.setNrOfEpisodes(1);
// 0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 0.99 rl.setDiscountFactor(0.3f);
rl.setDiscountFactor(0.05f);
// 0.1, 0.3, 0.5, 0.7 0.9
rl.setLearningRate(0.9f); rl.setLearningRate(0.9f);
rl.setEpsilon(0.2f); rl.setEpsilon(0.15f);
rl.start(); rl.start();
} }
} }

View File

@ -3,20 +3,20 @@ package example;
import core.RNG; import core.RNG;
import core.algo.Method; import core.algo.Method;
import core.controller.RLController; import core.controller.RLController;
import core.controller.RLControllerGUI;
import evironment.jumpingDino.DinoAction; import evironment.jumpingDino.DinoAction;
import evironment.jumpingDino.DinoWorld;
import evironment.jumpingDino.DinoWorldAdvanced; import evironment.jumpingDino.DinoWorldAdvanced;
public class JumpingDino { public class JumpingDino {
public static void main(String[] args) { public static void main(String[] args) {
RNG.setSeed(29); RNG.setSeed(29);
RLController<DinoAction> rl = new RLController<>( RLController<DinoAction> rl = new RLControllerGUI<>(
new DinoWorldAdvanced(), new DinoWorldAdvanced(),
Method.MC_CONTROL_FIRST_VISIT, Method.MC_CONTROL_FIRST_VISIT,
DinoAction.values()); DinoAction.values());
rl.setDelay(0); rl.setDelay(200);
rl.setDiscountFactor(9f); rl.setDiscountFactor(9f);
rl.setEpsilon(0.05f); rl.setEpsilon(0.05f);
rl.setLearningRate(0.8f); rl.setLearningRate(0.8f);