add TD algorithms and started adopting to continous tasks
- add Q-Learning and SARSA - more config variables
This commit is contained in:
parent
f4f1f7bd37
commit
77898f4e5a
|
@ -0,0 +1,5 @@
|
||||||
|
<component name="ProjectCodeStyleConfiguration">
|
||||||
|
<state>
|
||||||
|
<option name="PREFERRED_PROJECT_CODE_STYLE" value="Default" />
|
||||||
|
</state>
|
||||||
|
</component>
|
|
@ -2,6 +2,9 @@ package core;
|
||||||
|
|
||||||
public class LearningConfig {
|
public class LearningConfig {
|
||||||
public static final int DEFAULT_DELAY = 30;
|
public static final int DEFAULT_DELAY = 30;
|
||||||
|
public static final int DEFAULT_NR_OF_EPISODES = 10000;
|
||||||
public static final float DEFAULT_EPSILON = 0.1f;
|
public static final float DEFAULT_EPSILON = 0.1f;
|
||||||
public static final float DEFAULT_DISCOUNT_FACTOR = 1.0f;
|
public static final float DEFAULT_DISCOUNT_FACTOR = 1.0f;
|
||||||
|
// Learning rate
|
||||||
|
public static final float DEFAULT_ALPHA = 0.9f;
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,6 +2,7 @@ package core.algo;
|
||||||
|
|
||||||
import core.DiscreteActionSpace;
|
import core.DiscreteActionSpace;
|
||||||
import core.Environment;
|
import core.Environment;
|
||||||
|
import core.LearningConfig;
|
||||||
import core.StepResult;
|
import core.StepResult;
|
||||||
import core.listener.LearningListener;
|
import core.listener.LearningListener;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
@ -16,7 +17,7 @@ import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
public abstract class EpisodicLearning<A extends Enum> extends Learning<A> implements Episodic {
|
public abstract class EpisodicLearning<A extends Enum> extends Learning<A> implements Episodic {
|
||||||
@Setter
|
@Setter
|
||||||
protected int currentEpisode;
|
protected int currentEpisode = 0;
|
||||||
protected volatile AtomicInteger episodesToLearn = new AtomicInteger(0);
|
protected volatile AtomicInteger episodesToLearn = new AtomicInteger(0);
|
||||||
@Getter
|
@Getter
|
||||||
protected volatile int episodePerSecond;
|
protected volatile int episodePerSecond;
|
||||||
|
@ -81,7 +82,7 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void learn(){
|
public void learn(){
|
||||||
// TODO remove or learn with default episode number
|
learn(LearningConfig.DEFAULT_NR_OF_EPISODES);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void startLearning(){
|
private void startLearning(){
|
||||||
|
@ -132,6 +133,15 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
|
||||||
return episodesToLearn.get();
|
return episodesToLearn.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public int getCurrentEpisode() {
|
||||||
|
return currentEpisode;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getEpisodesPerSecond() {
|
||||||
|
return episodePerSecond;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public synchronized void save(ObjectOutputStream oos) throws IOException {
|
public synchronized void save(ObjectOutputStream oos) throws IOException {
|
||||||
super.save(oos);
|
super.save(oos);
|
||||||
|
|
|
@ -30,21 +30,20 @@ import java.util.*;
|
||||||
*
|
*
|
||||||
* @param <A>
|
* @param <A>
|
||||||
*/
|
*/
|
||||||
public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<A> {
|
public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A> {
|
||||||
|
|
||||||
private Map<Pair<State, A>, Double> returnSum;
|
private Map<Pair<State, A>, Double> returnSum;
|
||||||
private Map<Pair<State, A>, Integer> returnCount;
|
private Map<Pair<State, A>, Integer> returnCount;
|
||||||
|
|
||||||
public MonteCarloOnPolicyEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
|
public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
|
||||||
super(environment, actionSpace, discountFactor, delay);
|
super(environment, actionSpace, discountFactor, delay);
|
||||||
currentEpisode = 0;
|
|
||||||
this.policy = new EpsilonGreedyPolicy<>(epsilon);
|
this.policy = new EpsilonGreedyPolicy<>(epsilon);
|
||||||
this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace);
|
this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace);
|
||||||
returnSum = new HashMap<>();
|
returnSum = new HashMap<>();
|
||||||
returnCount = new HashMap<>();
|
returnCount = new HashMap<>();
|
||||||
}
|
}
|
||||||
|
|
||||||
public MonteCarloOnPolicyEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
|
public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
|
||||||
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay);
|
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -107,16 +106,6 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getCurrentEpisode() {
|
|
||||||
return currentEpisode;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getEpisodesPerSecond(){
|
|
||||||
return episodePerSecond;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void save(ObjectOutputStream oos) throws IOException {
|
public void save(ObjectOutputStream oos) throws IOException {
|
||||||
super.save(oos);
|
super.save(oos);
|
|
@ -5,5 +5,5 @@ package core.algo;
|
||||||
* which RL-algorithm should be used.
|
* which RL-algorithm should be used.
|
||||||
*/
|
*/
|
||||||
public enum Method {
|
public enum Method {
|
||||||
MC_ONPOLICY_EGREEDY, TD_ONPOLICY
|
MC_CONTROL_EGREEDY, SARSA_EPISODIC, Q_LEARNING_OFF_POLICY_CONTROL
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,68 @@
|
||||||
|
package core.algo.td;
|
||||||
|
|
||||||
|
import core.*;
|
||||||
|
import core.algo.EpisodicLearning;
|
||||||
|
import core.policy.EpsilonGreedyPolicy;
|
||||||
|
import core.policy.GreedyPolicy;
|
||||||
|
import core.policy.Policy;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearning<A> {
|
||||||
|
private float alpha;
|
||||||
|
private Policy<A> greedyPolicy = new GreedyPolicy<>();
|
||||||
|
|
||||||
|
public QLearningOffPolicyTDControl(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, float learningRate, int delay) {
|
||||||
|
super(environment, actionSpace, discountFactor, delay);
|
||||||
|
alpha = learningRate;
|
||||||
|
this.policy = new EpsilonGreedyPolicy<>(epsilon);
|
||||||
|
this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace);
|
||||||
|
}
|
||||||
|
|
||||||
|
public QLearningOffPolicyTDControl(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
|
||||||
|
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, LearningConfig.DEFAULT_ALPHA, delay);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void nextEpisode() {
|
||||||
|
State state = environment.reset();
|
||||||
|
try {
|
||||||
|
Thread.sleep(delay);
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
|
||||||
|
StepResultEnvironment envResult = null;
|
||||||
|
Map<A, Double> actionValues = null;
|
||||||
|
|
||||||
|
|
||||||
|
sumOfRewards = 0;
|
||||||
|
while(envResult == null || !envResult.isDone()) {
|
||||||
|
actionValues = stateActionTable.getActionValues(state);
|
||||||
|
A action = policy.chooseAction(actionValues);
|
||||||
|
|
||||||
|
// Take a step
|
||||||
|
envResult = environment.step(action);
|
||||||
|
double reward = envResult.getReward();
|
||||||
|
State nextState = envResult.getState();
|
||||||
|
sumOfRewards += reward;
|
||||||
|
|
||||||
|
// Q Update
|
||||||
|
double currentQValue = stateActionTable.getActionValues(state).get(action);
|
||||||
|
// maxQ(S', a);
|
||||||
|
// Using intern "greedy policy" as a helper to determine the highest action-value
|
||||||
|
double highestValueNextState = stateActionTable.getActionValues(nextState).get(greedyPolicy.chooseAction(stateActionTable.getActionValues(nextState)));
|
||||||
|
|
||||||
|
double updatedQValue = currentQValue + alpha * (reward + discountFactor * highestValueNextState - currentQValue);
|
||||||
|
stateActionTable.setValue(state, action, updatedQValue);
|
||||||
|
|
||||||
|
state = nextState;
|
||||||
|
try {
|
||||||
|
Thread.sleep(delay);
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
dispatchStepEnd();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,68 @@
|
||||||
|
package core.algo.td;
|
||||||
|
|
||||||
|
import core.*;
|
||||||
|
import core.algo.EpisodicLearning;
|
||||||
|
import core.policy.EpsilonGreedyPolicy;
|
||||||
|
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
|
||||||
|
public class SARSA<A extends Enum> extends EpisodicLearning<A> {
|
||||||
|
private float alpha;
|
||||||
|
|
||||||
|
public SARSA(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, float learningRate, int delay) {
|
||||||
|
super(environment, actionSpace, discountFactor, delay);
|
||||||
|
alpha = learningRate;
|
||||||
|
this.policy = new EpsilonGreedyPolicy<>(epsilon);
|
||||||
|
this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace);
|
||||||
|
}
|
||||||
|
|
||||||
|
public SARSA(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
|
||||||
|
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, LearningConfig.DEFAULT_ALPHA, delay);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected void nextEpisode() {
|
||||||
|
State state = environment.reset();
|
||||||
|
try {
|
||||||
|
Thread.sleep(delay);
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
|
||||||
|
StepResultEnvironment envResult = null;
|
||||||
|
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
|
||||||
|
A action = policy.chooseAction(actionValues);
|
||||||
|
|
||||||
|
sumOfRewards = 0;
|
||||||
|
while(envResult == null || !envResult.isDone()) {
|
||||||
|
// Take a step
|
||||||
|
envResult = environment.step(action);
|
||||||
|
sumOfRewards += envResult.getReward();
|
||||||
|
|
||||||
|
State nextState = envResult.getState();
|
||||||
|
|
||||||
|
// Pick next action
|
||||||
|
actionValues = stateActionTable.getActionValues(nextState);
|
||||||
|
A nextAction = policy.chooseAction(actionValues);
|
||||||
|
|
||||||
|
// TD update
|
||||||
|
// target = reward + gamma * Q(nextState, nextAction)
|
||||||
|
double currentQValue = stateActionTable.getActionValues(state).get(action);
|
||||||
|
double nextQValue = stateActionTable.getActionValues(nextState).get(nextAction);
|
||||||
|
double reward = envResult.getReward();
|
||||||
|
double updatedQValue = currentQValue + alpha * (reward + discountFactor * nextQValue - currentQValue);
|
||||||
|
stateActionTable.setValue(state, action, updatedQValue);
|
||||||
|
|
||||||
|
state = nextState;
|
||||||
|
action = nextAction;
|
||||||
|
|
||||||
|
try {
|
||||||
|
Thread.sleep(delay);
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
dispatchStepEnd();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,4 +0,0 @@
|
||||||
package core.algo.TD;
|
|
||||||
|
|
||||||
public class TemporalDifferenceOnPolicy {
|
|
||||||
}
|
|
|
@ -1,18 +1,19 @@
|
||||||
package core.controller;
|
package core.controller;
|
||||||
|
|
||||||
import core.*;
|
import core.DiscreteActionSpace;
|
||||||
|
import core.Environment;
|
||||||
|
import core.LearningConfig;
|
||||||
|
import core.ListDiscreteActionSpace;
|
||||||
import core.algo.EpisodicLearning;
|
import core.algo.EpisodicLearning;
|
||||||
import core.algo.Learning;
|
import core.algo.Learning;
|
||||||
import core.algo.Method;
|
import core.algo.Method;
|
||||||
import core.algo.mc.MonteCarloOnPolicyEGreedy;
|
import core.algo.mc.MonteCarloControlEGreedy;
|
||||||
import core.gui.LearningView;
|
import core.algo.td.QLearningOffPolicyTDControl;
|
||||||
import core.gui.View;
|
import core.algo.td.SARSA;
|
||||||
import core.listener.LearningListener;
|
import core.listener.LearningListener;
|
||||||
import core.listener.ViewListener;
|
|
||||||
import core.policy.EpsilonPolicy;
|
import core.policy.EpsilonPolicy;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
|
|
||||||
import javax.swing.*;
|
|
||||||
import java.io.*;
|
import java.io.*;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
@ -27,6 +28,8 @@ public class RLController<A extends Enum> implements LearningListener {
|
||||||
@Setter
|
@Setter
|
||||||
protected float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR;
|
protected float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR;
|
||||||
@Setter
|
@Setter
|
||||||
|
protected float learningRate = LearningConfig.DEFAULT_DISCOUNT_FACTOR;
|
||||||
|
@Setter
|
||||||
protected float epsilon = LearningConfig.DEFAULT_EPSILON;
|
protected float epsilon = LearningConfig.DEFAULT_EPSILON;
|
||||||
protected Learning<A> learning;
|
protected Learning<A> learning;
|
||||||
protected boolean fastLearning;
|
protected boolean fastLearning;
|
||||||
|
@ -45,10 +48,14 @@ public class RLController<A extends Enum> implements LearningListener {
|
||||||
|
|
||||||
public void start() {
|
public void start() {
|
||||||
switch(method) {
|
switch(method) {
|
||||||
case MC_ONPOLICY_EGREEDY:
|
case MC_CONTROL_EGREEDY:
|
||||||
learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
|
learning = new MonteCarloControlEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
|
||||||
break;
|
break;
|
||||||
case TD_ONPOLICY:
|
case SARSA_EPISODIC:
|
||||||
|
learning = new SARSA<>(environment, discreteActionSpace, discountFactor, epsilon, learningRate, delay);
|
||||||
|
break;
|
||||||
|
case Q_LEARNING_OFF_POLICY_CONTROL:
|
||||||
|
learning = new QLearningOffPolicyTDControl<>(environment, discreteActionSpace, discountFactor, epsilon, learningRate, delay);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw new IllegalArgumentException("Undefined method");
|
throw new IllegalArgumentException("Undefined method");
|
||||||
|
|
|
@ -30,7 +30,6 @@ public class StateActionRow<A extends Enum> extends JTextArea {
|
||||||
|
|
||||||
protected void refreshLabels(){
|
protected void refreshLabels(){
|
||||||
if(state == null || actionValues == null) return;
|
if(state == null || actionValues == null) return;
|
||||||
System.out.println("refreshing");
|
|
||||||
StringBuilder sb = new StringBuilder(state.toString()).append("\n");
|
StringBuilder sb = new StringBuilder(state.toString()).append("\n");
|
||||||
for(Map.Entry<A, Double> actionValue: actionValues.entrySet()){
|
for(Map.Entry<A, Double> actionValue: actionValues.entrySet()){
|
||||||
sb.append("\t").append(actionValue.getKey()).append("\t").append(actionValue.getValue()).append("\n");
|
sb.append("\t").append(actionValue.getKey()).append("\t").append(actionValue.getValue()).append("\n");
|
||||||
|
|
|
@ -29,7 +29,6 @@ public class AntState implements State, Visualizable {
|
||||||
private int computeHash() {
|
private int computeHash() {
|
||||||
int hash = 7;
|
int hash = 7;
|
||||||
int prime = 31;
|
int prime = 31;
|
||||||
|
|
||||||
int unknown = 0;
|
int unknown = 0;
|
||||||
int diff = 0;
|
int diff = 0;
|
||||||
for (Cell[] cells : knownWorld) {
|
for (Cell[] cells : knownWorld) {
|
||||||
|
|
|
@ -28,9 +28,11 @@ public class Dino extends RenderObject {
|
||||||
@Override
|
@Override
|
||||||
public void tick(){
|
public void tick(){
|
||||||
// reached max jump height
|
// reached max jump height
|
||||||
if(y + dy < Config.FRAME_HEIGHT - Config.GROUND_Y -Config.OBSTACLE_SIZE - Config.MAX_JUMP_HEIGHT){
|
int topOfDino = Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE;
|
||||||
|
|
||||||
|
if(y + dy <= topOfDino - Config.MAX_JUMP_HEIGHT) {
|
||||||
fall();
|
fall();
|
||||||
}else if(y + dy >= Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE){
|
} else if(y + dy >= topOfDino) {
|
||||||
inJump = false;
|
inJump = false;
|
||||||
dy = 0;
|
dy = 0;
|
||||||
y = Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE;
|
y = Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE;
|
||||||
|
|
|
@ -56,18 +56,28 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
|
||||||
dino.jump();
|
dino.jump();
|
||||||
}
|
}
|
||||||
|
|
||||||
for(int i= 0; i < 5; ++i){
|
// for(int i= 0; i < 5; ++i){
|
||||||
|
// dino.tick();
|
||||||
|
// currentObstacle.tick();
|
||||||
|
// if(currentObstacle.getX() < -Config.OBSTACLE_SIZE){
|
||||||
|
// spawnNewObstacle();
|
||||||
|
// }
|
||||||
|
// comp.repaint();
|
||||||
|
// if(ranIntoObstacle()){
|
||||||
|
// done = true;
|
||||||
|
// break;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
dino.tick();
|
dino.tick();
|
||||||
currentObstacle.tick();
|
currentObstacle.tick();
|
||||||
if(currentObstacle.getX() < -Config.OBSTACLE_SIZE) {
|
if(currentObstacle.getX() < -Config.OBSTACLE_SIZE) {
|
||||||
spawnNewObstacle();
|
spawnNewObstacle();
|
||||||
}
|
}
|
||||||
comp.repaint();
|
|
||||||
if(ranIntoObstacle()) {
|
if(ranIntoObstacle()) {
|
||||||
|
reward = 0;
|
||||||
done = true;
|
done = true;
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return new StepResultEnvironment(new DinoStateWithSpeed(getDistanceToObstacle(), getCurrentObstacle().getDx()), reward, done, "");
|
return new StepResultEnvironment(new DinoStateWithSpeed(getDistanceToObstacle(), getCurrentObstacle().getDx()), reward, done, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ public class DinoWorldComponent extends JComponent {
|
||||||
protected void paintComponent(Graphics g) {
|
protected void paintComponent(Graphics g) {
|
||||||
super.paintComponent(g);
|
super.paintComponent(g);
|
||||||
g.setColor(Color.BLACK);
|
g.setColor(Color.BLACK);
|
||||||
g.fillRect(0, Config.FRAME_HEIGHT - Config.GROUND_Y, Config.FRAME_WIDTH, 2);
|
g.fillRect(0, Config.FRAME_HEIGHT - Config.GROUND_Y, getWidth(), 2);
|
||||||
|
|
||||||
dinoWorld.getDino().render(g);
|
dinoWorld.getDino().render(g);
|
||||||
dinoWorld.getCurrentObstacle().render(g);
|
dinoWorld.getCurrentObstacle().render(g);
|
||||||
|
|
|
@ -12,15 +12,17 @@ public class JumpingDino {
|
||||||
RNG.setSeed(55);
|
RNG.setSeed(55);
|
||||||
|
|
||||||
RLController<DinoAction> rl = new RLControllerGUI<>(
|
RLController<DinoAction> rl = new RLControllerGUI<>(
|
||||||
new DinoWorld(true, true),
|
new DinoWorld(false, false),
|
||||||
Method.MC_ONPOLICY_EGREEDY,
|
Method.Q_LEARNING_OFF_POLICY_CONTROL,
|
||||||
DinoAction.values());
|
DinoAction.values());
|
||||||
|
|
||||||
rl.setDelay(100);
|
rl.setDelay(10);
|
||||||
rl.setDiscountFactor(1f);
|
rl.setDiscountFactor(0.8f);
|
||||||
rl.setEpsilon(0.15f);
|
rl.setEpsilon(0.1f);
|
||||||
rl.setNrOfEpisodes(100000);
|
rl.setLearningRate(0.5f);
|
||||||
|
rl.setNrOfEpisodes(10000);
|
||||||
rl.start();
|
rl.start();
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,16 +3,17 @@ package example;
|
||||||
import core.RNG;
|
import core.RNG;
|
||||||
import core.algo.Method;
|
import core.algo.Method;
|
||||||
import core.controller.RLController;
|
import core.controller.RLController;
|
||||||
|
import core.controller.RLControllerGUI;
|
||||||
import evironment.antGame.AntAction;
|
import evironment.antGame.AntAction;
|
||||||
import evironment.antGame.AntWorld;
|
import evironment.antGame.AntWorld;
|
||||||
|
|
||||||
public class RunningAnt {
|
public class RunningAnt {
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
RNG.setSeed(123);
|
RNG.setSeed(56);
|
||||||
|
|
||||||
RLController<AntAction> rl = new RLController<>(
|
RLController<AntAction> rl = new RLControllerGUI<>(
|
||||||
new AntWorld(3, 3, 0.1),
|
new AntWorld(3, 3, 0.1),
|
||||||
Method.MC_ONPOLICY_EGREEDY,
|
Method.MC_CONTROL_EGREEDY,
|
||||||
AntAction.values());
|
AntAction.values());
|
||||||
|
|
||||||
rl.setDelay(200);
|
rl.setDelay(200);
|
||||||
|
|
Loading…
Reference in New Issue