removed unnecessary stuff from sampling branches

This commit is contained in:
Jan Löwenstrom 2020-04-05 13:37:38 +02:00
parent 0300f3b1fd
commit bbccef1e71
7 changed files with 10 additions and 85 deletions

View File

@ -16,12 +16,12 @@ import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
public abstract class EpisodicLearning<A extends Enum> extends Learning<A> implements Episodic {
private volatile AtomicInteger episodesToLearn = new AtomicInteger(0);
private int episodeSumCurrentSecond;
@Setter
protected int currentEpisode = 0;
protected volatile AtomicInteger episodesToLearn = new AtomicInteger(0);
@Getter
protected volatile int episodePerSecond;
protected int episodeSumCurrentSecond;
protected double sumOfRewards;
protected List<StepResult<A>> episode = new ArrayList<>();
@ -84,7 +84,6 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
protected void dispatchStepEnd() {
super.dispatchStepEnd();
timestamp++;
timestampCurrentEpisode++;
}
@Override
@ -95,9 +94,7 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
private void startLearning(){
dispatchLearningStart();
while(episodesToLearn.get() > 0){
dispatchEpisodeStart();
timestampCurrentEpisode = 0;
nextEpisode();
dispatchEpisodeEnd();
}

View File

@ -23,10 +23,6 @@ import java.util.concurrent.CopyOnWriteArrayList;
*/
@Getter
public abstract class Learning<A extends Enum>{
// TODO: temp testing -> extract to dedicated test
protected int checkSum;
protected int rewardCheckSum;
// current discrete timestamp t
protected int timestamp;
protected int currentEpisode;

View File

@ -60,7 +60,6 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A
envResult = environment.step(chosenAction);
State nextState = envResult.getState();
sumOfRewards += envResult.getReward();
rewardCheckSum += envResult.getReward();
episode.add(new StepResult<>(state, chosenAction, envResult.getReward()));
state = nextState;
@ -74,8 +73,6 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A
dispatchStepEnd();
}
// System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
HashMap<Pair<State, A>, List<Integer>> stateActionPairs = new LinkedHashMap<>();

View File

@ -5,18 +5,12 @@ import core.algo.EpisodicLearning;
import core.policy.EpsilonGreedyPolicy;
import core.policy.GreedyPolicy;
import core.policy.Policy;
import evironment.antGame.Reward;
import example.ContinuousAnt;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.Map;
public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearning<A> {
private float alpha;
private Policy<A> greedyPolicy = new GreedyPolicy<>();
public QLearningOffPolicyTDControl(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, float learningRate, int delay) {
@ -42,11 +36,7 @@ public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearnin
StepResultEnvironment envResult = null;
Map<A, Double> actionValues = null;
sumOfRewards = 0;
int timestampTilFood = 0;
int foodCollected = 0;
int foodTimestampsTotal= 0;
while(envResult == null || !envResult.isDone()) {
actionValues = stateActionTable.getActionValues(state);
A action = policy.chooseAction(actionValues);
@ -56,44 +46,6 @@ public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearnin
double reward = envResult.getReward();
State nextState = envResult.getState();
sumOfRewards += reward;
timestampTilFood++;
if(reward == Reward.FOOD_DROP_DOWN_SUCCESS) {
foodCollected++;
foodTimestampsTotal += timestampTilFood;
File file = new File(ContinuousAnt.FILE_NAME);
if(foodCollected % 1000 == 0) {
System.out.println(foodTimestampsTotal / 1000f + " " + timestampCurrentEpisode);
try {
Files.writeString(Path.of(file.getPath()), foodTimestampsTotal / 1000f + ",", StandardOpenOption.APPEND);
} catch (IOException e) {
e.printStackTrace();
}
foodTimestampsTotal = 0;
}
if(foodCollected == 1000){
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(0.15f);
}
if(foodCollected == 2000){
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(0.10f);
}
if(foodCollected == 3000){
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(0.05f);
}
if(foodCollected == 4000){
System.out.println("Reached 0 exploration");
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(0.00f);
}
if(foodCollected == 15000){
try {
Files.writeString(Path.of(file.getPath()), "\n", StandardOpenOption.APPEND);
} catch (IOException e) {
e.printStackTrace();
}
return;
}
timestampTilFood = 0;
}
// Q Update
double currentQValue = stateActionTable.getActionValues(state).get(action);

View File

@ -3,8 +3,6 @@ package core.algo.td;
import core.*;
import core.algo.EpisodicLearning;
import core.policy.EpsilonGreedyPolicy;
import core.policy.GreedyPolicy;
import core.policy.Policy;
import java.util.Map;
@ -35,10 +33,8 @@ public class SARSA<A extends Enum> extends EpisodicLearning<A> {
StepResultEnvironment envResult = null;
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
A action = policy.chooseAction(actionValues);
//A action = policy.chooseAction(actionValues);
sumOfRewards = 0;
while(envResult == null || !envResult.isDone()) {
// Take a step

View File

@ -7,31 +7,18 @@ import core.controller.RLControllerGUI;
import evironment.antGame.AntAction;
import evironment.antGame.AntWorldContinuous;
import java.io.File;
import java.io.IOException;
public class ContinuousAnt {
public static final String FILE_NAME = "converge.txt";
public static void main(String[] args) {
File file = new File(FILE_NAME);
try {
file.createNewFile();
} catch (IOException e) {
e.printStackTrace();
}
RNG.setSeed(13, true);
RLController<AntAction> rl = new RLControllerGUI<>(
new AntWorldContinuous(8, 8),
Method.Q_LEARNING_OFF_POLICY_CONTROL,
AntAction.values());
rl.setDelay(20);
rl.setDelay(200);
rl.setNrOfEpisodes(1);
// 0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 0.99
rl.setDiscountFactor(0.05f);
// 0.1, 0.3, 0.5, 0.7 0.9
rl.setDiscountFactor(0.3f);
rl.setLearningRate(0.9f);
rl.setEpsilon(0.2f);
rl.setEpsilon(0.15f);
rl.start();
}
}

View File

@ -3,20 +3,20 @@ package example;
import core.RNG;
import core.algo.Method;
import core.controller.RLController;
import core.controller.RLControllerGUI;
import evironment.jumpingDino.DinoAction;
import evironment.jumpingDino.DinoWorld;
import evironment.jumpingDino.DinoWorldAdvanced;
public class JumpingDino {
public static void main(String[] args) {
RNG.setSeed(29);
RLController<DinoAction> rl = new RLController<>(
RLController<DinoAction> rl = new RLControllerGUI<>(
new DinoWorldAdvanced(),
Method.MC_CONTROL_FIRST_VISIT,
DinoAction.values());
rl.setDelay(0);
rl.setDelay(200);
rl.setDiscountFactor(9f);
rl.setEpsilon(0.05f);
rl.setLearningRate(0.8f);