removed unnecessary stuff from sampling branches
This commit is contained in:
parent
0300f3b1fd
commit
bbccef1e71
|
@ -16,12 +16,12 @@ import java.util.List;
|
|||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
public abstract class EpisodicLearning<A extends Enum> extends Learning<A> implements Episodic {
|
||||
private volatile AtomicInteger episodesToLearn = new AtomicInteger(0);
|
||||
private int episodeSumCurrentSecond;
|
||||
@Setter
|
||||
protected int currentEpisode = 0;
|
||||
protected volatile AtomicInteger episodesToLearn = new AtomicInteger(0);
|
||||
@Getter
|
||||
protected volatile int episodePerSecond;
|
||||
protected int episodeSumCurrentSecond;
|
||||
protected double sumOfRewards;
|
||||
protected List<StepResult<A>> episode = new ArrayList<>();
|
||||
|
||||
|
@ -84,7 +84,6 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
|
|||
protected void dispatchStepEnd() {
|
||||
super.dispatchStepEnd();
|
||||
timestamp++;
|
||||
timestampCurrentEpisode++;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -95,9 +94,7 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
|
|||
private void startLearning(){
|
||||
dispatchLearningStart();
|
||||
while(episodesToLearn.get() > 0){
|
||||
|
||||
dispatchEpisodeStart();
|
||||
timestampCurrentEpisode = 0;
|
||||
nextEpisode();
|
||||
dispatchEpisodeEnd();
|
||||
}
|
||||
|
|
|
@ -23,10 +23,6 @@ import java.util.concurrent.CopyOnWriteArrayList;
|
|||
*/
|
||||
@Getter
|
||||
public abstract class Learning<A extends Enum>{
|
||||
// TODO: temp testing -> extract to dedicated test
|
||||
protected int checkSum;
|
||||
protected int rewardCheckSum;
|
||||
|
||||
// current discrete timestamp t
|
||||
protected int timestamp;
|
||||
protected int currentEpisode;
|
||||
|
|
|
@ -60,7 +60,6 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A
|
|||
envResult = environment.step(chosenAction);
|
||||
State nextState = envResult.getState();
|
||||
sumOfRewards += envResult.getReward();
|
||||
rewardCheckSum += envResult.getReward();
|
||||
episode.add(new StepResult<>(state, chosenAction, envResult.getReward()));
|
||||
|
||||
state = nextState;
|
||||
|
@ -74,8 +73,6 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A
|
|||
dispatchStepEnd();
|
||||
}
|
||||
|
||||
|
||||
|
||||
// System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
|
||||
HashMap<Pair<State, A>, List<Integer>> stateActionPairs = new LinkedHashMap<>();
|
||||
|
||||
|
|
|
@ -5,18 +5,12 @@ import core.algo.EpisodicLearning;
|
|||
import core.policy.EpsilonGreedyPolicy;
|
||||
import core.policy.GreedyPolicy;
|
||||
import core.policy.Policy;
|
||||
import evironment.antGame.Reward;
|
||||
import example.ContinuousAnt;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.StandardOpenOption;
|
||||
import java.util.Map;
|
||||
|
||||
public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearning<A> {
|
||||
private float alpha;
|
||||
|
||||
private Policy<A> greedyPolicy = new GreedyPolicy<>();
|
||||
|
||||
public QLearningOffPolicyTDControl(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, float learningRate, int delay) {
|
||||
|
@ -42,11 +36,7 @@ public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearnin
|
|||
StepResultEnvironment envResult = null;
|
||||
Map<A, Double> actionValues = null;
|
||||
|
||||
|
||||
sumOfRewards = 0;
|
||||
int timestampTilFood = 0;
|
||||
int foodCollected = 0;
|
||||
int foodTimestampsTotal= 0;
|
||||
while(envResult == null || !envResult.isDone()) {
|
||||
actionValues = stateActionTable.getActionValues(state);
|
||||
A action = policy.chooseAction(actionValues);
|
||||
|
@ -56,44 +46,6 @@ public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearnin
|
|||
double reward = envResult.getReward();
|
||||
State nextState = envResult.getState();
|
||||
sumOfRewards += reward;
|
||||
timestampTilFood++;
|
||||
|
||||
if(reward == Reward.FOOD_DROP_DOWN_SUCCESS) {
|
||||
foodCollected++;
|
||||
foodTimestampsTotal += timestampTilFood;
|
||||
File file = new File(ContinuousAnt.FILE_NAME);
|
||||
if(foodCollected % 1000 == 0) {
|
||||
System.out.println(foodTimestampsTotal / 1000f + " " + timestampCurrentEpisode);
|
||||
try {
|
||||
Files.writeString(Path.of(file.getPath()), foodTimestampsTotal / 1000f + ",", StandardOpenOption.APPEND);
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
foodTimestampsTotal = 0;
|
||||
}
|
||||
if(foodCollected == 1000){
|
||||
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(0.15f);
|
||||
}
|
||||
if(foodCollected == 2000){
|
||||
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(0.10f);
|
||||
}
|
||||
if(foodCollected == 3000){
|
||||
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(0.05f);
|
||||
}
|
||||
if(foodCollected == 4000){
|
||||
System.out.println("Reached 0 exploration");
|
||||
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(0.00f);
|
||||
}
|
||||
if(foodCollected == 15000){
|
||||
try {
|
||||
Files.writeString(Path.of(file.getPath()), "\n", StandardOpenOption.APPEND);
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
return;
|
||||
}
|
||||
timestampTilFood = 0;
|
||||
}
|
||||
|
||||
// Q Update
|
||||
double currentQValue = stateActionTable.getActionValues(state).get(action);
|
||||
|
|
|
@ -3,8 +3,6 @@ package core.algo.td;
|
|||
import core.*;
|
||||
import core.algo.EpisodicLearning;
|
||||
import core.policy.EpsilonGreedyPolicy;
|
||||
import core.policy.GreedyPolicy;
|
||||
import core.policy.Policy;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
|
@ -35,10 +33,8 @@ public class SARSA<A extends Enum> extends EpisodicLearning<A> {
|
|||
StepResultEnvironment envResult = null;
|
||||
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
|
||||
A action = policy.chooseAction(actionValues);
|
||||
|
||||
//A action = policy.chooseAction(actionValues);
|
||||
|
||||
sumOfRewards = 0;
|
||||
|
||||
while(envResult == null || !envResult.isDone()) {
|
||||
|
||||
// Take a step
|
||||
|
|
|
@ -7,31 +7,18 @@ import core.controller.RLControllerGUI;
|
|||
import evironment.antGame.AntAction;
|
||||
import evironment.antGame.AntWorldContinuous;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
|
||||
public class ContinuousAnt {
|
||||
public static final String FILE_NAME = "converge.txt";
|
||||
|
||||
public static void main(String[] args) {
|
||||
File file = new File(FILE_NAME);
|
||||
try {
|
||||
file.createNewFile();
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
RNG.setSeed(13, true);
|
||||
RLController<AntAction> rl = new RLControllerGUI<>(
|
||||
new AntWorldContinuous(8, 8),
|
||||
Method.Q_LEARNING_OFF_POLICY_CONTROL,
|
||||
AntAction.values());
|
||||
rl.setDelay(20);
|
||||
rl.setDelay(200);
|
||||
rl.setNrOfEpisodes(1);
|
||||
// 0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 0.99
|
||||
rl.setDiscountFactor(0.05f);
|
||||
// 0.1, 0.3, 0.5, 0.7 0.9
|
||||
rl.setDiscountFactor(0.3f);
|
||||
rl.setLearningRate(0.9f);
|
||||
rl.setEpsilon(0.2f);
|
||||
rl.setEpsilon(0.15f);
|
||||
rl.start();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,20 +3,20 @@ package example;
|
|||
import core.RNG;
|
||||
import core.algo.Method;
|
||||
import core.controller.RLController;
|
||||
import core.controller.RLControllerGUI;
|
||||
import evironment.jumpingDino.DinoAction;
|
||||
import evironment.jumpingDino.DinoWorld;
|
||||
import evironment.jumpingDino.DinoWorldAdvanced;
|
||||
|
||||
public class JumpingDino {
|
||||
public static void main(String[] args) {
|
||||
RNG.setSeed(29);
|
||||
|
||||
RLController<DinoAction> rl = new RLController<>(
|
||||
RLController<DinoAction> rl = new RLControllerGUI<>(
|
||||
new DinoWorldAdvanced(),
|
||||
Method.MC_CONTROL_FIRST_VISIT,
|
||||
DinoAction.values());
|
||||
|
||||
rl.setDelay(0);
|
||||
rl.setDelay(200);
|
||||
rl.setDiscountFactor(9f);
|
||||
rl.setEpsilon(0.05f);
|
||||
rl.setLearningRate(0.8f);
|
||||
|
|
Loading…
Reference in New Issue