rename MC class and improve specific analysis of antGame examples
This commit is contained in:
parent
4402d70467
commit
5b82e7965d
|
@ -12,19 +12,19 @@ import java.io.ObjectOutputStream;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Includes both variants of Monte-Carlo methods
|
* Includes both! variants of Monte-Carlo methods
|
||||||
* Default method is First-Visit.
|
* Default method is First-Visit.
|
||||||
* Change to Every-Visit by setting flag "useEveryVisit" in the constructor to true.
|
* Change to Every-Visit by setting flag "useEveryVisit" in the constructor to true.
|
||||||
* @param <A>
|
* @param <A>
|
||||||
*/
|
*/
|
||||||
public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends EpisodicLearning<A> {
|
public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A> {
|
||||||
|
|
||||||
private Map<Pair<State, A>, Double> returnSum;
|
private Map<Pair<State, A>, Double> returnSum;
|
||||||
private Map<Pair<State, A>, Integer> returnCount;
|
private Map<Pair<State, A>, Integer> returnCount;
|
||||||
private boolean isEveryVisit;
|
private boolean isEveryVisit;
|
||||||
|
|
||||||
|
|
||||||
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay, boolean useEveryVisit) {
|
public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay, boolean useEveryVisit) {
|
||||||
super(environment, actionSpace, discountFactor, delay);
|
super(environment, actionSpace, discountFactor, delay);
|
||||||
isEveryVisit = useEveryVisit;
|
isEveryVisit = useEveryVisit;
|
||||||
this.policy = new EpsilonGreedyPolicy<>(epsilon);
|
this.policy = new EpsilonGreedyPolicy<>(epsilon);
|
||||||
|
@ -33,11 +33,11 @@ public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends Episodic
|
||||||
returnCount = new HashMap<>();
|
returnCount = new HashMap<>();
|
||||||
}
|
}
|
||||||
|
|
||||||
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
|
public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
|
||||||
this(environment, actionSpace, discountFactor, epsilon, delay, false);
|
this(environment, actionSpace, discountFactor, epsilon, delay, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
|
public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
|
||||||
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay);
|
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay);
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,9 +45,7 @@ public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearnin
|
||||||
|
|
||||||
sumOfRewards = 0;
|
sumOfRewards = 0;
|
||||||
int timestampTilFood = 0;
|
int timestampTilFood = 0;
|
||||||
int rewardsPer1000 = 0;
|
|
||||||
int foodCollected = 0;
|
int foodCollected = 0;
|
||||||
int iterations = 0;
|
|
||||||
int foodTimestampsTotal= 0;
|
int foodTimestampsTotal= 0;
|
||||||
while(envResult == null || !envResult.isDone()) {
|
while(envResult == null || !envResult.isDone()) {
|
||||||
actionValues = stateActionTable.getActionValues(state);
|
actionValues = stateActionTable.getActionValues(state);
|
||||||
|
@ -58,20 +56,8 @@ public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearnin
|
||||||
double reward = envResult.getReward();
|
double reward = envResult.getReward();
|
||||||
State nextState = envResult.getState();
|
State nextState = envResult.getState();
|
||||||
sumOfRewards += reward;
|
sumOfRewards += reward;
|
||||||
rewardsPer1000+=reward;
|
|
||||||
timestampTilFood++;
|
timestampTilFood++;
|
||||||
|
|
||||||
/* if(iterations == 100){
|
|
||||||
File file = new File(ContinuousAnt.FILE_NAME);
|
|
||||||
try {
|
|
||||||
Files.writeString(Path.of(file.getPath()), "\n", StandardOpenOption.APPEND);
|
|
||||||
} catch (IOException e) {
|
|
||||||
e.printStackTrace();
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}*/
|
|
||||||
|
|
||||||
|
|
||||||
if(reward == Reward.FOOD_DROP_DOWN_SUCCESS) {
|
if(reward == Reward.FOOD_DROP_DOWN_SUCCESS) {
|
||||||
foodCollected++;
|
foodCollected++;
|
||||||
foodTimestampsTotal += timestampTilFood;
|
foodTimestampsTotal += timestampTilFood;
|
||||||
|
@ -95,7 +81,7 @@ public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearnin
|
||||||
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(0.05f);
|
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(0.05f);
|
||||||
}
|
}
|
||||||
if(foodCollected == 4000){
|
if(foodCollected == 4000){
|
||||||
System.out.println("final 0 expl");
|
System.out.println("Reached 0 exploration");
|
||||||
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(0.00f);
|
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(0.00f);
|
||||||
}
|
}
|
||||||
if(foodCollected == 15000){
|
if(foodCollected == 15000){
|
||||||
|
@ -106,7 +92,6 @@ public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearnin
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
iterations++;
|
|
||||||
timestampTilFood = 0;
|
timestampTilFood = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ import core.ListDiscreteActionSpace;
|
||||||
import core.algo.EpisodicLearning;
|
import core.algo.EpisodicLearning;
|
||||||
import core.algo.Learning;
|
import core.algo.Learning;
|
||||||
import core.algo.Method;
|
import core.algo.Method;
|
||||||
import core.algo.mc.MonteCarloControlFirstVisitEGreedy;
|
import core.algo.mc.MonteCarloControlEGreedy;
|
||||||
import core.algo.td.QLearningOffPolicyTDControl;
|
import core.algo.td.QLearningOffPolicyTDControl;
|
||||||
import core.algo.td.SARSA;
|
import core.algo.td.SARSA;
|
||||||
import core.listener.LearningListener;
|
import core.listener.LearningListener;
|
||||||
|
@ -49,10 +49,10 @@ public class RLController<A extends Enum> implements LearningListener {
|
||||||
public void start() {
|
public void start() {
|
||||||
switch(method) {
|
switch(method) {
|
||||||
case MC_CONTROL_FIRST_VISIT:
|
case MC_CONTROL_FIRST_VISIT:
|
||||||
learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
|
learning = new MonteCarloControlEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
|
||||||
break;
|
break;
|
||||||
case MC_CONTROL_EVERY_VISIT:
|
case MC_CONTROL_EVERY_VISIT:
|
||||||
learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay, true);
|
learning = new MonteCarloControlEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay, true);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case SARSA_ON_POLICY_CONTROL:
|
case SARSA_ON_POLICY_CONTROL:
|
||||||
|
|
|
@ -11,10 +11,9 @@ import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
|
||||||
public class ContinuousAnt {
|
public class ContinuousAnt {
|
||||||
public static final String FILE_NAME = "converge22.txt";
|
public static final String FILE_NAME = "converge.txt";
|
||||||
|
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
int i = 4+4+4+6+6+6+8+10+12+14+14+16+16+16+18+18+18+20+20+20+22+22+22+24+24+24+24+26+26+26+26+26+28+28+28+28+28+30+30+30+30+32+32+32+34+34+34+36+36+38+40+42;
|
|
||||||
System.out.println(i/52f);
|
|
||||||
File file = new File(FILE_NAME);
|
File file = new File(FILE_NAME);
|
||||||
try {
|
try {
|
||||||
file.createNewFile();
|
file.createNewFile();
|
||||||
|
@ -28,8 +27,7 @@ public class ContinuousAnt {
|
||||||
AntAction.values());
|
AntAction.values());
|
||||||
rl.setDelay(20);
|
rl.setDelay(20);
|
||||||
rl.setNrOfEpisodes(1);
|
rl.setNrOfEpisodes(1);
|
||||||
//0.99 0.9 0.5
|
// 0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 0.99
|
||||||
//0.99 0.95 0.9 0.7 0.5 0.3 0.1
|
|
||||||
rl.setDiscountFactor(0.05f);
|
rl.setDiscountFactor(0.05f);
|
||||||
// 0.1, 0.3, 0.5, 0.7 0.9
|
// 0.1, 0.3, 0.5, 0.7 0.9
|
||||||
rl.setLearningRate(0.9f);
|
rl.setLearningRate(0.9f);
|
||||||
|
|
|
@ -14,8 +14,8 @@ import java.nio.file.Path;
|
||||||
import java.nio.file.StandardOpenOption;
|
import java.nio.file.StandardOpenOption;
|
||||||
|
|
||||||
public class DinoSampling {
|
public class DinoSampling {
|
||||||
public static final float f =0.05f;
|
|
||||||
public static final String FILE_NAME = "converge.txt";
|
public static final String FILE_NAME = "converge.txt";
|
||||||
|
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
File file = new File(FILE_NAME);
|
File file = new File(FILE_NAME);
|
||||||
try {
|
try {
|
||||||
|
@ -30,8 +30,9 @@ public class DinoSampling {
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
}
|
}
|
||||||
for(int i = 1; i <= 100; i++) {
|
for(int i = 1; i <= 100; i++) {
|
||||||
System.out.println("seed: " + i * 13);
|
int seed = i * 13;
|
||||||
RNG.setSeed(i * 13);
|
System.out.println("seed: " + seed);
|
||||||
|
RNG.setSeed(seed);
|
||||||
|
|
||||||
RLController<DinoAction> rl = new RLControllerGUI<>(
|
RLController<DinoAction> rl = new RLControllerGUI<>(
|
||||||
new DinoWorld(),
|
new DinoWorld(),
|
||||||
|
|
|
@ -1,15 +0,0 @@
|
||||||
Method:
|
|
||||||
Epsilon = k / currentEpisode
|
|
||||||
set to 0 if Epsilon < b
|
|
||||||
|
|
||||||
k = 1.5
|
|
||||||
b = 0.1 => conv. 16
|
|
||||||
|
|
||||||
k = 1.5
|
|
||||||
b = 0.02 => 75
|
|
||||||
|
|
||||||
k = 1.4
|
|
||||||
b = 0.02 => fail
|
|
||||||
|
|
||||||
k = 2.0
|
|
||||||
b = 0.02 => conv. 100
|
|
|
@ -19,8 +19,8 @@ public class RunningAnt {
|
||||||
rl.setDelay(200);
|
rl.setDelay(200);
|
||||||
rl.setNrOfEpisodes(10000);
|
rl.setNrOfEpisodes(10000);
|
||||||
rl.setDiscountFactor(0.9f);
|
rl.setDiscountFactor(0.9f);
|
||||||
|
rl.setLearningRate(0.9f);
|
||||||
rl.setEpsilon(0.15f);
|
rl.setEpsilon(0.15f);
|
||||||
|
|
||||||
rl.start();
|
rl.start();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,52 +0,0 @@
|
||||||
package example;
|
|
||||||
|
|
||||||
public class Test {
|
|
||||||
interface Drawable{
|
|
||||||
void draw();
|
|
||||||
}
|
|
||||||
interface State{
|
|
||||||
int getInt();
|
|
||||||
}
|
|
||||||
|
|
||||||
static class A implements Drawable, State{
|
|
||||||
private int k;
|
|
||||||
public A(int a){
|
|
||||||
k = a;
|
|
||||||
}
|
|
||||||
@Override
|
|
||||||
public void draw() {
|
|
||||||
System.out.println("draw " + k);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getInt() {
|
|
||||||
System.out.println("getInt" + k);
|
|
||||||
return k;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static class B implements State{
|
|
||||||
@Override
|
|
||||||
public int getInt() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void main(String[] args) {
|
|
||||||
State state = new A(24);
|
|
||||||
State state2 = new B();
|
|
||||||
state.getInt();
|
|
||||||
|
|
||||||
System.out.println(state2 instanceof Drawable);
|
|
||||||
drawState(state2);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void drawState(State s){
|
|
||||||
if(s instanceof Drawable){
|
|
||||||
Drawable d = (Drawable) s;
|
|
||||||
d.draw();
|
|
||||||
}else{
|
|
||||||
System.out.println("invalid");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue