apply threading changes to master branch and clean up for tag version

- no testing or epsilon testing stuff
This commit is contained in:
Jan Löwenstrom 2020-03-05 11:49:51 +01:00
parent 6613e23c7c
commit cffec63dc6
6 changed files with 40 additions and 77 deletions

View File

@ -50,7 +50,7 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
private void initBenchMarking(){ private void initBenchMarking(){
new Thread(()->{ new Thread(()->{
while (true){ while (currentlyLearning){
episodePerSecond = episodeSumCurrentSecond; episodePerSecond = episodeSumCurrentSecond;
episodeSumCurrentSecond = 0; episodeSumCurrentSecond = 0;
try { try {
@ -62,7 +62,7 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
}).start(); }).start();
} }
protected void dispatchEpisodeEnd(){ private void dispatchEpisodeEnd(){
++episodeSumCurrentSecond; ++episodeSumCurrentSecond;
if(rewardHistory.size() > 10000){ if(rewardHistory.size() > 10000){
rewardHistory.clear(); rewardHistory.clear();
@ -75,18 +75,6 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
protected void dispatchEpisodeStart(){ protected void dispatchEpisodeStart(){
++currentEpisode; ++currentEpisode;
/*
2f 0.02 => 100
1.5f 0.02 => 75
1.4f 0.02 => fail
1.5f 0.1 => 16 !
*/
if(this.policy instanceof EpsilonGreedyPolicy){
float ep = 1.5f/(float)currentEpisode;
if(ep < 0.10) ep = 0;
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(ep);
System.out.println(ep);
}
episodesToLearn.decrementAndGet(); episodesToLearn.decrementAndGet();
for(LearningListener l: learningListeners){ for(LearningListener l: learningListeners){
l.onEpisodeStart(); l.onEpisodeStart();
@ -97,31 +85,24 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
protected void dispatchStepEnd() { protected void dispatchStepEnd() {
super.dispatchStepEnd(); super.dispatchStepEnd();
timestamp++; timestamp++;
// TODO: more sophisticated way to check convergence
if(timestamp > 300000){
System.out.println("converged after: " + currentEpisode + " episode!");
interruptLearning();
}
}
@Override
public void learn(){
learn(LearningConfig.DEFAULT_NR_OF_EPISODES);
} }
private void startLearning(){ private void startLearning(){
learningExecutor.submit(()->{ dispatchLearningStart();
dispatchLearningStart(); System.out.println(episodesToLearn.get());
while(episodesToLearn.get() > 0){ while(episodesToLearn.get() > 0){
dispatchEpisodeStart(); dispatchEpisodeStart();
nextEpisode(); nextEpisode();
dispatchEpisodeEnd(); dispatchEpisodeEnd();
} }
synchronized (this){ synchronized (this){
dispatchLearningEnd(); dispatchLearningEnd();
notifyAll(); notifyAll();
} }
}); }
public void learnMoreEpisodes(int nrOfEpisodes){
episodesToLearn.addAndGet(nrOfEpisodes);
} }
/** /**
@ -146,8 +127,14 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
delay = prevDelay; delay = prevDelay;
} }
@Override
public void learn(){
learn(LearningConfig.DEFAULT_NR_OF_EPISODES);
}
public synchronized void learn(int nrOfEpisodes){ public synchronized void learn(int nrOfEpisodes){
boolean isLearning = episodesToLearn.getAndAdd(nrOfEpisodes) != 0; boolean isLearning = episodesToLearn.getAndAdd(nrOfEpisodes) != 0;
System.out.println(isLearning);
if(!isLearning) if(!isLearning)
startLearning(); startLearning();
} }

View File

@ -42,8 +42,7 @@ public abstract class Learning<A extends Enum>{
@Setter @Setter
protected int delay; protected int delay;
protected List<Double> rewardHistory; protected List<Double> rewardHistory;
protected ExecutorService learningExecutor; protected volatile boolean currentlyLearning;
protected boolean currentlyLearning;
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay) { public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay) {
this.environment = environment; this.environment = environment;
@ -53,7 +52,6 @@ public abstract class Learning<A extends Enum>{
currentlyLearning = false; currentlyLearning = false;
learningListeners = new HashSet<>(); learningListeners = new HashSet<>();
rewardHistory = new CopyOnWriteArrayList<>(); rewardHistory = new CopyOnWriteArrayList<>();
learningExecutor = Executors.newSingleThreadExecutor();
} }
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor) { public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor) {
@ -89,8 +87,6 @@ public abstract class Learning<A extends Enum>{
protected void dispatchLearningEnd() { protected void dispatchLearningEnd() {
currentlyLearning = false; currentlyLearning = false;
System.out.println("Checksum: " + checkSum);
System.out.println("Reward Checksum: " + rewardCheckSum);
for (LearningListener l : learningListeners) { for (LearningListener l : learningListeners) {
l.onLearningEnd(); l.onLearningEnd();
} }

View File

@ -83,7 +83,7 @@ public class RLController<A extends Enum> implements LearningListener {
private void initLearning() { private void initLearning() {
if(learning instanceof EpisodicLearning) { if(learning instanceof EpisodicLearning) {
System.out.println("Starting learning of <" + nrOfEpisodes + "> episodes"); System.out.println("Starting learning of <" + nrOfEpisodes + "> episodes");
((EpisodicLearning) learning).learn(nrOfEpisodes); ((EpisodicLearning<A>) learning).learn(nrOfEpisodes);
} else { } else {
learning.learn(); learning.learn();
} }
@ -95,7 +95,13 @@ public class RLController<A extends Enum> implements LearningListener {
protected void learnMoreEpisodes(int nrOfEpisodes) { protected void learnMoreEpisodes(int nrOfEpisodes) {
if(learning instanceof EpisodicLearning) { if(learning instanceof EpisodicLearning) {
((EpisodicLearning) learning).learn(nrOfEpisodes); if(learning.isCurrentlyLearning()){
((EpisodicLearning<A>) learning).learnMoreEpisodes(nrOfEpisodes);
}else{
new Thread(() -> {
((EpisodicLearning<A>) learning).learn(nrOfEpisodes);
}).start();
}
} else { } else {
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!"); throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
} }
@ -169,8 +175,8 @@ public class RLController<A extends Enum> implements LearningListener {
public void onEpisodeEnd(List<Double> rewardHistory) { public void onEpisodeEnd(List<Double> rewardHistory) {
latestRewardsHistory = rewardHistory; latestRewardsHistory = rewardHistory;
if(printNextEpisode) { if(printNextEpisode) {
System.out.println("Episode " + ((EpisodicLearning) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1)); System.out.println("Episode " + ((EpisodicLearning<A>) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1));
System.out.println("Eps/sec: " + ((EpisodicLearning) learning).getEpisodePerSecond()); System.out.println("Eps/sec: " + ((EpisodicLearning<A>) learning).getEpisodePerSecond());
printNextEpisode = false; printNextEpisode = false;
} }
} }

View File

@ -60,9 +60,9 @@ public class LearningInfoPanel extends JPanel {
viewListener.onFastLearnChange(fastLearning); viewListener.onFastLearnChange(fastLearning);
}); });
smoothGraphCheckbox = new JCheckBox("Smoothen Graph"); smoothGraphCheckbox = new JCheckBox("Smoothen Graph");
smoothGraphCheckbox.setSelected(false); smoothGraphCheckbox.setSelected(true);
last100Checkbox = new JCheckBox("Only show last 100 Rewards"); last100Checkbox = new JCheckBox("Only show last 100 Rewards");
last100Checkbox.setSelected(true); last100Checkbox.setSelected(false);
drawEnvironmentCheckbox = new JCheckBox("Update Environment"); drawEnvironmentCheckbox = new JCheckBox("Update Environment");
drawEnvironmentCheckbox.setSelected(true); drawEnvironmentCheckbox.setSelected(true);

View File

@ -1,27 +0,0 @@
package example;
import core.RNG;
import core.algo.Method;
import core.controller.RLController;
import evironment.jumpingDino.DinoAction;
import evironment.jumpingDino.DinoWorld;
public class DinoSampling {
public static void main(String[] args) {
for (int i = 0; i < 10 ; i++) {
RNG.setSeed(55);
RLController<DinoAction> rl = new RLController<>(
new DinoWorld(false, false),
Method.MC_CONTROL_FIRST_VISIT,
DinoAction.values());
rl.setDelay(0);
rl.setDiscountFactor(1f);
rl.setEpsilon(0.15f);
rl.setLearningRate(1f);
rl.setNrOfEpisodes(400);
rl.start();
}
}
}

View File

@ -3,6 +3,7 @@ package example;
import core.RNG; import core.RNG;
import core.algo.Method; import core.algo.Method;
import core.controller.RLController; import core.controller.RLController;
import core.controller.RLControllerGUI;
import evironment.jumpingDino.DinoAction; import evironment.jumpingDino.DinoAction;
import evironment.jumpingDino.DinoWorld; import evironment.jumpingDino.DinoWorld;
@ -10,16 +11,16 @@ public class JumpingDino {
public static void main(String[] args) { public static void main(String[] args) {
RNG.setSeed(55); RNG.setSeed(55);
RLController<DinoAction> rl = new RLController<>( RLController<DinoAction> rl = new RLControllerGUI<>(
new DinoWorld(false, false), new DinoWorld(false, false),
Method.MC_CONTROL_FIRST_VISIT, Method.MC_CONTROL_FIRST_VISIT,
DinoAction.values()); DinoAction.values());
rl.setDelay(0); rl.setDelay(100);
rl.setDiscountFactor(1f); rl.setDiscountFactor(1f);
rl.setEpsilon(0.15f); rl.setEpsilon(0.15f);
rl.setLearningRate(1f); rl.setLearningRate(1f);
rl.setNrOfEpisodes(400); rl.setNrOfEpisodes(10000);
rl.start(); rl.start();
} }
} }