rename MC class and improve specific analysis of antGame examples

This commit is contained in:
Jan Löwenstrom 2020-04-05 12:29:44 +02:00
parent 4402d70467
commit 5b82e7965d
8 changed files with 28 additions and 111 deletions

View File

@ -12,19 +12,19 @@ import java.io.ObjectOutputStream;
import java.util.*;
/**
* Includes both variants of Monte-Carlo methods
* Includes both! variants of Monte-Carlo methods
* Default method is First-Visit.
* Change to Every-Visit by setting flag "useEveryVisit" in the constructor to true.
* @param <A>
*/
public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends EpisodicLearning<A> {
public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A> {
private Map<Pair<State, A>, Double> returnSum;
private Map<Pair<State, A>, Integer> returnCount;
private boolean isEveryVisit;
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay, boolean useEveryVisit) {
public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay, boolean useEveryVisit) {
super(environment, actionSpace, discountFactor, delay);
isEveryVisit = useEveryVisit;
this.policy = new EpsilonGreedyPolicy<>(epsilon);
@ -33,11 +33,11 @@ public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends Episodic
returnCount = new HashMap<>();
}
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
this(environment, actionSpace, discountFactor, epsilon, delay, false);
}
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay);
}

View File

@ -45,9 +45,7 @@ public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearnin
sumOfRewards = 0;
int timestampTilFood = 0;
int rewardsPer1000 = 0;
int foodCollected = 0;
int iterations = 0;
int foodTimestampsTotal= 0;
while(envResult == null || !envResult.isDone()) {
actionValues = stateActionTable.getActionValues(state);
@ -58,20 +56,8 @@ public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearnin
double reward = envResult.getReward();
State nextState = envResult.getState();
sumOfRewards += reward;
rewardsPer1000+=reward;
timestampTilFood++;
/* if(iterations == 100){
File file = new File(ContinuousAnt.FILE_NAME);
try {
Files.writeString(Path.of(file.getPath()), "\n", StandardOpenOption.APPEND);
} catch (IOException e) {
e.printStackTrace();
}
return;
}*/
if(reward == Reward.FOOD_DROP_DOWN_SUCCESS) {
foodCollected++;
foodTimestampsTotal += timestampTilFood;
@ -95,7 +81,7 @@ public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearnin
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(0.05f);
}
if(foodCollected == 4000){
System.out.println("final 0 expl");
System.out.println("Reached 0 exploration");
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(0.00f);
}
if(foodCollected == 15000){
@ -106,7 +92,6 @@ public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearnin
}
return;
}
iterations++;
timestampTilFood = 0;
}

View File

@ -7,7 +7,7 @@ import core.ListDiscreteActionSpace;
import core.algo.EpisodicLearning;
import core.algo.Learning;
import core.algo.Method;
import core.algo.mc.MonteCarloControlFirstVisitEGreedy;
import core.algo.mc.MonteCarloControlEGreedy;
import core.algo.td.QLearningOffPolicyTDControl;
import core.algo.td.SARSA;
import core.listener.LearningListener;
@ -49,10 +49,10 @@ public class RLController<A extends Enum> implements LearningListener {
public void start() {
switch(method) {
case MC_CONTROL_FIRST_VISIT:
learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
learning = new MonteCarloControlEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
break;
case MC_CONTROL_EVERY_VISIT:
learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay, true);
learning = new MonteCarloControlEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay, true);
break;
case SARSA_ON_POLICY_CONTROL:

View File

@ -11,30 +11,28 @@ import java.io.File;
import java.io.IOException;
public class ContinuousAnt {
public static final String FILE_NAME = "converge22.txt";
public static final String FILE_NAME = "converge.txt";
public static void main(String[] args) {
int i = 4+4+4+6+6+6+8+10+12+14+14+16+16+16+18+18+18+20+20+20+22+22+22+24+24+24+24+26+26+26+26+26+28+28+28+28+28+30+30+30+30+32+32+32+34+34+34+36+36+38+40+42;
System.out.println(i/52f);
File file = new File(FILE_NAME);
try {
file.createNewFile();
} catch (IOException e) {
e.printStackTrace();
}
RNG.setSeed(13);
RNG.setSeed(13);
RLController<AntAction> rl = new RLControllerGUI<>(
new AntWorldContinuous(8, 8),
Method.Q_LEARNING_OFF_POLICY_CONTROL,
AntAction.values());
new AntWorldContinuous(8, 8),
Method.Q_LEARNING_OFF_POLICY_CONTROL,
AntAction.values());
rl.setDelay(20);
rl.setNrOfEpisodes(1);
//0.99 0.9 0.5
//0.99 0.95 0.9 0.7 0.5 0.3 0.1
rl.setNrOfEpisodes(1);
// 0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 0.99
rl.setDiscountFactor(0.05f);
// 0.1, 0.3, 0.5, 0.7 0.9
rl.setLearningRate(0.9f);
rl.setEpsilon(0.2f);
rl.start();
// 0.1, 0.3, 0.5, 0.7 0.9
rl.setLearningRate(0.9f);
rl.setEpsilon(0.2f);
rl.start();
}

View File

@ -14,8 +14,8 @@ import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
public class DinoSampling {
public static final float f =0.05f;
public static final String FILE_NAME = "converge.txt";
public static void main(String[] args) {
File file = new File(FILE_NAME);
try {
@ -23,15 +23,16 @@ public class DinoSampling {
} catch (IOException e) {
e.printStackTrace();
}
for(float f = 0.05f; f <=1.003 ; f+=0.05f) {
for(float f = 0.05f; f <= 1.003; f += 0.05f) {
try {
Files.writeString(Path.of(file.getPath()), f + ",", StandardOpenOption.APPEND);
} catch (IOException e) {
e.printStackTrace();
}
for (int i = 1; i <= 100; i++) {
System.out.println("seed: " + i * 13);
RNG.setSeed(i * 13);
for(int i = 1; i <= 100; i++) {
int seed = i * 13;
System.out.println("seed: " + seed);
RNG.setSeed(seed);
RLController<DinoAction> rl = new RLControllerGUI<>(
new DinoWorld(),

View File

@ -1,15 +0,0 @@
Method:
Epsilon = k / currentEpisode
set to 0 if Epsilon < b
k = 1.5
b = 0.1 => conv. 16
k = 1.5
b = 0.02 => 75
k = 1.4
b = 0.02 => fail
k = 2.0
b = 0.02 => conv. 100

View File

@ -19,8 +19,8 @@ public class RunningAnt {
rl.setDelay(200);
rl.setNrOfEpisodes(10000);
rl.setDiscountFactor(0.9f);
rl.setLearningRate(0.9f);
rl.setEpsilon(0.15f);
rl.start();
}
}

View File

@ -1,52 +0,0 @@
package example;
public class Test {
interface Drawable{
void draw();
}
interface State{
int getInt();
}
static class A implements Drawable, State{
private int k;
public A(int a){
k = a;
}
@Override
public void draw() {
System.out.println("draw " + k);
}
@Override
public int getInt() {
System.out.println("getInt" + k);
return k;
}
}
static class B implements State{
@Override
public int getInt() {
return 0;
}
}
public static void main(String[] args) {
State state = new A(24);
State state2 = new B();
state.getInt();
System.out.println(state2 instanceof Drawable);
drawState(state2);
}
static void drawState(State s){
if(s instanceof Drawable){
Drawable d = (Drawable) s;
d.draw();
}else{
System.out.println("invalid");
}
}
}