apply threading changes to master branch and clean up for tag version
- no testing or epsilon testing stuff
This commit is contained in:
parent
6613e23c7c
commit
cffec63dc6
|
@ -50,7 +50,7 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
|
|||
|
||||
private void initBenchMarking(){
|
||||
new Thread(()->{
|
||||
while (true){
|
||||
while (currentlyLearning){
|
||||
episodePerSecond = episodeSumCurrentSecond;
|
||||
episodeSumCurrentSecond = 0;
|
||||
try {
|
||||
|
@ -62,7 +62,7 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
|
|||
}).start();
|
||||
}
|
||||
|
||||
protected void dispatchEpisodeEnd(){
|
||||
private void dispatchEpisodeEnd(){
|
||||
++episodeSumCurrentSecond;
|
||||
if(rewardHistory.size() > 10000){
|
||||
rewardHistory.clear();
|
||||
|
@ -75,18 +75,6 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
|
|||
|
||||
protected void dispatchEpisodeStart(){
|
||||
++currentEpisode;
|
||||
/*
|
||||
2f 0.02 => 100
|
||||
1.5f 0.02 => 75
|
||||
1.4f 0.02 => fail
|
||||
1.5f 0.1 => 16 !
|
||||
*/
|
||||
if(this.policy instanceof EpsilonGreedyPolicy){
|
||||
float ep = 1.5f/(float)currentEpisode;
|
||||
if(ep < 0.10) ep = 0;
|
||||
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(ep);
|
||||
System.out.println(ep);
|
||||
}
|
||||
episodesToLearn.decrementAndGet();
|
||||
for(LearningListener l: learningListeners){
|
||||
l.onEpisodeStart();
|
||||
|
@ -97,31 +85,24 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
|
|||
protected void dispatchStepEnd() {
|
||||
super.dispatchStepEnd();
|
||||
timestamp++;
|
||||
// TODO: more sophisticated way to check convergence
|
||||
if(timestamp > 300000){
|
||||
System.out.println("converged after: " + currentEpisode + " episode!");
|
||||
interruptLearning();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void learn(){
|
||||
learn(LearningConfig.DEFAULT_NR_OF_EPISODES);
|
||||
}
|
||||
|
||||
private void startLearning(){
|
||||
learningExecutor.submit(()->{
|
||||
dispatchLearningStart();
|
||||
while(episodesToLearn.get() > 0){
|
||||
dispatchEpisodeStart();
|
||||
nextEpisode();
|
||||
dispatchEpisodeEnd();
|
||||
}
|
||||
synchronized (this){
|
||||
dispatchLearningEnd();
|
||||
notifyAll();
|
||||
}
|
||||
});
|
||||
dispatchLearningStart();
|
||||
System.out.println(episodesToLearn.get());
|
||||
while(episodesToLearn.get() > 0){
|
||||
dispatchEpisodeStart();
|
||||
nextEpisode();
|
||||
dispatchEpisodeEnd();
|
||||
}
|
||||
synchronized (this){
|
||||
dispatchLearningEnd();
|
||||
notifyAll();
|
||||
}
|
||||
}
|
||||
|
||||
public void learnMoreEpisodes(int nrOfEpisodes){
|
||||
episodesToLearn.addAndGet(nrOfEpisodes);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -146,8 +127,14 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
|
|||
delay = prevDelay;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void learn(){
|
||||
learn(LearningConfig.DEFAULT_NR_OF_EPISODES);
|
||||
}
|
||||
|
||||
public synchronized void learn(int nrOfEpisodes){
|
||||
boolean isLearning = episodesToLearn.getAndAdd(nrOfEpisodes) != 0;
|
||||
System.out.println(isLearning);
|
||||
if(!isLearning)
|
||||
startLearning();
|
||||
}
|
||||
|
|
|
@ -42,8 +42,7 @@ public abstract class Learning<A extends Enum>{
|
|||
@Setter
|
||||
protected int delay;
|
||||
protected List<Double> rewardHistory;
|
||||
protected ExecutorService learningExecutor;
|
||||
protected boolean currentlyLearning;
|
||||
protected volatile boolean currentlyLearning;
|
||||
|
||||
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay) {
|
||||
this.environment = environment;
|
||||
|
@ -53,7 +52,6 @@ public abstract class Learning<A extends Enum>{
|
|||
currentlyLearning = false;
|
||||
learningListeners = new HashSet<>();
|
||||
rewardHistory = new CopyOnWriteArrayList<>();
|
||||
learningExecutor = Executors.newSingleThreadExecutor();
|
||||
}
|
||||
|
||||
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor) {
|
||||
|
@ -89,8 +87,6 @@ public abstract class Learning<A extends Enum>{
|
|||
|
||||
protected void dispatchLearningEnd() {
|
||||
currentlyLearning = false;
|
||||
System.out.println("Checksum: " + checkSum);
|
||||
System.out.println("Reward Checksum: " + rewardCheckSum);
|
||||
for (LearningListener l : learningListeners) {
|
||||
l.onLearningEnd();
|
||||
}
|
||||
|
|
|
@ -83,7 +83,7 @@ public class RLController<A extends Enum> implements LearningListener {
|
|||
private void initLearning() {
|
||||
if(learning instanceof EpisodicLearning) {
|
||||
System.out.println("Starting learning of <" + nrOfEpisodes + "> episodes");
|
||||
((EpisodicLearning) learning).learn(nrOfEpisodes);
|
||||
((EpisodicLearning<A>) learning).learn(nrOfEpisodes);
|
||||
} else {
|
||||
learning.learn();
|
||||
}
|
||||
|
@ -95,7 +95,13 @@ public class RLController<A extends Enum> implements LearningListener {
|
|||
|
||||
protected void learnMoreEpisodes(int nrOfEpisodes) {
|
||||
if(learning instanceof EpisodicLearning) {
|
||||
((EpisodicLearning) learning).learn(nrOfEpisodes);
|
||||
if(learning.isCurrentlyLearning()){
|
||||
((EpisodicLearning<A>) learning).learnMoreEpisodes(nrOfEpisodes);
|
||||
}else{
|
||||
new Thread(() -> {
|
||||
((EpisodicLearning<A>) learning).learn(nrOfEpisodes);
|
||||
}).start();
|
||||
}
|
||||
} else {
|
||||
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
|
||||
}
|
||||
|
@ -169,8 +175,8 @@ public class RLController<A extends Enum> implements LearningListener {
|
|||
public void onEpisodeEnd(List<Double> rewardHistory) {
|
||||
latestRewardsHistory = rewardHistory;
|
||||
if(printNextEpisode) {
|
||||
System.out.println("Episode " + ((EpisodicLearning) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1));
|
||||
System.out.println("Eps/sec: " + ((EpisodicLearning) learning).getEpisodePerSecond());
|
||||
System.out.println("Episode " + ((EpisodicLearning<A>) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1));
|
||||
System.out.println("Eps/sec: " + ((EpisodicLearning<A>) learning).getEpisodePerSecond());
|
||||
printNextEpisode = false;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -60,9 +60,9 @@ public class LearningInfoPanel extends JPanel {
|
|||
viewListener.onFastLearnChange(fastLearning);
|
||||
});
|
||||
smoothGraphCheckbox = new JCheckBox("Smoothen Graph");
|
||||
smoothGraphCheckbox.setSelected(false);
|
||||
smoothGraphCheckbox.setSelected(true);
|
||||
last100Checkbox = new JCheckBox("Only show last 100 Rewards");
|
||||
last100Checkbox.setSelected(true);
|
||||
last100Checkbox.setSelected(false);
|
||||
drawEnvironmentCheckbox = new JCheckBox("Update Environment");
|
||||
drawEnvironmentCheckbox.setSelected(true);
|
||||
|
||||
|
|
|
@ -1,27 +0,0 @@
|
|||
package example;
|
||||
|
||||
import core.RNG;
|
||||
import core.algo.Method;
|
||||
import core.controller.RLController;
|
||||
import evironment.jumpingDino.DinoAction;
|
||||
import evironment.jumpingDino.DinoWorld;
|
||||
|
||||
public class DinoSampling {
|
||||
public static void main(String[] args) {
|
||||
for (int i = 0; i < 10 ; i++) {
|
||||
RNG.setSeed(55);
|
||||
|
||||
RLController<DinoAction> rl = new RLController<>(
|
||||
new DinoWorld(false, false),
|
||||
Method.MC_CONTROL_FIRST_VISIT,
|
||||
DinoAction.values());
|
||||
|
||||
rl.setDelay(0);
|
||||
rl.setDiscountFactor(1f);
|
||||
rl.setEpsilon(0.15f);
|
||||
rl.setLearningRate(1f);
|
||||
rl.setNrOfEpisodes(400);
|
||||
rl.start();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -3,6 +3,7 @@ package example;
|
|||
import core.RNG;
|
||||
import core.algo.Method;
|
||||
import core.controller.RLController;
|
||||
import core.controller.RLControllerGUI;
|
||||
import evironment.jumpingDino.DinoAction;
|
||||
import evironment.jumpingDino.DinoWorld;
|
||||
|
||||
|
@ -10,16 +11,16 @@ public class JumpingDino {
|
|||
public static void main(String[] args) {
|
||||
RNG.setSeed(55);
|
||||
|
||||
RLController<DinoAction> rl = new RLController<>(
|
||||
RLController<DinoAction> rl = new RLControllerGUI<>(
|
||||
new DinoWorld(false, false),
|
||||
Method.MC_CONTROL_FIRST_VISIT,
|
||||
DinoAction.values());
|
||||
|
||||
rl.setDelay(0);
|
||||
rl.setDelay(100);
|
||||
rl.setDiscountFactor(1f);
|
||||
rl.setEpsilon(0.15f);
|
||||
rl.setLearningRate(1f);
|
||||
rl.setNrOfEpisodes(400);
|
||||
rl.setNrOfEpisodes(10000);
|
||||
rl.start();
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue