enhance save/load feature and change thread handling

- saving monte carlo did not include returnSum and returnCount, so it the state would be wrong after loading. Learning, EpisodicLearning and MonteCarlo classes are all overriding custom save and load methods, calling super() each time but including fields that are necessary to replace on runtime.
- moved generic episodic behaviour from monteCarlo to abstract top level class
- using AtomicInteger for episodesToLearn
- moved learning-Thread-handling from controller to model. Learning got one extra Leaning thread.
- add feature to use custom speed and distance for dino world obstacles
This commit is contained in:
Jan Löwenstrom 2019-12-29 01:12:11 +01:00
parent 64355e0b93
commit 195722e98f
8 changed files with 193 additions and 75 deletions

3
.gitignore vendored
View File

@ -1,3 +1,6 @@
learningStates/*
!learningStates/.gitkeep
.idea/refo.iml
.idea/misc.xml
.idea/modules.xml

0
learningStates/.gitkeep Normal file
View File

View File

@ -2,59 +2,53 @@ package core.algo;
import core.DiscreteActionSpace;
import core.Environment;
import core.StepResult;
import core.listener.LearningListener;
import lombok.Getter;
import lombok.Setter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
public abstract class EpisodicLearning<A extends Enum> extends Learning<A> implements Episodic {
@Setter
protected int currentEpisode;
protected int episodesToLearn;
protected volatile AtomicInteger episodesToLearn = new AtomicInteger(0);
@Getter
protected volatile int episodePerSecond;
protected int episodeSumCurrentSecond;
private volatile boolean measureEpisodeBenchMark;
protected double sumOfRewards;
protected List<StepResult<A>> episode = new ArrayList<>();
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay) {
super(environment, actionSpace, discountFactor, delay);
initBenchMarking();
}
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor) {
super(environment, actionSpace, discountFactor);
initBenchMarking();
}
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
super(environment, actionSpace, delay);
initBenchMarking();
}
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace) {
super(environment, actionSpace);
initBenchMarking();
}
protected void dispatchEpisodeEnd(double recentSumOfRewards){
++episodeSumCurrentSecond;
if(rewardHistory.size() > 10000){
rewardHistory.clear();
}
rewardHistory.add(recentSumOfRewards);
for(LearningListener l: learningListeners) {
l.onEpisodeEnd(rewardHistory);
}
}
protected abstract void nextEpisode();
protected void dispatchEpisodeStart(){
for(LearningListener l: learningListeners){
l.onEpisodeStart();
}
}
@Override
public void learn(){
learn(0);
}
public void learn(int nrOfEpisodes){
measureEpisodeBenchMark = true;
private void initBenchMarking(){
new Thread(()->{
while(measureEpisodeBenchMark){
while (true){
episodePerSecond = episodeSumCurrentSecond;
episodeSumCurrentSecond = 0;
try {
@ -64,24 +58,89 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
}
}
}).start();
episodesToLearn += nrOfEpisodes;
dispatchLearningStart();
for(int i=0; i < nrOfEpisodes; ++i){
nextEpisode();
}
dispatchLearningEnd();
measureEpisodeBenchMark = false;
}
protected abstract void nextEpisode();
protected void dispatchEpisodeEnd(){
++episodeSumCurrentSecond;
if(rewardHistory.size() > 10000){
rewardHistory.clear();
}
rewardHistory.add(sumOfRewards);
for(LearningListener l: learningListeners) {
l.onEpisodeEnd(rewardHistory);
}
}
protected void dispatchEpisodeStart(){
++currentEpisode;
episodesToLearn.decrementAndGet();
for(LearningListener l: learningListeners){
l.onEpisodeStart();
}
}
@Override
public int getCurrentEpisode(){
return currentEpisode;
public void learn(){
// TODO remove or learn with default episode number
}
private void startLearning(){
learningExecutor.submit(()->{
dispatchLearningStart();
while(episodesToLearn.get() > 0){
dispatchEpisodeStart();
nextEpisode();
dispatchEpisodeEnd();
}
synchronized (this){
dispatchLearningEnd();
notifyAll();
}
});
}
/**
* Stopping the while loop by setting episodesToLearn to 0.
* The current episode can not be interrupted, so the sleep delay
* is removed and the calling thread has to wait until the
* current episode is done.
* Resetting the delay afterwards.
*/
@Override
public synchronized void interruptLearning(){
episodesToLearn.set(0);
int prevDelay = delay;
delay = 0;
while(currentlyLearning) {
try {
wait();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
delay = prevDelay;
}
public synchronized void learn(int nrOfEpisodes){
boolean isLearning = episodesToLearn.getAndAdd(nrOfEpisodes) != 0;
if(!isLearning)
startLearning();
}
@Override
public int getEpisodesToGo(){
return episodesToLearn - currentEpisode;
return episodesToLearn.get();
}
@Override
public synchronized void save(ObjectOutputStream oos) throws IOException {
super.save(oos);
oos.writeInt(currentEpisode);
}
@Override
public synchronized void load(ObjectInputStream ois) throws IOException, ClassNotFoundException {
super.load(ois);
currentEpisode = ois.readInt();
}
}

View File

@ -9,10 +9,16 @@ import core.policy.Policy;
import lombok.Getter;
import lombok.Setter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
/**
*
@ -30,14 +36,18 @@ public abstract class Learning<A extends Enum>{
@Setter
protected int delay;
protected List<Double> rewardHistory;
protected ExecutorService learningExecutor;
protected boolean currentlyLearning;
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay) {
this.environment = environment;
this.actionSpace = actionSpace;
this.discountFactor = discountFactor;
this.delay = delay;
currentlyLearning = false;
learningListeners = new HashSet<>();
rewardHistory = new CopyOnWriteArrayList<>();
learningExecutor = Executors.newSingleThreadExecutor();
}
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor) {
@ -52,7 +62,6 @@ public abstract class Learning<A extends Enum>{
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_DELAY);
}
public abstract void learn();
public void addListener(LearningListener learningListener) {
@ -66,15 +75,31 @@ public abstract class Learning<A extends Enum>{
}
protected void dispatchLearningStart() {
currentlyLearning = true;
for (LearningListener l : learningListeners) {
l.onLearningStart();
}
}
protected void dispatchLearningEnd() {
currentlyLearning = false;
for (LearningListener l : learningListeners) {
l.onLearningEnd();
}
}
public synchronized void interruptLearning(){
//TODO: for non episodic learning
}
public void save(ObjectOutputStream oos) throws IOException {
oos.writeObject(rewardHistory);
oos.writeObject(stateActionTable);
}
public void load(ObjectInputStream ois) throws IOException, ClassNotFoundException {
rewardHistory = (List<Double>) ois.readObject();
stateActionTable = (StateActionTable<A>) ois.readObject();
}
}

View File

@ -5,6 +5,9 @@ import core.algo.EpisodicLearning;
import core.policy.EpsilonGreedyPolicy;
import javafx.util.Pair;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.*;
/**
@ -44,19 +47,16 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay);
}
@Override
public void nextEpisode() {
++currentEpisode;
List<StepResult<A>> episode = new ArrayList<>();
episode = new ArrayList<>();
State state = environment.reset();
dispatchEpisodeStart();
try {
Thread.sleep(delay);
} catch (InterruptedException e) {
e.printStackTrace();
}
double sumOfRewards = 0;
sumOfRewards = 0;
StepResultEnvironment envResult = null;
while(envResult == null || !envResult.isDone()){
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
@ -76,7 +76,6 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
dispatchStepEnd();
}
dispatchEpisodeEnd(sumOfRewards);
// System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
Set<Pair<State, A>> stateActionPairs = new LinkedHashSet<>();
@ -115,4 +114,18 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
public int getEpisodesPerSecond(){
return episodePerSecond;
}
@Override
public void save(ObjectOutputStream oos) throws IOException {
super.save(oos);
oos.writeObject(returnSum);
oos.writeObject(returnCount);
}
@Override
public void load(ObjectInputStream ois) throws IOException, ClassNotFoundException {
super.load(ois);
returnSum = (Map<Pair<State, A>, Double>) ois.readObject();
returnCount = (Map<Pair<State, A>, Integer>) ois.readObject();
}
}

View File

@ -18,6 +18,7 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
public class RLController<A extends Enum> implements ViewListener, LearningListener {
private final String folderPrefix = "learningStates" + File.separator;
private Environment<A> environment;
private DiscreteActionSpace<A> discreteActionSpace;
private Method method;
@ -26,15 +27,12 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
private float epsilon = LearningConfig.DEFAULT_EPSILON;
private Learning<A> learning;
private LearningView learningView;
private ExecutorService learningExecutor;
private boolean currentlyLearning;
private boolean fastLearning;
private List<Double> latestRewardsHistory;
private int nrOfEpisodes;
private int prevDelay;
public RLController(Environment<A> env, Method method, A... actions){
learningExecutor = Executors.newSingleThreadExecutor();
setEnvironment(env);
setMethod(method);
setAllowedActions(actions);
@ -64,9 +62,9 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
private void initLearning(){
if(learning instanceof EpisodicLearning){
learningExecutor.submit(()->((EpisodicLearning) learning).learn(nrOfEpisodes));
((EpisodicLearning) learning).learn(nrOfEpisodes);
}else{
learningExecutor.submit(()->learning.learn());
learning.learn();
}
}
@ -75,27 +73,25 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
*************************************************/
@Override
public void onLearnMoreEpisodes(int nrOfEpisodes){
if(!currentlyLearning){
if(learning instanceof EpisodicLearning){
learningExecutor.submit(()->((EpisodicLearning) learning).learn(nrOfEpisodes));
}else{
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
}
if(learning instanceof EpisodicLearning){
((EpisodicLearning) learning).learn(nrOfEpisodes);
}else{
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
}
learningView.updateLearningInfoPanel();
}
@Override
public void onLoadState(String fileName) {
FileInputStream fis;
ObjectInput in;
ObjectInputStream in;
try {
fis = new FileInputStream(fileName);
in = new ObjectInputStream(fis);
SaveState<A> saveState = (SaveState<A>) in.readObject();
learning.setStateActionTable(saveState.getStateActionTable());
if(learning instanceof EpisodicLearning){
((EpisodicLearning) learning).setCurrentEpisode(saveState.getCurrentEpisode());
}
System.out.println("interrup" + Thread.currentThread().getId());
learning.interruptLearning();
learning.load(in);
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
in.close();
} catch (IOException | ClassNotFoundException e) {
e.printStackTrace();
@ -107,15 +103,10 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
FileOutputStream fos;
ObjectOutputStream out;
try{
fos = new FileOutputStream(fileName);
fos = new FileOutputStream(folderPrefix + fileName);
out = new ObjectOutputStream(fos);
int currentEpisode;
if(learning instanceof EpisodicLearning){
currentEpisode = ((EpisodicLearning) learning).getCurrentEpisode();
}else{
currentEpisode = 0;
}
out.writeObject(new SaveState<>(learning.getStateActionTable(), currentEpisode));
learning.interruptLearning();
learning.save(out);
out.close();
}catch (IOException e){
e.printStackTrace();
@ -158,13 +149,12 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
*************************************************/
@Override
public void onLearningStart() {
currentlyLearning = true;
}
@Override
public void onLearningEnd() {
currentlyLearning = false;
SwingUtilities.invokeLater(()-> learningView.updateRewardGraph(latestRewardsHistory));
onSaveState( method.toString() + System.currentTimeMillis()/1000 + (learning instanceof EpisodicLearning ? "e " + ((EpisodicLearning) learning).getCurrentEpisode() : ""));
}
@Override
@ -192,7 +182,7 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
/*************************************************
** SETTER **
** SETTERS **
*************************************************/
private void setEnvironment(Environment<A> environment){

View File

@ -14,12 +14,20 @@ import java.awt.*;
public class DinoWorld implements Environment<DinoAction>, Visualizable {
private Dino dino;
private Obstacle currentObstacle;
private boolean randomObstacleSpeed;
private boolean randomObstacleDistance;
public DinoWorld(){
public DinoWorld(boolean randomObstacleSpeed, boolean randomObstacleDistance){
this.randomObstacleSpeed = randomObstacleSpeed;
this.randomObstacleDistance = randomObstacleDistance;
dino = new Dino(Config.DINO_SIZE, Config.DINO_STARTING_X, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE, 0, 0, Color.GREEN);
spawnNewObstacle();
}
public DinoWorld(){
this(false, false);
}
private boolean ranIntoObstacle(){
Obstacle o = currentObstacle;
Dino p = dino;
@ -32,6 +40,7 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
return xAxis && yAxis;
}
private int getDistanceToObstacle(){
return currentObstacle.getX() - dino.getX() + Config.DINO_SIZE;
}
@ -57,8 +66,27 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
return new StepResultEnvironment(new DinoState(getDistanceToObstacle()), reward, done, "");
}
private void spawnNewObstacle(){
currentObstacle = new Obstacle(Config.OBSTACLE_SIZE, Config.FRAME_WIDTH + Config.OBSTACLE_SIZE, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE, -Config.OBSTACLE_SPEED, 0, Color.BLACK);
int dx;
int xSpawn;
if(randomObstacleSpeed){
dx = -(int)((Math.random() + 0.5) * Config.OBSTACLE_SPEED);
}else{
dx = -Config.OBSTACLE_SPEED;
}
if(randomObstacleDistance){
// randomly spawning more right outside of the screen
xSpawn = (int)(Math.random() + 0.5 * Config.FRAME_WIDTH + Config.FRAME_WIDTH + Config.OBSTACLE_SIZE);
}else{
// instantly respawning on the left screen border
xSpawn = Config.FRAME_WIDTH + Config.OBSTACLE_SIZE;
}
currentObstacle = new Obstacle(Config.OBSTACLE_SIZE, xSpawn, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE, dx, 0, Color.BLACK);
}
private void spawnDino(){

View File

@ -11,7 +11,7 @@ public class JumpingDino {
RNG.setSeed(55);
RLController<DinoAction> rl = new RLController<>(
new DinoWorld(),
new DinoWorld(true, true),
Method.MC_ONPOLICY_EGREEDY,
DinoAction.values());