enhance save/load feature and change thread handling
- saving monte carlo did not include returnSum and returnCount, so it the state would be wrong after loading. Learning, EpisodicLearning and MonteCarlo classes are all overriding custom save and load methods, calling super() each time but including fields that are necessary to replace on runtime. - moved generic episodic behaviour from monteCarlo to abstract top level class - using AtomicInteger for episodesToLearn - moved learning-Thread-handling from controller to model. Learning got one extra Leaning thread. - add feature to use custom speed and distance for dino world obstacles
This commit is contained in:
parent
64355e0b93
commit
195722e98f
|
@ -1,3 +1,6 @@
|
|||
learningStates/*
|
||||
!learningStates/.gitkeep
|
||||
|
||||
.idea/refo.iml
|
||||
.idea/misc.xml
|
||||
.idea/modules.xml
|
||||
|
|
|
@ -2,59 +2,53 @@ package core.algo;
|
|||
|
||||
import core.DiscreteActionSpace;
|
||||
import core.Environment;
|
||||
import core.StepResult;
|
||||
import core.listener.LearningListener;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.ObjectInputStream;
|
||||
import java.io.ObjectOutputStream;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
public abstract class EpisodicLearning<A extends Enum> extends Learning<A> implements Episodic {
|
||||
@Setter
|
||||
protected int currentEpisode;
|
||||
protected int episodesToLearn;
|
||||
protected volatile AtomicInteger episodesToLearn = new AtomicInteger(0);
|
||||
@Getter
|
||||
protected volatile int episodePerSecond;
|
||||
protected int episodeSumCurrentSecond;
|
||||
private volatile boolean measureEpisodeBenchMark;
|
||||
protected double sumOfRewards;
|
||||
protected List<StepResult<A>> episode = new ArrayList<>();
|
||||
|
||||
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay) {
|
||||
super(environment, actionSpace, discountFactor, delay);
|
||||
initBenchMarking();
|
||||
}
|
||||
|
||||
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor) {
|
||||
super(environment, actionSpace, discountFactor);
|
||||
initBenchMarking();
|
||||
}
|
||||
|
||||
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
|
||||
super(environment, actionSpace, delay);
|
||||
initBenchMarking();
|
||||
}
|
||||
|
||||
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace) {
|
||||
super(environment, actionSpace);
|
||||
initBenchMarking();
|
||||
}
|
||||
|
||||
protected void dispatchEpisodeEnd(double recentSumOfRewards){
|
||||
++episodeSumCurrentSecond;
|
||||
if(rewardHistory.size() > 10000){
|
||||
rewardHistory.clear();
|
||||
}
|
||||
rewardHistory.add(recentSumOfRewards);
|
||||
for(LearningListener l: learningListeners) {
|
||||
l.onEpisodeEnd(rewardHistory);
|
||||
}
|
||||
}
|
||||
protected abstract void nextEpisode();
|
||||
|
||||
protected void dispatchEpisodeStart(){
|
||||
for(LearningListener l: learningListeners){
|
||||
l.onEpisodeStart();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void learn(){
|
||||
learn(0);
|
||||
}
|
||||
|
||||
public void learn(int nrOfEpisodes){
|
||||
measureEpisodeBenchMark = true;
|
||||
private void initBenchMarking(){
|
||||
new Thread(()->{
|
||||
while(measureEpisodeBenchMark){
|
||||
while (true){
|
||||
episodePerSecond = episodeSumCurrentSecond;
|
||||
episodeSumCurrentSecond = 0;
|
||||
try {
|
||||
|
@ -64,24 +58,89 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
|
|||
}
|
||||
}
|
||||
}).start();
|
||||
episodesToLearn += nrOfEpisodes;
|
||||
dispatchLearningStart();
|
||||
for(int i=0; i < nrOfEpisodes; ++i){
|
||||
nextEpisode();
|
||||
}
|
||||
dispatchLearningEnd();
|
||||
measureEpisodeBenchMark = false;
|
||||
}
|
||||
|
||||
protected abstract void nextEpisode();
|
||||
protected void dispatchEpisodeEnd(){
|
||||
++episodeSumCurrentSecond;
|
||||
if(rewardHistory.size() > 10000){
|
||||
rewardHistory.clear();
|
||||
}
|
||||
rewardHistory.add(sumOfRewards);
|
||||
for(LearningListener l: learningListeners) {
|
||||
l.onEpisodeEnd(rewardHistory);
|
||||
}
|
||||
}
|
||||
|
||||
protected void dispatchEpisodeStart(){
|
||||
++currentEpisode;
|
||||
episodesToLearn.decrementAndGet();
|
||||
for(LearningListener l: learningListeners){
|
||||
l.onEpisodeStart();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getCurrentEpisode(){
|
||||
return currentEpisode;
|
||||
public void learn(){
|
||||
// TODO remove or learn with default episode number
|
||||
}
|
||||
|
||||
private void startLearning(){
|
||||
learningExecutor.submit(()->{
|
||||
dispatchLearningStart();
|
||||
while(episodesToLearn.get() > 0){
|
||||
dispatchEpisodeStart();
|
||||
nextEpisode();
|
||||
dispatchEpisodeEnd();
|
||||
}
|
||||
synchronized (this){
|
||||
dispatchLearningEnd();
|
||||
notifyAll();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Stopping the while loop by setting episodesToLearn to 0.
|
||||
* The current episode can not be interrupted, so the sleep delay
|
||||
* is removed and the calling thread has to wait until the
|
||||
* current episode is done.
|
||||
* Resetting the delay afterwards.
|
||||
*/
|
||||
@Override
|
||||
public synchronized void interruptLearning(){
|
||||
episodesToLearn.set(0);
|
||||
int prevDelay = delay;
|
||||
delay = 0;
|
||||
while(currentlyLearning) {
|
||||
try {
|
||||
wait();
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
delay = prevDelay;
|
||||
}
|
||||
|
||||
public synchronized void learn(int nrOfEpisodes){
|
||||
boolean isLearning = episodesToLearn.getAndAdd(nrOfEpisodes) != 0;
|
||||
if(!isLearning)
|
||||
startLearning();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getEpisodesToGo(){
|
||||
return episodesToLearn - currentEpisode;
|
||||
return episodesToLearn.get();
|
||||
}
|
||||
|
||||
@Override
|
||||
public synchronized void save(ObjectOutputStream oos) throws IOException {
|
||||
super.save(oos);
|
||||
oos.writeInt(currentEpisode);
|
||||
}
|
||||
|
||||
@Override
|
||||
public synchronized void load(ObjectInputStream ois) throws IOException, ClassNotFoundException {
|
||||
super.load(ois);
|
||||
currentEpisode = ois.readInt();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -9,10 +9,16 @@ import core.policy.Policy;
|
|||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.ObjectInputStream;
|
||||
import java.io.ObjectOutputStream;
|
||||
import java.io.Serializable;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
|
||||
/**
|
||||
*
|
||||
|
@ -30,14 +36,18 @@ public abstract class Learning<A extends Enum>{
|
|||
@Setter
|
||||
protected int delay;
|
||||
protected List<Double> rewardHistory;
|
||||
protected ExecutorService learningExecutor;
|
||||
protected boolean currentlyLearning;
|
||||
|
||||
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay) {
|
||||
this.environment = environment;
|
||||
this.actionSpace = actionSpace;
|
||||
this.discountFactor = discountFactor;
|
||||
this.delay = delay;
|
||||
currentlyLearning = false;
|
||||
learningListeners = new HashSet<>();
|
||||
rewardHistory = new CopyOnWriteArrayList<>();
|
||||
learningExecutor = Executors.newSingleThreadExecutor();
|
||||
}
|
||||
|
||||
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor) {
|
||||
|
@ -52,7 +62,6 @@ public abstract class Learning<A extends Enum>{
|
|||
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_DELAY);
|
||||
}
|
||||
|
||||
|
||||
public abstract void learn();
|
||||
|
||||
public void addListener(LearningListener learningListener) {
|
||||
|
@ -66,15 +75,31 @@ public abstract class Learning<A extends Enum>{
|
|||
}
|
||||
|
||||
protected void dispatchLearningStart() {
|
||||
currentlyLearning = true;
|
||||
for (LearningListener l : learningListeners) {
|
||||
l.onLearningStart();
|
||||
}
|
||||
}
|
||||
|
||||
protected void dispatchLearningEnd() {
|
||||
currentlyLearning = false;
|
||||
for (LearningListener l : learningListeners) {
|
||||
l.onLearningEnd();
|
||||
}
|
||||
}
|
||||
|
||||
public synchronized void interruptLearning(){
|
||||
//TODO: for non episodic learning
|
||||
}
|
||||
|
||||
|
||||
public void save(ObjectOutputStream oos) throws IOException {
|
||||
oos.writeObject(rewardHistory);
|
||||
oos.writeObject(stateActionTable);
|
||||
}
|
||||
|
||||
public void load(ObjectInputStream ois) throws IOException, ClassNotFoundException {
|
||||
rewardHistory = (List<Double>) ois.readObject();
|
||||
stateActionTable = (StateActionTable<A>) ois.readObject();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,9 @@ import core.algo.EpisodicLearning;
|
|||
import core.policy.EpsilonGreedyPolicy;
|
||||
import javafx.util.Pair;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.ObjectInputStream;
|
||||
import java.io.ObjectOutputStream;
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
|
@ -44,19 +47,16 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
|
|||
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public void nextEpisode() {
|
||||
++currentEpisode;
|
||||
List<StepResult<A>> episode = new ArrayList<>();
|
||||
episode = new ArrayList<>();
|
||||
State state = environment.reset();
|
||||
dispatchEpisodeStart();
|
||||
try {
|
||||
Thread.sleep(delay);
|
||||
} catch (InterruptedException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
double sumOfRewards = 0;
|
||||
sumOfRewards = 0;
|
||||
StepResultEnvironment envResult = null;
|
||||
while(envResult == null || !envResult.isDone()){
|
||||
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
|
||||
|
@ -76,7 +76,6 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
|
|||
dispatchStepEnd();
|
||||
}
|
||||
|
||||
dispatchEpisodeEnd(sumOfRewards);
|
||||
// System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
|
||||
Set<Pair<State, A>> stateActionPairs = new LinkedHashSet<>();
|
||||
|
||||
|
@ -115,4 +114,18 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
|
|||
public int getEpisodesPerSecond(){
|
||||
return episodePerSecond;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void save(ObjectOutputStream oos) throws IOException {
|
||||
super.save(oos);
|
||||
oos.writeObject(returnSum);
|
||||
oos.writeObject(returnCount);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void load(ObjectInputStream ois) throws IOException, ClassNotFoundException {
|
||||
super.load(ois);
|
||||
returnSum = (Map<Pair<State, A>, Double>) ois.readObject();
|
||||
returnCount = (Map<Pair<State, A>, Integer>) ois.readObject();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,6 +18,7 @@ import java.util.concurrent.ExecutorService;
|
|||
import java.util.concurrent.Executors;
|
||||
|
||||
public class RLController<A extends Enum> implements ViewListener, LearningListener {
|
||||
private final String folderPrefix = "learningStates" + File.separator;
|
||||
private Environment<A> environment;
|
||||
private DiscreteActionSpace<A> discreteActionSpace;
|
||||
private Method method;
|
||||
|
@ -26,15 +27,12 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
|||
private float epsilon = LearningConfig.DEFAULT_EPSILON;
|
||||
private Learning<A> learning;
|
||||
private LearningView learningView;
|
||||
private ExecutorService learningExecutor;
|
||||
private boolean currentlyLearning;
|
||||
private boolean fastLearning;
|
||||
private List<Double> latestRewardsHistory;
|
||||
private int nrOfEpisodes;
|
||||
private int prevDelay;
|
||||
|
||||
public RLController(Environment<A> env, Method method, A... actions){
|
||||
learningExecutor = Executors.newSingleThreadExecutor();
|
||||
setEnvironment(env);
|
||||
setMethod(method);
|
||||
setAllowedActions(actions);
|
||||
|
@ -64,9 +62,9 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
|||
|
||||
private void initLearning(){
|
||||
if(learning instanceof EpisodicLearning){
|
||||
learningExecutor.submit(()->((EpisodicLearning) learning).learn(nrOfEpisodes));
|
||||
((EpisodicLearning) learning).learn(nrOfEpisodes);
|
||||
}else{
|
||||
learningExecutor.submit(()->learning.learn());
|
||||
learning.learn();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -75,27 +73,25 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
|||
*************************************************/
|
||||
@Override
|
||||
public void onLearnMoreEpisodes(int nrOfEpisodes){
|
||||
if(!currentlyLearning){
|
||||
if(learning instanceof EpisodicLearning){
|
||||
learningExecutor.submit(()->((EpisodicLearning) learning).learn(nrOfEpisodes));
|
||||
}else{
|
||||
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
|
||||
}
|
||||
if(learning instanceof EpisodicLearning){
|
||||
((EpisodicLearning) learning).learn(nrOfEpisodes);
|
||||
}else{
|
||||
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
|
||||
}
|
||||
learningView.updateLearningInfoPanel();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onLoadState(String fileName) {
|
||||
FileInputStream fis;
|
||||
ObjectInput in;
|
||||
ObjectInputStream in;
|
||||
try {
|
||||
fis = new FileInputStream(fileName);
|
||||
in = new ObjectInputStream(fis);
|
||||
SaveState<A> saveState = (SaveState<A>) in.readObject();
|
||||
learning.setStateActionTable(saveState.getStateActionTable());
|
||||
if(learning instanceof EpisodicLearning){
|
||||
((EpisodicLearning) learning).setCurrentEpisode(saveState.getCurrentEpisode());
|
||||
}
|
||||
System.out.println("interrup" + Thread.currentThread().getId());
|
||||
learning.interruptLearning();
|
||||
learning.load(in);
|
||||
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
|
||||
in.close();
|
||||
} catch (IOException | ClassNotFoundException e) {
|
||||
e.printStackTrace();
|
||||
|
@ -107,15 +103,10 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
|||
FileOutputStream fos;
|
||||
ObjectOutputStream out;
|
||||
try{
|
||||
fos = new FileOutputStream(fileName);
|
||||
fos = new FileOutputStream(folderPrefix + fileName);
|
||||
out = new ObjectOutputStream(fos);
|
||||
int currentEpisode;
|
||||
if(learning instanceof EpisodicLearning){
|
||||
currentEpisode = ((EpisodicLearning) learning).getCurrentEpisode();
|
||||
}else{
|
||||
currentEpisode = 0;
|
||||
}
|
||||
out.writeObject(new SaveState<>(learning.getStateActionTable(), currentEpisode));
|
||||
learning.interruptLearning();
|
||||
learning.save(out);
|
||||
out.close();
|
||||
}catch (IOException e){
|
||||
e.printStackTrace();
|
||||
|
@ -158,13 +149,12 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
|||
*************************************************/
|
||||
@Override
|
||||
public void onLearningStart() {
|
||||
currentlyLearning = true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onLearningEnd() {
|
||||
currentlyLearning = false;
|
||||
SwingUtilities.invokeLater(()-> learningView.updateRewardGraph(latestRewardsHistory));
|
||||
onSaveState( method.toString() + System.currentTimeMillis()/1000 + (learning instanceof EpisodicLearning ? "e " + ((EpisodicLearning) learning).getCurrentEpisode() : ""));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -192,7 +182,7 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
|||
|
||||
|
||||
/*************************************************
|
||||
** SETTER **
|
||||
** SETTERS **
|
||||
*************************************************/
|
||||
|
||||
private void setEnvironment(Environment<A> environment){
|
||||
|
|
|
@ -14,12 +14,20 @@ import java.awt.*;
|
|||
public class DinoWorld implements Environment<DinoAction>, Visualizable {
|
||||
private Dino dino;
|
||||
private Obstacle currentObstacle;
|
||||
private boolean randomObstacleSpeed;
|
||||
private boolean randomObstacleDistance;
|
||||
|
||||
public DinoWorld(){
|
||||
public DinoWorld(boolean randomObstacleSpeed, boolean randomObstacleDistance){
|
||||
this.randomObstacleSpeed = randomObstacleSpeed;
|
||||
this.randomObstacleDistance = randomObstacleDistance;
|
||||
dino = new Dino(Config.DINO_SIZE, Config.DINO_STARTING_X, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE, 0, 0, Color.GREEN);
|
||||
spawnNewObstacle();
|
||||
}
|
||||
|
||||
public DinoWorld(){
|
||||
this(false, false);
|
||||
}
|
||||
|
||||
private boolean ranIntoObstacle(){
|
||||
Obstacle o = currentObstacle;
|
||||
Dino p = dino;
|
||||
|
@ -32,6 +40,7 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
|
|||
|
||||
return xAxis && yAxis;
|
||||
}
|
||||
|
||||
private int getDistanceToObstacle(){
|
||||
return currentObstacle.getX() - dino.getX() + Config.DINO_SIZE;
|
||||
}
|
||||
|
@ -57,8 +66,27 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
|
|||
|
||||
return new StepResultEnvironment(new DinoState(getDistanceToObstacle()), reward, done, "");
|
||||
}
|
||||
|
||||
|
||||
private void spawnNewObstacle(){
|
||||
currentObstacle = new Obstacle(Config.OBSTACLE_SIZE, Config.FRAME_WIDTH + Config.OBSTACLE_SIZE, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE, -Config.OBSTACLE_SPEED, 0, Color.BLACK);
|
||||
int dx;
|
||||
int xSpawn;
|
||||
|
||||
if(randomObstacleSpeed){
|
||||
dx = -(int)((Math.random() + 0.5) * Config.OBSTACLE_SPEED);
|
||||
}else{
|
||||
dx = -Config.OBSTACLE_SPEED;
|
||||
}
|
||||
|
||||
if(randomObstacleDistance){
|
||||
// randomly spawning more right outside of the screen
|
||||
xSpawn = (int)(Math.random() + 0.5 * Config.FRAME_WIDTH + Config.FRAME_WIDTH + Config.OBSTACLE_SIZE);
|
||||
}else{
|
||||
// instantly respawning on the left screen border
|
||||
xSpawn = Config.FRAME_WIDTH + Config.OBSTACLE_SIZE;
|
||||
}
|
||||
|
||||
currentObstacle = new Obstacle(Config.OBSTACLE_SIZE, xSpawn, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE, dx, 0, Color.BLACK);
|
||||
}
|
||||
|
||||
private void spawnDino(){
|
||||
|
|
|
@ -11,7 +11,7 @@ public class JumpingDino {
|
|||
RNG.setSeed(55);
|
||||
|
||||
RLController<DinoAction> rl = new RLController<>(
|
||||
new DinoWorld(),
|
||||
new DinoWorld(true, true),
|
||||
Method.MC_ONPOLICY_EGREEDY,
|
||||
DinoAction.values());
|
||||
|
||||
|
|
Loading…
Reference in New Issue