add TD algorithms and started adopting to continous tasks

- add Q-Learning and SARSA
- more config variables
This commit is contained in:
Jan Löwenstrom 2020-02-17 13:56:55 +01:00
parent f4f1f7bd37
commit 77898f4e5a
16 changed files with 222 additions and 63 deletions

View File

@ -0,0 +1,5 @@
<component name="ProjectCodeStyleConfiguration">
<state>
<option name="PREFERRED_PROJECT_CODE_STYLE" value="Default" />
</state>
</component>

View File

@ -2,6 +2,9 @@ package core;
public class LearningConfig {
public static final int DEFAULT_DELAY = 30;
public static final int DEFAULT_NR_OF_EPISODES = 10000;
public static final float DEFAULT_EPSILON = 0.1f;
public static final float DEFAULT_DISCOUNT_FACTOR = 1.0f;
// Learning rate
public static final float DEFAULT_ALPHA = 0.9f;
}

View File

@ -2,6 +2,7 @@ package core.algo;
import core.DiscreteActionSpace;
import core.Environment;
import core.LearningConfig;
import core.StepResult;
import core.listener.LearningListener;
import lombok.Getter;
@ -16,7 +17,7 @@ import java.util.concurrent.atomic.AtomicInteger;
public abstract class EpisodicLearning<A extends Enum> extends Learning<A> implements Episodic {
@Setter
protected int currentEpisode;
protected int currentEpisode = 0;
protected volatile AtomicInteger episodesToLearn = new AtomicInteger(0);
@Getter
protected volatile int episodePerSecond;
@ -81,7 +82,7 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
@Override
public void learn(){
// TODO remove or learn with default episode number
learn(LearningConfig.DEFAULT_NR_OF_EPISODES);
}
private void startLearning(){
@ -132,6 +133,15 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
return episodesToLearn.get();
}
public int getCurrentEpisode() {
return currentEpisode;
}
public int getEpisodesPerSecond() {
return episodePerSecond;
}
@Override
public synchronized void save(ObjectOutputStream oos) throws IOException {
super.save(oos);

View File

@ -30,21 +30,20 @@ import java.util.*;
*
* @param <A>
*/
public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<A> {
public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A> {
private Map<Pair<State, A>, Double> returnSum;
private Map<Pair<State, A>, Integer> returnCount;
public MonteCarloOnPolicyEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) {
super(environment, actionSpace, discountFactor, delay);
currentEpisode = 0;
this.policy = new EpsilonGreedyPolicy<>(epsilon);
this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace);
returnSum = new HashMap<>();
returnCount = new HashMap<>();
}
public MonteCarloOnPolicyEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay);
}
@ -59,7 +58,7 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
}
sumOfRewards = 0;
StepResultEnvironment envResult = null;
while(envResult == null || !envResult.isDone()){
while(envResult == null || !envResult.isDone()) {
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
A chosenAction = policy.chooseAction(actionValues);
envResult = environment.step(chosenAction);
@ -77,26 +76,26 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
dispatchStepEnd();
}
// System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
// System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
Set<Pair<State, A>> stateActionPairs = new LinkedHashSet<>();
for (StepResult<A> sr : episode) {
for(StepResult<A> sr : episode) {
stateActionPairs.add(new ImmutablePair<>(sr.getState(), sr.getAction()));
}
//System.out.println("stateActionPairs " + stateActionPairs.size());
for (Pair<State, A> stateActionPair : stateActionPairs) {
for(Pair<State, A> stateActionPair : stateActionPairs) {
int firstOccurenceIndex = 0;
// find first occurance of state action pair
for (StepResult<A> sr : episode) {
if (stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())) {
for(StepResult<A> sr : episode) {
if(stateActionPair.getKey().equals(sr.getState()) && stateActionPair.getValue().equals(sr.getAction())) {
break;
}
firstOccurenceIndex++;
}
double G = 0;
for (int l = firstOccurenceIndex; l < episode.size(); ++l) {
for(int l = firstOccurenceIndex; l < episode.size(); ++l) {
G += episode.get(l).getReward() * (Math.pow(discountFactor, l - firstOccurenceIndex));
}
// slick trick to add G to the entry.
@ -107,16 +106,6 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
}
}
@Override
public int getCurrentEpisode() {
return currentEpisode;
}
@Override
public int getEpisodesPerSecond(){
return episodePerSecond;
}
@Override
public void save(ObjectOutputStream oos) throws IOException {
super.save(oos);

View File

@ -5,5 +5,5 @@ package core.algo;
* which RL-algorithm should be used.
*/
public enum Method {
MC_ONPOLICY_EGREEDY, TD_ONPOLICY
MC_CONTROL_EGREEDY, SARSA_EPISODIC, Q_LEARNING_OFF_POLICY_CONTROL
}

View File

@ -0,0 +1,68 @@
package core.algo.td;
import core.*;
import core.algo.EpisodicLearning;
import core.policy.EpsilonGreedyPolicy;
import core.policy.GreedyPolicy;
import core.policy.Policy;
import java.util.Map;
public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearning<A> {
private float alpha;
private Policy<A> greedyPolicy = new GreedyPolicy<>();
public QLearningOffPolicyTDControl(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, float learningRate, int delay) {
super(environment, actionSpace, discountFactor, delay);
alpha = learningRate;
this.policy = new EpsilonGreedyPolicy<>(epsilon);
this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace);
}
public QLearningOffPolicyTDControl(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, LearningConfig.DEFAULT_ALPHA, delay);
}
@Override
protected void nextEpisode() {
State state = environment.reset();
try {
Thread.sleep(delay);
} catch (InterruptedException e) {
e.printStackTrace();
}
StepResultEnvironment envResult = null;
Map<A, Double> actionValues = null;
sumOfRewards = 0;
while(envResult == null || !envResult.isDone()) {
actionValues = stateActionTable.getActionValues(state);
A action = policy.chooseAction(actionValues);
// Take a step
envResult = environment.step(action);
double reward = envResult.getReward();
State nextState = envResult.getState();
sumOfRewards += reward;
// Q Update
double currentQValue = stateActionTable.getActionValues(state).get(action);
// maxQ(S', a);
// Using intern "greedy policy" as a helper to determine the highest action-value
double highestValueNextState = stateActionTable.getActionValues(nextState).get(greedyPolicy.chooseAction(stateActionTable.getActionValues(nextState)));
double updatedQValue = currentQValue + alpha * (reward + discountFactor * highestValueNextState - currentQValue);
stateActionTable.setValue(state, action, updatedQValue);
state = nextState;
try {
Thread.sleep(delay);
} catch (InterruptedException e) {
e.printStackTrace();
}
dispatchStepEnd();
}
}
}

View File

@ -0,0 +1,68 @@
package core.algo.td;
import core.*;
import core.algo.EpisodicLearning;
import core.policy.EpsilonGreedyPolicy;
import java.util.Map;
public class SARSA<A extends Enum> extends EpisodicLearning<A> {
private float alpha;
public SARSA(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, float learningRate, int delay) {
super(environment, actionSpace, discountFactor, delay);
alpha = learningRate;
this.policy = new EpsilonGreedyPolicy<>(epsilon);
this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace);
}
public SARSA(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, LearningConfig.DEFAULT_ALPHA, delay);
}
@Override
protected void nextEpisode() {
State state = environment.reset();
try {
Thread.sleep(delay);
} catch (InterruptedException e) {
e.printStackTrace();
}
StepResultEnvironment envResult = null;
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
A action = policy.chooseAction(actionValues);
sumOfRewards = 0;
while(envResult == null || !envResult.isDone()) {
// Take a step
envResult = environment.step(action);
sumOfRewards += envResult.getReward();
State nextState = envResult.getState();
// Pick next action
actionValues = stateActionTable.getActionValues(nextState);
A nextAction = policy.chooseAction(actionValues);
// TD update
// target = reward + gamma * Q(nextState, nextAction)
double currentQValue = stateActionTable.getActionValues(state).get(action);
double nextQValue = stateActionTable.getActionValues(nextState).get(nextAction);
double reward = envResult.getReward();
double updatedQValue = currentQValue + alpha * (reward + discountFactor * nextQValue - currentQValue);
stateActionTable.setValue(state, action, updatedQValue);
state = nextState;
action = nextAction;
try {
Thread.sleep(delay);
} catch (InterruptedException e) {
e.printStackTrace();
}
dispatchStepEnd();
}
}
}

View File

@ -1,4 +0,0 @@
package core.algo.TD;
public class TemporalDifferenceOnPolicy {
}

View File

@ -1,18 +1,19 @@
package core.controller;
import core.*;
import core.DiscreteActionSpace;
import core.Environment;
import core.LearningConfig;
import core.ListDiscreteActionSpace;
import core.algo.EpisodicLearning;
import core.algo.Learning;
import core.algo.Method;
import core.algo.mc.MonteCarloOnPolicyEGreedy;
import core.gui.LearningView;
import core.gui.View;
import core.algo.mc.MonteCarloControlEGreedy;
import core.algo.td.QLearningOffPolicyTDControl;
import core.algo.td.SARSA;
import core.listener.LearningListener;
import core.listener.ViewListener;
import core.policy.EpsilonPolicy;
import lombok.Setter;
import javax.swing.*;
import java.io.*;
import java.util.List;
@ -27,6 +28,8 @@ public class RLController<A extends Enum> implements LearningListener {
@Setter
protected float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR;
@Setter
protected float learningRate = LearningConfig.DEFAULT_DISCOUNT_FACTOR;
@Setter
protected float epsilon = LearningConfig.DEFAULT_EPSILON;
protected Learning<A> learning;
protected boolean fastLearning;
@ -45,10 +48,14 @@ public class RLController<A extends Enum> implements LearningListener {
public void start() {
switch(method) {
case MC_ONPOLICY_EGREEDY:
learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
case MC_CONTROL_EGREEDY:
learning = new MonteCarloControlEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
break;
case TD_ONPOLICY:
case SARSA_EPISODIC:
learning = new SARSA<>(environment, discreteActionSpace, discountFactor, epsilon, learningRate, delay);
break;
case Q_LEARNING_OFF_POLICY_CONTROL:
learning = new QLearningOffPolicyTDControl<>(environment, discreteActionSpace, discountFactor, epsilon, learningRate, delay);
break;
default:
throw new IllegalArgumentException("Undefined method");

View File

@ -30,7 +30,6 @@ public class StateActionRow<A extends Enum> extends JTextArea {
protected void refreshLabels(){
if(state == null || actionValues == null) return;
System.out.println("refreshing");
StringBuilder sb = new StringBuilder(state.toString()).append("\n");
for(Map.Entry<A, Double> actionValue: actionValues.entrySet()){
sb.append("\t").append(actionValue.getKey()).append("\t").append(actionValue.getValue()).append("\n");

View File

@ -29,7 +29,6 @@ public class AntState implements State, Visualizable {
private int computeHash() {
int hash = 7;
int prime = 31;
int unknown = 0;
int diff = 0;
for (Cell[] cells : knownWorld) {

View File

@ -28,9 +28,11 @@ public class Dino extends RenderObject {
@Override
public void tick(){
// reached max jump height
if(y + dy < Config.FRAME_HEIGHT - Config.GROUND_Y -Config.OBSTACLE_SIZE - Config.MAX_JUMP_HEIGHT){
int topOfDino = Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE;
if(y + dy <= topOfDino - Config.MAX_JUMP_HEIGHT) {
fall();
}else if(y + dy >= Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE){
} else if(y + dy >= topOfDino) {
inJump = false;
dy = 0;
y = Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE;

View File

@ -56,18 +56,28 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
dino.jump();
}
for(int i= 0; i < 5; ++i){
dino.tick();
currentObstacle.tick();
if(currentObstacle.getX() < -Config.OBSTACLE_SIZE){
spawnNewObstacle();
}
comp.repaint();
if(ranIntoObstacle()){
done = true;
break;
}
// for(int i= 0; i < 5; ++i){
// dino.tick();
// currentObstacle.tick();
// if(currentObstacle.getX() < -Config.OBSTACLE_SIZE){
// spawnNewObstacle();
// }
// comp.repaint();
// if(ranIntoObstacle()){
// done = true;
// break;
// }
// }
dino.tick();
currentObstacle.tick();
if(currentObstacle.getX() < -Config.OBSTACLE_SIZE) {
spawnNewObstacle();
}
if(ranIntoObstacle()) {
reward = 0;
done = true;
}
return new StepResultEnvironment(new DinoStateWithSpeed(getDistanceToObstacle(), getCurrentObstacle().getDx()), reward, done, "");
}

View File

@ -19,7 +19,7 @@ public class DinoWorldComponent extends JComponent {
protected void paintComponent(Graphics g) {
super.paintComponent(g);
g.setColor(Color.BLACK);
g.fillRect(0, Config.FRAME_HEIGHT - Config.GROUND_Y, Config.FRAME_WIDTH, 2);
g.fillRect(0, Config.FRAME_HEIGHT - Config.GROUND_Y, getWidth(), 2);
dinoWorld.getDino().render(g);
dinoWorld.getCurrentObstacle().render(g);

View File

@ -12,15 +12,17 @@ public class JumpingDino {
RNG.setSeed(55);
RLController<DinoAction> rl = new RLControllerGUI<>(
new DinoWorld(true, true),
Method.MC_ONPOLICY_EGREEDY,
new DinoWorld(false, false),
Method.Q_LEARNING_OFF_POLICY_CONTROL,
DinoAction.values());
rl.setDelay(100);
rl.setDiscountFactor(1f);
rl.setEpsilon(0.15f);
rl.setNrOfEpisodes(100000);
rl.setDelay(10);
rl.setDiscountFactor(0.8f);
rl.setEpsilon(0.1f);
rl.setLearningRate(0.5f);
rl.setNrOfEpisodes(10000);
rl.start();
}
}

View File

@ -3,16 +3,17 @@ package example;
import core.RNG;
import core.algo.Method;
import core.controller.RLController;
import core.controller.RLControllerGUI;
import evironment.antGame.AntAction;
import evironment.antGame.AntWorld;
public class RunningAnt {
public static void main(String[] args) {
RNG.setSeed(123);
RNG.setSeed(56);
RLController<AntAction> rl = new RLController<>(
RLController<AntAction> rl = new RLControllerGUI<>(
new AntWorld(3, 3, 0.1),
Method.MC_ONPOLICY_EGREEDY,
Method.MC_CONTROL_EGREEDY,
AntAction.values());
rl.setDelay(200);