rename MC class and improve specific analysis of antGame examples
This commit is contained in:
parent
4402d70467
commit
5b82e7965d
|
@ -12,19 +12,19 @@ import java.io.ObjectOutputStream;
|
|||
import java.util.*;
|
||||
|
||||
/**
|
||||
* Includes both variants of Monte-Carlo methods
|
||||
* Includes both! variants of Monte-Carlo methods
|
||||
* Default method is First-Visit.
|
||||
* Change to Every-Visit by setting flag "useEveryVisit" in the constructor to true.
|
||||
* @param <A>
|
||||
*/
|
||||
public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends EpisodicLearning<A> {
|
||||
public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A> {
|
||||
|
||||
private Map<Pair<State, A>, Double> returnSum;
|
||||
private Map<Pair<State, A>, Integer> returnCount;
|
||||
private boolean isEveryVisit;
|
||||
|
||||
|
||||
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay, boolean useEveryVisit) {
|
||||
public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay, boolean useEveryVisit) {
|
||||
super(environment, actionSpace, discountFactor, delay);
|
||||
isEveryVisit = useEveryVisit;
|
||||
this.policy = new EpsilonGreedyPolicy<>(epsilon);
|
||||
|
@ -33,11 +33,11 @@ public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends Episodic
|
|||
returnCount = new HashMap<>();
|
||||
}
|
||||
|
||||
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
|
||||
public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
|
||||
this(environment, actionSpace, discountFactor, epsilon, delay, false);
|
||||
}
|
||||
|
||||
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
|
||||
public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
|
||||
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay);
|
||||
}
|
||||
|
|
@ -45,9 +45,7 @@ public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearnin
|
|||
|
||||
sumOfRewards = 0;
|
||||
int timestampTilFood = 0;
|
||||
int rewardsPer1000 = 0;
|
||||
int foodCollected = 0;
|
||||
int iterations = 0;
|
||||
int foodTimestampsTotal= 0;
|
||||
while(envResult == null || !envResult.isDone()) {
|
||||
actionValues = stateActionTable.getActionValues(state);
|
||||
|
@ -58,20 +56,8 @@ public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearnin
|
|||
double reward = envResult.getReward();
|
||||
State nextState = envResult.getState();
|
||||
sumOfRewards += reward;
|
||||
rewardsPer1000+=reward;
|
||||
timestampTilFood++;
|
||||
|
||||
/* if(iterations == 100){
|
||||
File file = new File(ContinuousAnt.FILE_NAME);
|
||||
try {
|
||||
Files.writeString(Path.of(file.getPath()), "\n", StandardOpenOption.APPEND);
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
return;
|
||||
}*/
|
||||
|
||||
|
||||
if(reward == Reward.FOOD_DROP_DOWN_SUCCESS) {
|
||||
foodCollected++;
|
||||
foodTimestampsTotal += timestampTilFood;
|
||||
|
@ -95,7 +81,7 @@ public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearnin
|
|||
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(0.05f);
|
||||
}
|
||||
if(foodCollected == 4000){
|
||||
System.out.println("final 0 expl");
|
||||
System.out.println("Reached 0 exploration");
|
||||
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(0.00f);
|
||||
}
|
||||
if(foodCollected == 15000){
|
||||
|
@ -106,7 +92,6 @@ public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearnin
|
|||
}
|
||||
return;
|
||||
}
|
||||
iterations++;
|
||||
timestampTilFood = 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ import core.ListDiscreteActionSpace;
|
|||
import core.algo.EpisodicLearning;
|
||||
import core.algo.Learning;
|
||||
import core.algo.Method;
|
||||
import core.algo.mc.MonteCarloControlFirstVisitEGreedy;
|
||||
import core.algo.mc.MonteCarloControlEGreedy;
|
||||
import core.algo.td.QLearningOffPolicyTDControl;
|
||||
import core.algo.td.SARSA;
|
||||
import core.listener.LearningListener;
|
||||
|
@ -49,10 +49,10 @@ public class RLController<A extends Enum> implements LearningListener {
|
|||
public void start() {
|
||||
switch(method) {
|
||||
case MC_CONTROL_FIRST_VISIT:
|
||||
learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
|
||||
learning = new MonteCarloControlEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
|
||||
break;
|
||||
case MC_CONTROL_EVERY_VISIT:
|
||||
learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay, true);
|
||||
learning = new MonteCarloControlEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay, true);
|
||||
break;
|
||||
|
||||
case SARSA_ON_POLICY_CONTROL:
|
||||
|
|
|
@ -11,30 +11,28 @@ import java.io.File;
|
|||
import java.io.IOException;
|
||||
|
||||
public class ContinuousAnt {
|
||||
public static final String FILE_NAME = "converge22.txt";
|
||||
public static final String FILE_NAME = "converge.txt";
|
||||
|
||||
public static void main(String[] args) {
|
||||
int i = 4+4+4+6+6+6+8+10+12+14+14+16+16+16+18+18+18+20+20+20+22+22+22+24+24+24+24+26+26+26+26+26+28+28+28+28+28+30+30+30+30+32+32+32+34+34+34+36+36+38+40+42;
|
||||
System.out.println(i/52f);
|
||||
File file = new File(FILE_NAME);
|
||||
try {
|
||||
file.createNewFile();
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
RNG.setSeed(13);
|
||||
RNG.setSeed(13);
|
||||
RLController<AntAction> rl = new RLControllerGUI<>(
|
||||
new AntWorldContinuous(8, 8),
|
||||
Method.Q_LEARNING_OFF_POLICY_CONTROL,
|
||||
AntAction.values());
|
||||
new AntWorldContinuous(8, 8),
|
||||
Method.Q_LEARNING_OFF_POLICY_CONTROL,
|
||||
AntAction.values());
|
||||
rl.setDelay(20);
|
||||
rl.setNrOfEpisodes(1);
|
||||
//0.99 0.9 0.5
|
||||
//0.99 0.95 0.9 0.7 0.5 0.3 0.1
|
||||
rl.setNrOfEpisodes(1);
|
||||
// 0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 0.99
|
||||
rl.setDiscountFactor(0.05f);
|
||||
// 0.1, 0.3, 0.5, 0.7 0.9
|
||||
rl.setLearningRate(0.9f);
|
||||
rl.setEpsilon(0.2f);
|
||||
rl.start();
|
||||
// 0.1, 0.3, 0.5, 0.7 0.9
|
||||
rl.setLearningRate(0.9f);
|
||||
rl.setEpsilon(0.2f);
|
||||
rl.start();
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -14,8 +14,8 @@ import java.nio.file.Path;
|
|||
import java.nio.file.StandardOpenOption;
|
||||
|
||||
public class DinoSampling {
|
||||
public static final float f =0.05f;
|
||||
public static final String FILE_NAME = "converge.txt";
|
||||
|
||||
public static void main(String[] args) {
|
||||
File file = new File(FILE_NAME);
|
||||
try {
|
||||
|
@ -23,15 +23,16 @@ public class DinoSampling {
|
|||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
for(float f = 0.05f; f <=1.003 ; f+=0.05f) {
|
||||
for(float f = 0.05f; f <= 1.003; f += 0.05f) {
|
||||
try {
|
||||
Files.writeString(Path.of(file.getPath()), f + ",", StandardOpenOption.APPEND);
|
||||
} catch (IOException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
for (int i = 1; i <= 100; i++) {
|
||||
System.out.println("seed: " + i * 13);
|
||||
RNG.setSeed(i * 13);
|
||||
for(int i = 1; i <= 100; i++) {
|
||||
int seed = i * 13;
|
||||
System.out.println("seed: " + seed);
|
||||
RNG.setSeed(seed);
|
||||
|
||||
RLController<DinoAction> rl = new RLControllerGUI<>(
|
||||
new DinoWorld(),
|
||||
|
|
|
@ -1,15 +0,0 @@
|
|||
Method:
|
||||
Epsilon = k / currentEpisode
|
||||
set to 0 if Epsilon < b
|
||||
|
||||
k = 1.5
|
||||
b = 0.1 => conv. 16
|
||||
|
||||
k = 1.5
|
||||
b = 0.02 => 75
|
||||
|
||||
k = 1.4
|
||||
b = 0.02 => fail
|
||||
|
||||
k = 2.0
|
||||
b = 0.02 => conv. 100
|
|
@ -19,8 +19,8 @@ public class RunningAnt {
|
|||
rl.setDelay(200);
|
||||
rl.setNrOfEpisodes(10000);
|
||||
rl.setDiscountFactor(0.9f);
|
||||
rl.setLearningRate(0.9f);
|
||||
rl.setEpsilon(0.15f);
|
||||
|
||||
rl.start();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,52 +0,0 @@
|
|||
package example;
|
||||
|
||||
public class Test {
|
||||
interface Drawable{
|
||||
void draw();
|
||||
}
|
||||
interface State{
|
||||
int getInt();
|
||||
}
|
||||
|
||||
static class A implements Drawable, State{
|
||||
private int k;
|
||||
public A(int a){
|
||||
k = a;
|
||||
}
|
||||
@Override
|
||||
public void draw() {
|
||||
System.out.println("draw " + k);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getInt() {
|
||||
System.out.println("getInt" + k);
|
||||
return k;
|
||||
}
|
||||
}
|
||||
|
||||
static class B implements State{
|
||||
@Override
|
||||
public int getInt() {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) {
|
||||
State state = new A(24);
|
||||
State state2 = new B();
|
||||
state.getInt();
|
||||
|
||||
System.out.println(state2 instanceof Drawable);
|
||||
drawState(state2);
|
||||
}
|
||||
|
||||
static void drawState(State s){
|
||||
if(s instanceof Drawable){
|
||||
Drawable d = (Drawable) s;
|
||||
d.draw();
|
||||
}else{
|
||||
System.out.println("invalid");
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue