apply threading changes to master branch and clean up for tag version

- no testing or epsilon testing stuff
This commit is contained in:
Jan Löwenstrom 2020-03-05 11:49:51 +01:00
parent 6613e23c7c
commit cffec63dc6
6 changed files with 40 additions and 77 deletions

View File

@ -50,7 +50,7 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
private void initBenchMarking(){
new Thread(()->{
while (true){
while (currentlyLearning){
episodePerSecond = episodeSumCurrentSecond;
episodeSumCurrentSecond = 0;
try {
@ -62,7 +62,7 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
}).start();
}
protected void dispatchEpisodeEnd(){
private void dispatchEpisodeEnd(){
++episodeSumCurrentSecond;
if(rewardHistory.size() > 10000){
rewardHistory.clear();
@ -75,18 +75,6 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
protected void dispatchEpisodeStart(){
++currentEpisode;
/*
2f 0.02 => 100
1.5f 0.02 => 75
1.4f 0.02 => fail
1.5f 0.1 => 16 !
*/
if(this.policy instanceof EpsilonGreedyPolicy){
float ep = 1.5f/(float)currentEpisode;
if(ep < 0.10) ep = 0;
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(ep);
System.out.println(ep);
}
episodesToLearn.decrementAndGet();
for(LearningListener l: learningListeners){
l.onEpisodeStart();
@ -97,31 +85,24 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
protected void dispatchStepEnd() {
super.dispatchStepEnd();
timestamp++;
// TODO: more sophisticated way to check convergence
if(timestamp > 300000){
System.out.println("converged after: " + currentEpisode + " episode!");
interruptLearning();
}
}
@Override
public void learn(){
learn(LearningConfig.DEFAULT_NR_OF_EPISODES);
}
private void startLearning(){
learningExecutor.submit(()->{
dispatchLearningStart();
while(episodesToLearn.get() > 0){
dispatchEpisodeStart();
nextEpisode();
dispatchEpisodeEnd();
}
synchronized (this){
dispatchLearningEnd();
notifyAll();
}
});
dispatchLearningStart();
System.out.println(episodesToLearn.get());
while(episodesToLearn.get() > 0){
dispatchEpisodeStart();
nextEpisode();
dispatchEpisodeEnd();
}
synchronized (this){
dispatchLearningEnd();
notifyAll();
}
}
public void learnMoreEpisodes(int nrOfEpisodes){
episodesToLearn.addAndGet(nrOfEpisodes);
}
/**
@ -146,8 +127,14 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
delay = prevDelay;
}
@Override
public void learn(){
learn(LearningConfig.DEFAULT_NR_OF_EPISODES);
}
public synchronized void learn(int nrOfEpisodes){
boolean isLearning = episodesToLearn.getAndAdd(nrOfEpisodes) != 0;
System.out.println(isLearning);
if(!isLearning)
startLearning();
}

View File

@ -42,8 +42,7 @@ public abstract class Learning<A extends Enum>{
@Setter
protected int delay;
protected List<Double> rewardHistory;
protected ExecutorService learningExecutor;
protected boolean currentlyLearning;
protected volatile boolean currentlyLearning;
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay) {
this.environment = environment;
@ -53,7 +52,6 @@ public abstract class Learning<A extends Enum>{
currentlyLearning = false;
learningListeners = new HashSet<>();
rewardHistory = new CopyOnWriteArrayList<>();
learningExecutor = Executors.newSingleThreadExecutor();
}
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor) {
@ -89,8 +87,6 @@ public abstract class Learning<A extends Enum>{
protected void dispatchLearningEnd() {
currentlyLearning = false;
System.out.println("Checksum: " + checkSum);
System.out.println("Reward Checksum: " + rewardCheckSum);
for (LearningListener l : learningListeners) {
l.onLearningEnd();
}

View File

@ -83,7 +83,7 @@ public class RLController<A extends Enum> implements LearningListener {
private void initLearning() {
if(learning instanceof EpisodicLearning) {
System.out.println("Starting learning of <" + nrOfEpisodes + "> episodes");
((EpisodicLearning) learning).learn(nrOfEpisodes);
((EpisodicLearning<A>) learning).learn(nrOfEpisodes);
} else {
learning.learn();
}
@ -95,7 +95,13 @@ public class RLController<A extends Enum> implements LearningListener {
protected void learnMoreEpisodes(int nrOfEpisodes) {
if(learning instanceof EpisodicLearning) {
((EpisodicLearning) learning).learn(nrOfEpisodes);
if(learning.isCurrentlyLearning()){
((EpisodicLearning<A>) learning).learnMoreEpisodes(nrOfEpisodes);
}else{
new Thread(() -> {
((EpisodicLearning<A>) learning).learn(nrOfEpisodes);
}).start();
}
} else {
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
}
@ -169,8 +175,8 @@ public class RLController<A extends Enum> implements LearningListener {
public void onEpisodeEnd(List<Double> rewardHistory) {
latestRewardsHistory = rewardHistory;
if(printNextEpisode) {
System.out.println("Episode " + ((EpisodicLearning) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1));
System.out.println("Eps/sec: " + ((EpisodicLearning) learning).getEpisodePerSecond());
System.out.println("Episode " + ((EpisodicLearning<A>) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1));
System.out.println("Eps/sec: " + ((EpisodicLearning<A>) learning).getEpisodePerSecond());
printNextEpisode = false;
}
}

View File

@ -60,9 +60,9 @@ public class LearningInfoPanel extends JPanel {
viewListener.onFastLearnChange(fastLearning);
});
smoothGraphCheckbox = new JCheckBox("Smoothen Graph");
smoothGraphCheckbox.setSelected(false);
smoothGraphCheckbox.setSelected(true);
last100Checkbox = new JCheckBox("Only show last 100 Rewards");
last100Checkbox.setSelected(true);
last100Checkbox.setSelected(false);
drawEnvironmentCheckbox = new JCheckBox("Update Environment");
drawEnvironmentCheckbox.setSelected(true);

View File

@ -1,27 +0,0 @@
package example;
import core.RNG;
import core.algo.Method;
import core.controller.RLController;
import evironment.jumpingDino.DinoAction;
import evironment.jumpingDino.DinoWorld;
public class DinoSampling {
public static void main(String[] args) {
for (int i = 0; i < 10 ; i++) {
RNG.setSeed(55);
RLController<DinoAction> rl = new RLController<>(
new DinoWorld(false, false),
Method.MC_CONTROL_FIRST_VISIT,
DinoAction.values());
rl.setDelay(0);
rl.setDiscountFactor(1f);
rl.setEpsilon(0.15f);
rl.setLearningRate(1f);
rl.setNrOfEpisodes(400);
rl.start();
}
}
}

View File

@ -3,6 +3,7 @@ package example;
import core.RNG;
import core.algo.Method;
import core.controller.RLController;
import core.controller.RLControllerGUI;
import evironment.jumpingDino.DinoAction;
import evironment.jumpingDino.DinoWorld;
@ -10,16 +11,16 @@ public class JumpingDino {
public static void main(String[] args) {
RNG.setSeed(55);
RLController<DinoAction> rl = new RLController<>(
RLController<DinoAction> rl = new RLControllerGUI<>(
new DinoWorld(false, false),
Method.MC_CONTROL_FIRST_VISIT,
DinoAction.values());
rl.setDelay(0);
rl.setDelay(100);
rl.setDiscountFactor(1f);
rl.setEpsilon(0.15f);
rl.setLearningRate(1f);
rl.setNrOfEpisodes(400);
rl.setNrOfEpisodes(10000);
rl.start();
}
}