enhance save/load feature and change thread handling
- saving monte carlo did not include returnSum and returnCount, so it the state would be wrong after loading. Learning, EpisodicLearning and MonteCarlo classes are all overriding custom save and load methods, calling super() each time but including fields that are necessary to replace on runtime. - moved generic episodic behaviour from monteCarlo to abstract top level class - using AtomicInteger for episodesToLearn - moved learning-Thread-handling from controller to model. Learning got one extra Leaning thread. - add feature to use custom speed and distance for dino world obstacles
This commit is contained in:
parent
64355e0b93
commit
195722e98f
|
@ -1,3 +1,6 @@
|
||||||
|
learningStates/*
|
||||||
|
!learningStates/.gitkeep
|
||||||
|
|
||||||
.idea/refo.iml
|
.idea/refo.iml
|
||||||
.idea/misc.xml
|
.idea/misc.xml
|
||||||
.idea/modules.xml
|
.idea/modules.xml
|
||||||
|
|
|
@ -2,59 +2,53 @@ package core.algo;
|
||||||
|
|
||||||
import core.DiscreteActionSpace;
|
import core.DiscreteActionSpace;
|
||||||
import core.Environment;
|
import core.Environment;
|
||||||
|
import core.StepResult;
|
||||||
import core.listener.LearningListener;
|
import core.listener.LearningListener;
|
||||||
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.ObjectInputStream;
|
||||||
|
import java.io.ObjectOutputStream;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
|
||||||
public abstract class EpisodicLearning<A extends Enum> extends Learning<A> implements Episodic {
|
public abstract class EpisodicLearning<A extends Enum> extends Learning<A> implements Episodic {
|
||||||
@Setter
|
@Setter
|
||||||
protected int currentEpisode;
|
protected int currentEpisode;
|
||||||
protected int episodesToLearn;
|
protected volatile AtomicInteger episodesToLearn = new AtomicInteger(0);
|
||||||
|
@Getter
|
||||||
protected volatile int episodePerSecond;
|
protected volatile int episodePerSecond;
|
||||||
protected int episodeSumCurrentSecond;
|
protected int episodeSumCurrentSecond;
|
||||||
private volatile boolean measureEpisodeBenchMark;
|
protected double sumOfRewards;
|
||||||
|
protected List<StepResult<A>> episode = new ArrayList<>();
|
||||||
|
|
||||||
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay) {
|
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay) {
|
||||||
super(environment, actionSpace, discountFactor, delay);
|
super(environment, actionSpace, discountFactor, delay);
|
||||||
|
initBenchMarking();
|
||||||
}
|
}
|
||||||
|
|
||||||
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor) {
|
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor) {
|
||||||
super(environment, actionSpace, discountFactor);
|
super(environment, actionSpace, discountFactor);
|
||||||
|
initBenchMarking();
|
||||||
}
|
}
|
||||||
|
|
||||||
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
|
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) {
|
||||||
super(environment, actionSpace, delay);
|
super(environment, actionSpace, delay);
|
||||||
|
initBenchMarking();
|
||||||
}
|
}
|
||||||
|
|
||||||
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace) {
|
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace) {
|
||||||
super(environment, actionSpace);
|
super(environment, actionSpace);
|
||||||
|
initBenchMarking();
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void dispatchEpisodeEnd(double recentSumOfRewards){
|
protected abstract void nextEpisode();
|
||||||
++episodeSumCurrentSecond;
|
|
||||||
if(rewardHistory.size() > 10000){
|
|
||||||
rewardHistory.clear();
|
|
||||||
}
|
|
||||||
rewardHistory.add(recentSumOfRewards);
|
|
||||||
for(LearningListener l: learningListeners) {
|
|
||||||
l.onEpisodeEnd(rewardHistory);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
protected void dispatchEpisodeStart(){
|
private void initBenchMarking(){
|
||||||
for(LearningListener l: learningListeners){
|
|
||||||
l.onEpisodeStart();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void learn(){
|
|
||||||
learn(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void learn(int nrOfEpisodes){
|
|
||||||
measureEpisodeBenchMark = true;
|
|
||||||
new Thread(()->{
|
new Thread(()->{
|
||||||
while(measureEpisodeBenchMark){
|
while (true){
|
||||||
episodePerSecond = episodeSumCurrentSecond;
|
episodePerSecond = episodeSumCurrentSecond;
|
||||||
episodeSumCurrentSecond = 0;
|
episodeSumCurrentSecond = 0;
|
||||||
try {
|
try {
|
||||||
|
@ -64,24 +58,89 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}).start();
|
}).start();
|
||||||
episodesToLearn += nrOfEpisodes;
|
|
||||||
dispatchLearningStart();
|
|
||||||
for(int i=0; i < nrOfEpisodes; ++i){
|
|
||||||
nextEpisode();
|
|
||||||
}
|
|
||||||
dispatchLearningEnd();
|
|
||||||
measureEpisodeBenchMark = false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract void nextEpisode();
|
protected void dispatchEpisodeEnd(){
|
||||||
|
++episodeSumCurrentSecond;
|
||||||
|
if(rewardHistory.size() > 10000){
|
||||||
|
rewardHistory.clear();
|
||||||
|
}
|
||||||
|
rewardHistory.add(sumOfRewards);
|
||||||
|
for(LearningListener l: learningListeners) {
|
||||||
|
l.onEpisodeEnd(rewardHistory);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void dispatchEpisodeStart(){
|
||||||
|
++currentEpisode;
|
||||||
|
episodesToLearn.decrementAndGet();
|
||||||
|
for(LearningListener l: learningListeners){
|
||||||
|
l.onEpisodeStart();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int getCurrentEpisode(){
|
public void learn(){
|
||||||
return currentEpisode;
|
// TODO remove or learn with default episode number
|
||||||
|
}
|
||||||
|
|
||||||
|
private void startLearning(){
|
||||||
|
learningExecutor.submit(()->{
|
||||||
|
dispatchLearningStart();
|
||||||
|
while(episodesToLearn.get() > 0){
|
||||||
|
dispatchEpisodeStart();
|
||||||
|
nextEpisode();
|
||||||
|
dispatchEpisodeEnd();
|
||||||
|
}
|
||||||
|
synchronized (this){
|
||||||
|
dispatchLearningEnd();
|
||||||
|
notifyAll();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stopping the while loop by setting episodesToLearn to 0.
|
||||||
|
* The current episode can not be interrupted, so the sleep delay
|
||||||
|
* is removed and the calling thread has to wait until the
|
||||||
|
* current episode is done.
|
||||||
|
* Resetting the delay afterwards.
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public synchronized void interruptLearning(){
|
||||||
|
episodesToLearn.set(0);
|
||||||
|
int prevDelay = delay;
|
||||||
|
delay = 0;
|
||||||
|
while(currentlyLearning) {
|
||||||
|
try {
|
||||||
|
wait();
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delay = prevDelay;
|
||||||
|
}
|
||||||
|
|
||||||
|
public synchronized void learn(int nrOfEpisodes){
|
||||||
|
boolean isLearning = episodesToLearn.getAndAdd(nrOfEpisodes) != 0;
|
||||||
|
if(!isLearning)
|
||||||
|
startLearning();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int getEpisodesToGo(){
|
public int getEpisodesToGo(){
|
||||||
return episodesToLearn - currentEpisode;
|
return episodesToLearn.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public synchronized void save(ObjectOutputStream oos) throws IOException {
|
||||||
|
super.save(oos);
|
||||||
|
oos.writeInt(currentEpisode);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public synchronized void load(ObjectInputStream ois) throws IOException, ClassNotFoundException {
|
||||||
|
super.load(ois);
|
||||||
|
currentEpisode = ois.readInt();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,10 +9,16 @@ import core.policy.Policy;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.ObjectInputStream;
|
||||||
|
import java.io.ObjectOutputStream;
|
||||||
|
import java.io.Serializable;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.concurrent.CopyOnWriteArrayList;
|
import java.util.concurrent.CopyOnWriteArrayList;
|
||||||
|
import java.util.concurrent.ExecutorService;
|
||||||
|
import java.util.concurrent.Executors;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -30,14 +36,18 @@ public abstract class Learning<A extends Enum>{
|
||||||
@Setter
|
@Setter
|
||||||
protected int delay;
|
protected int delay;
|
||||||
protected List<Double> rewardHistory;
|
protected List<Double> rewardHistory;
|
||||||
|
protected ExecutorService learningExecutor;
|
||||||
|
protected boolean currentlyLearning;
|
||||||
|
|
||||||
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay) {
|
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay) {
|
||||||
this.environment = environment;
|
this.environment = environment;
|
||||||
this.actionSpace = actionSpace;
|
this.actionSpace = actionSpace;
|
||||||
this.discountFactor = discountFactor;
|
this.discountFactor = discountFactor;
|
||||||
this.delay = delay;
|
this.delay = delay;
|
||||||
|
currentlyLearning = false;
|
||||||
learningListeners = new HashSet<>();
|
learningListeners = new HashSet<>();
|
||||||
rewardHistory = new CopyOnWriteArrayList<>();
|
rewardHistory = new CopyOnWriteArrayList<>();
|
||||||
|
learningExecutor = Executors.newSingleThreadExecutor();
|
||||||
}
|
}
|
||||||
|
|
||||||
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor) {
|
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor) {
|
||||||
|
@ -52,7 +62,6 @@ public abstract class Learning<A extends Enum>{
|
||||||
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_DELAY);
|
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_DELAY);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public abstract void learn();
|
public abstract void learn();
|
||||||
|
|
||||||
public void addListener(LearningListener learningListener) {
|
public void addListener(LearningListener learningListener) {
|
||||||
|
@ -66,15 +75,31 @@ public abstract class Learning<A extends Enum>{
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void dispatchLearningStart() {
|
protected void dispatchLearningStart() {
|
||||||
|
currentlyLearning = true;
|
||||||
for (LearningListener l : learningListeners) {
|
for (LearningListener l : learningListeners) {
|
||||||
l.onLearningStart();
|
l.onLearningStart();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void dispatchLearningEnd() {
|
protected void dispatchLearningEnd() {
|
||||||
|
currentlyLearning = false;
|
||||||
for (LearningListener l : learningListeners) {
|
for (LearningListener l : learningListeners) {
|
||||||
l.onLearningEnd();
|
l.onLearningEnd();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public synchronized void interruptLearning(){
|
||||||
|
//TODO: for non episodic learning
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public void save(ObjectOutputStream oos) throws IOException {
|
||||||
|
oos.writeObject(rewardHistory);
|
||||||
|
oos.writeObject(stateActionTable);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void load(ObjectInputStream ois) throws IOException, ClassNotFoundException {
|
||||||
|
rewardHistory = (List<Double>) ois.readObject();
|
||||||
|
stateActionTable = (StateActionTable<A>) ois.readObject();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,9 @@ import core.algo.EpisodicLearning;
|
||||||
import core.policy.EpsilonGreedyPolicy;
|
import core.policy.EpsilonGreedyPolicy;
|
||||||
import javafx.util.Pair;
|
import javafx.util.Pair;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.ObjectInputStream;
|
||||||
|
import java.io.ObjectOutputStream;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -44,19 +47,16 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
|
||||||
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay);
|
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void nextEpisode() {
|
public void nextEpisode() {
|
||||||
++currentEpisode;
|
episode = new ArrayList<>();
|
||||||
List<StepResult<A>> episode = new ArrayList<>();
|
|
||||||
State state = environment.reset();
|
State state = environment.reset();
|
||||||
dispatchEpisodeStart();
|
|
||||||
try {
|
try {
|
||||||
Thread.sleep(delay);
|
Thread.sleep(delay);
|
||||||
} catch (InterruptedException e) {
|
} catch (InterruptedException e) {
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
}
|
}
|
||||||
double sumOfRewards = 0;
|
sumOfRewards = 0;
|
||||||
StepResultEnvironment envResult = null;
|
StepResultEnvironment envResult = null;
|
||||||
while(envResult == null || !envResult.isDone()){
|
while(envResult == null || !envResult.isDone()){
|
||||||
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
|
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
|
||||||
|
@ -76,7 +76,6 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
|
||||||
dispatchStepEnd();
|
dispatchStepEnd();
|
||||||
}
|
}
|
||||||
|
|
||||||
dispatchEpisodeEnd(sumOfRewards);
|
|
||||||
// System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
|
// System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
|
||||||
Set<Pair<State, A>> stateActionPairs = new LinkedHashSet<>();
|
Set<Pair<State, A>> stateActionPairs = new LinkedHashSet<>();
|
||||||
|
|
||||||
|
@ -115,4 +114,18 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends EpisodicLearning<
|
||||||
public int getEpisodesPerSecond(){
|
public int getEpisodesPerSecond(){
|
||||||
return episodePerSecond;
|
return episodePerSecond;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void save(ObjectOutputStream oos) throws IOException {
|
||||||
|
super.save(oos);
|
||||||
|
oos.writeObject(returnSum);
|
||||||
|
oos.writeObject(returnCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void load(ObjectInputStream ois) throws IOException, ClassNotFoundException {
|
||||||
|
super.load(ois);
|
||||||
|
returnSum = (Map<Pair<State, A>, Double>) ois.readObject();
|
||||||
|
returnCount = (Map<Pair<State, A>, Integer>) ois.readObject();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@ import java.util.concurrent.ExecutorService;
|
||||||
import java.util.concurrent.Executors;
|
import java.util.concurrent.Executors;
|
||||||
|
|
||||||
public class RLController<A extends Enum> implements ViewListener, LearningListener {
|
public class RLController<A extends Enum> implements ViewListener, LearningListener {
|
||||||
|
private final String folderPrefix = "learningStates" + File.separator;
|
||||||
private Environment<A> environment;
|
private Environment<A> environment;
|
||||||
private DiscreteActionSpace<A> discreteActionSpace;
|
private DiscreteActionSpace<A> discreteActionSpace;
|
||||||
private Method method;
|
private Method method;
|
||||||
|
@ -26,15 +27,12 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
||||||
private float epsilon = LearningConfig.DEFAULT_EPSILON;
|
private float epsilon = LearningConfig.DEFAULT_EPSILON;
|
||||||
private Learning<A> learning;
|
private Learning<A> learning;
|
||||||
private LearningView learningView;
|
private LearningView learningView;
|
||||||
private ExecutorService learningExecutor;
|
|
||||||
private boolean currentlyLearning;
|
|
||||||
private boolean fastLearning;
|
private boolean fastLearning;
|
||||||
private List<Double> latestRewardsHistory;
|
private List<Double> latestRewardsHistory;
|
||||||
private int nrOfEpisodes;
|
private int nrOfEpisodes;
|
||||||
private int prevDelay;
|
private int prevDelay;
|
||||||
|
|
||||||
public RLController(Environment<A> env, Method method, A... actions){
|
public RLController(Environment<A> env, Method method, A... actions){
|
||||||
learningExecutor = Executors.newSingleThreadExecutor();
|
|
||||||
setEnvironment(env);
|
setEnvironment(env);
|
||||||
setMethod(method);
|
setMethod(method);
|
||||||
setAllowedActions(actions);
|
setAllowedActions(actions);
|
||||||
|
@ -64,9 +62,9 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
||||||
|
|
||||||
private void initLearning(){
|
private void initLearning(){
|
||||||
if(learning instanceof EpisodicLearning){
|
if(learning instanceof EpisodicLearning){
|
||||||
learningExecutor.submit(()->((EpisodicLearning) learning).learn(nrOfEpisodes));
|
((EpisodicLearning) learning).learn(nrOfEpisodes);
|
||||||
}else{
|
}else{
|
||||||
learningExecutor.submit(()->learning.learn());
|
learning.learn();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -75,27 +73,25 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
||||||
*************************************************/
|
*************************************************/
|
||||||
@Override
|
@Override
|
||||||
public void onLearnMoreEpisodes(int nrOfEpisodes){
|
public void onLearnMoreEpisodes(int nrOfEpisodes){
|
||||||
if(!currentlyLearning){
|
if(learning instanceof EpisodicLearning){
|
||||||
if(learning instanceof EpisodicLearning){
|
((EpisodicLearning) learning).learn(nrOfEpisodes);
|
||||||
learningExecutor.submit(()->((EpisodicLearning) learning).learn(nrOfEpisodes));
|
}else{
|
||||||
}else{
|
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
|
||||||
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
learningView.updateLearningInfoPanel();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void onLoadState(String fileName) {
|
public void onLoadState(String fileName) {
|
||||||
FileInputStream fis;
|
FileInputStream fis;
|
||||||
ObjectInput in;
|
ObjectInputStream in;
|
||||||
try {
|
try {
|
||||||
fis = new FileInputStream(fileName);
|
fis = new FileInputStream(fileName);
|
||||||
in = new ObjectInputStream(fis);
|
in = new ObjectInputStream(fis);
|
||||||
SaveState<A> saveState = (SaveState<A>) in.readObject();
|
System.out.println("interrup" + Thread.currentThread().getId());
|
||||||
learning.setStateActionTable(saveState.getStateActionTable());
|
learning.interruptLearning();
|
||||||
if(learning instanceof EpisodicLearning){
|
learning.load(in);
|
||||||
((EpisodicLearning) learning).setCurrentEpisode(saveState.getCurrentEpisode());
|
SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
|
||||||
}
|
|
||||||
in.close();
|
in.close();
|
||||||
} catch (IOException | ClassNotFoundException e) {
|
} catch (IOException | ClassNotFoundException e) {
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
|
@ -107,15 +103,10 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
||||||
FileOutputStream fos;
|
FileOutputStream fos;
|
||||||
ObjectOutputStream out;
|
ObjectOutputStream out;
|
||||||
try{
|
try{
|
||||||
fos = new FileOutputStream(fileName);
|
fos = new FileOutputStream(folderPrefix + fileName);
|
||||||
out = new ObjectOutputStream(fos);
|
out = new ObjectOutputStream(fos);
|
||||||
int currentEpisode;
|
learning.interruptLearning();
|
||||||
if(learning instanceof EpisodicLearning){
|
learning.save(out);
|
||||||
currentEpisode = ((EpisodicLearning) learning).getCurrentEpisode();
|
|
||||||
}else{
|
|
||||||
currentEpisode = 0;
|
|
||||||
}
|
|
||||||
out.writeObject(new SaveState<>(learning.getStateActionTable(), currentEpisode));
|
|
||||||
out.close();
|
out.close();
|
||||||
}catch (IOException e){
|
}catch (IOException e){
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
|
@ -158,13 +149,12 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
||||||
*************************************************/
|
*************************************************/
|
||||||
@Override
|
@Override
|
||||||
public void onLearningStart() {
|
public void onLearningStart() {
|
||||||
currentlyLearning = true;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void onLearningEnd() {
|
public void onLearningEnd() {
|
||||||
currentlyLearning = false;
|
|
||||||
SwingUtilities.invokeLater(()-> learningView.updateRewardGraph(latestRewardsHistory));
|
SwingUtilities.invokeLater(()-> learningView.updateRewardGraph(latestRewardsHistory));
|
||||||
|
onSaveState( method.toString() + System.currentTimeMillis()/1000 + (learning instanceof EpisodicLearning ? "e " + ((EpisodicLearning) learning).getCurrentEpisode() : ""));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -192,7 +182,7 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
||||||
|
|
||||||
|
|
||||||
/*************************************************
|
/*************************************************
|
||||||
** SETTER **
|
** SETTERS **
|
||||||
*************************************************/
|
*************************************************/
|
||||||
|
|
||||||
private void setEnvironment(Environment<A> environment){
|
private void setEnvironment(Environment<A> environment){
|
||||||
|
|
|
@ -14,12 +14,20 @@ import java.awt.*;
|
||||||
public class DinoWorld implements Environment<DinoAction>, Visualizable {
|
public class DinoWorld implements Environment<DinoAction>, Visualizable {
|
||||||
private Dino dino;
|
private Dino dino;
|
||||||
private Obstacle currentObstacle;
|
private Obstacle currentObstacle;
|
||||||
|
private boolean randomObstacleSpeed;
|
||||||
|
private boolean randomObstacleDistance;
|
||||||
|
|
||||||
public DinoWorld(){
|
public DinoWorld(boolean randomObstacleSpeed, boolean randomObstacleDistance){
|
||||||
|
this.randomObstacleSpeed = randomObstacleSpeed;
|
||||||
|
this.randomObstacleDistance = randomObstacleDistance;
|
||||||
dino = new Dino(Config.DINO_SIZE, Config.DINO_STARTING_X, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE, 0, 0, Color.GREEN);
|
dino = new Dino(Config.DINO_SIZE, Config.DINO_STARTING_X, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE, 0, 0, Color.GREEN);
|
||||||
spawnNewObstacle();
|
spawnNewObstacle();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public DinoWorld(){
|
||||||
|
this(false, false);
|
||||||
|
}
|
||||||
|
|
||||||
private boolean ranIntoObstacle(){
|
private boolean ranIntoObstacle(){
|
||||||
Obstacle o = currentObstacle;
|
Obstacle o = currentObstacle;
|
||||||
Dino p = dino;
|
Dino p = dino;
|
||||||
|
@ -32,6 +40,7 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
|
||||||
|
|
||||||
return xAxis && yAxis;
|
return xAxis && yAxis;
|
||||||
}
|
}
|
||||||
|
|
||||||
private int getDistanceToObstacle(){
|
private int getDistanceToObstacle(){
|
||||||
return currentObstacle.getX() - dino.getX() + Config.DINO_SIZE;
|
return currentObstacle.getX() - dino.getX() + Config.DINO_SIZE;
|
||||||
}
|
}
|
||||||
|
@ -57,8 +66,27 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
|
||||||
|
|
||||||
return new StepResultEnvironment(new DinoState(getDistanceToObstacle()), reward, done, "");
|
return new StepResultEnvironment(new DinoState(getDistanceToObstacle()), reward, done, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
private void spawnNewObstacle(){
|
private void spawnNewObstacle(){
|
||||||
currentObstacle = new Obstacle(Config.OBSTACLE_SIZE, Config.FRAME_WIDTH + Config.OBSTACLE_SIZE, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE, -Config.OBSTACLE_SPEED, 0, Color.BLACK);
|
int dx;
|
||||||
|
int xSpawn;
|
||||||
|
|
||||||
|
if(randomObstacleSpeed){
|
||||||
|
dx = -(int)((Math.random() + 0.5) * Config.OBSTACLE_SPEED);
|
||||||
|
}else{
|
||||||
|
dx = -Config.OBSTACLE_SPEED;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(randomObstacleDistance){
|
||||||
|
// randomly spawning more right outside of the screen
|
||||||
|
xSpawn = (int)(Math.random() + 0.5 * Config.FRAME_WIDTH + Config.FRAME_WIDTH + Config.OBSTACLE_SIZE);
|
||||||
|
}else{
|
||||||
|
// instantly respawning on the left screen border
|
||||||
|
xSpawn = Config.FRAME_WIDTH + Config.OBSTACLE_SIZE;
|
||||||
|
}
|
||||||
|
|
||||||
|
currentObstacle = new Obstacle(Config.OBSTACLE_SIZE, xSpawn, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE, dx, 0, Color.BLACK);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void spawnDino(){
|
private void spawnDino(){
|
||||||
|
|
|
@ -11,7 +11,7 @@ public class JumpingDino {
|
||||||
RNG.setSeed(55);
|
RNG.setSeed(55);
|
||||||
|
|
||||||
RLController<DinoAction> rl = new RLController<>(
|
RLController<DinoAction> rl = new RLController<>(
|
||||||
new DinoWorld(),
|
new DinoWorld(true, true),
|
||||||
Method.MC_ONPOLICY_EGREEDY,
|
Method.MC_ONPOLICY_EGREEDY,
|
||||||
DinoAction.values());
|
DinoAction.values());
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue