apply threading changes to master branch and clean up for tag version
- no testing or epsilon testing stuff
This commit is contained in:
parent
6613e23c7c
commit
cffec63dc6
|
@ -50,7 +50,7 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
|
||||||
|
|
||||||
private void initBenchMarking(){
|
private void initBenchMarking(){
|
||||||
new Thread(()->{
|
new Thread(()->{
|
||||||
while (true){
|
while (currentlyLearning){
|
||||||
episodePerSecond = episodeSumCurrentSecond;
|
episodePerSecond = episodeSumCurrentSecond;
|
||||||
episodeSumCurrentSecond = 0;
|
episodeSumCurrentSecond = 0;
|
||||||
try {
|
try {
|
||||||
|
@ -62,7 +62,7 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
|
||||||
}).start();
|
}).start();
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void dispatchEpisodeEnd(){
|
private void dispatchEpisodeEnd(){
|
||||||
++episodeSumCurrentSecond;
|
++episodeSumCurrentSecond;
|
||||||
if(rewardHistory.size() > 10000){
|
if(rewardHistory.size() > 10000){
|
||||||
rewardHistory.clear();
|
rewardHistory.clear();
|
||||||
|
@ -75,18 +75,6 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
|
||||||
|
|
||||||
protected void dispatchEpisodeStart(){
|
protected void dispatchEpisodeStart(){
|
||||||
++currentEpisode;
|
++currentEpisode;
|
||||||
/*
|
|
||||||
2f 0.02 => 100
|
|
||||||
1.5f 0.02 => 75
|
|
||||||
1.4f 0.02 => fail
|
|
||||||
1.5f 0.1 => 16 !
|
|
||||||
*/
|
|
||||||
if(this.policy instanceof EpsilonGreedyPolicy){
|
|
||||||
float ep = 1.5f/(float)currentEpisode;
|
|
||||||
if(ep < 0.10) ep = 0;
|
|
||||||
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(ep);
|
|
||||||
System.out.println(ep);
|
|
||||||
}
|
|
||||||
episodesToLearn.decrementAndGet();
|
episodesToLearn.decrementAndGet();
|
||||||
for(LearningListener l: learningListeners){
|
for(LearningListener l: learningListeners){
|
||||||
l.onEpisodeStart();
|
l.onEpisodeStart();
|
||||||
|
@ -97,31 +85,24 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
|
||||||
protected void dispatchStepEnd() {
|
protected void dispatchStepEnd() {
|
||||||
super.dispatchStepEnd();
|
super.dispatchStepEnd();
|
||||||
timestamp++;
|
timestamp++;
|
||||||
// TODO: more sophisticated way to check convergence
|
|
||||||
if(timestamp > 300000){
|
|
||||||
System.out.println("converged after: " + currentEpisode + " episode!");
|
|
||||||
interruptLearning();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void learn(){
|
|
||||||
learn(LearningConfig.DEFAULT_NR_OF_EPISODES);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void startLearning(){
|
private void startLearning(){
|
||||||
learningExecutor.submit(()->{
|
dispatchLearningStart();
|
||||||
dispatchLearningStart();
|
System.out.println(episodesToLearn.get());
|
||||||
while(episodesToLearn.get() > 0){
|
while(episodesToLearn.get() > 0){
|
||||||
dispatchEpisodeStart();
|
dispatchEpisodeStart();
|
||||||
nextEpisode();
|
nextEpisode();
|
||||||
dispatchEpisodeEnd();
|
dispatchEpisodeEnd();
|
||||||
}
|
}
|
||||||
synchronized (this){
|
synchronized (this){
|
||||||
dispatchLearningEnd();
|
dispatchLearningEnd();
|
||||||
notifyAll();
|
notifyAll();
|
||||||
}
|
}
|
||||||
});
|
}
|
||||||
|
|
||||||
|
public void learnMoreEpisodes(int nrOfEpisodes){
|
||||||
|
episodesToLearn.addAndGet(nrOfEpisodes);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -146,8 +127,14 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
|
||||||
delay = prevDelay;
|
delay = prevDelay;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void learn(){
|
||||||
|
learn(LearningConfig.DEFAULT_NR_OF_EPISODES);
|
||||||
|
}
|
||||||
|
|
||||||
public synchronized void learn(int nrOfEpisodes){
|
public synchronized void learn(int nrOfEpisodes){
|
||||||
boolean isLearning = episodesToLearn.getAndAdd(nrOfEpisodes) != 0;
|
boolean isLearning = episodesToLearn.getAndAdd(nrOfEpisodes) != 0;
|
||||||
|
System.out.println(isLearning);
|
||||||
if(!isLearning)
|
if(!isLearning)
|
||||||
startLearning();
|
startLearning();
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,8 +42,7 @@ public abstract class Learning<A extends Enum>{
|
||||||
@Setter
|
@Setter
|
||||||
protected int delay;
|
protected int delay;
|
||||||
protected List<Double> rewardHistory;
|
protected List<Double> rewardHistory;
|
||||||
protected ExecutorService learningExecutor;
|
protected volatile boolean currentlyLearning;
|
||||||
protected boolean currentlyLearning;
|
|
||||||
|
|
||||||
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay) {
|
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay) {
|
||||||
this.environment = environment;
|
this.environment = environment;
|
||||||
|
@ -53,7 +52,6 @@ public abstract class Learning<A extends Enum>{
|
||||||
currentlyLearning = false;
|
currentlyLearning = false;
|
||||||
learningListeners = new HashSet<>();
|
learningListeners = new HashSet<>();
|
||||||
rewardHistory = new CopyOnWriteArrayList<>();
|
rewardHistory = new CopyOnWriteArrayList<>();
|
||||||
learningExecutor = Executors.newSingleThreadExecutor();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor) {
|
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor) {
|
||||||
|
@ -89,8 +87,6 @@ public abstract class Learning<A extends Enum>{
|
||||||
|
|
||||||
protected void dispatchLearningEnd() {
|
protected void dispatchLearningEnd() {
|
||||||
currentlyLearning = false;
|
currentlyLearning = false;
|
||||||
System.out.println("Checksum: " + checkSum);
|
|
||||||
System.out.println("Reward Checksum: " + rewardCheckSum);
|
|
||||||
for (LearningListener l : learningListeners) {
|
for (LearningListener l : learningListeners) {
|
||||||
l.onLearningEnd();
|
l.onLearningEnd();
|
||||||
}
|
}
|
||||||
|
|
|
@ -83,7 +83,7 @@ public class RLController<A extends Enum> implements LearningListener {
|
||||||
private void initLearning() {
|
private void initLearning() {
|
||||||
if(learning instanceof EpisodicLearning) {
|
if(learning instanceof EpisodicLearning) {
|
||||||
System.out.println("Starting learning of <" + nrOfEpisodes + "> episodes");
|
System.out.println("Starting learning of <" + nrOfEpisodes + "> episodes");
|
||||||
((EpisodicLearning) learning).learn(nrOfEpisodes);
|
((EpisodicLearning<A>) learning).learn(nrOfEpisodes);
|
||||||
} else {
|
} else {
|
||||||
learning.learn();
|
learning.learn();
|
||||||
}
|
}
|
||||||
|
@ -95,7 +95,13 @@ public class RLController<A extends Enum> implements LearningListener {
|
||||||
|
|
||||||
protected void learnMoreEpisodes(int nrOfEpisodes) {
|
protected void learnMoreEpisodes(int nrOfEpisodes) {
|
||||||
if(learning instanceof EpisodicLearning) {
|
if(learning instanceof EpisodicLearning) {
|
||||||
((EpisodicLearning) learning).learn(nrOfEpisodes);
|
if(learning.isCurrentlyLearning()){
|
||||||
|
((EpisodicLearning<A>) learning).learnMoreEpisodes(nrOfEpisodes);
|
||||||
|
}else{
|
||||||
|
new Thread(() -> {
|
||||||
|
((EpisodicLearning<A>) learning).learn(nrOfEpisodes);
|
||||||
|
}).start();
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
|
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
|
||||||
}
|
}
|
||||||
|
@ -169,8 +175,8 @@ public class RLController<A extends Enum> implements LearningListener {
|
||||||
public void onEpisodeEnd(List<Double> rewardHistory) {
|
public void onEpisodeEnd(List<Double> rewardHistory) {
|
||||||
latestRewardsHistory = rewardHistory;
|
latestRewardsHistory = rewardHistory;
|
||||||
if(printNextEpisode) {
|
if(printNextEpisode) {
|
||||||
System.out.println("Episode " + ((EpisodicLearning) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1));
|
System.out.println("Episode " + ((EpisodicLearning<A>) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1));
|
||||||
System.out.println("Eps/sec: " + ((EpisodicLearning) learning).getEpisodePerSecond());
|
System.out.println("Eps/sec: " + ((EpisodicLearning<A>) learning).getEpisodePerSecond());
|
||||||
printNextEpisode = false;
|
printNextEpisode = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -60,9 +60,9 @@ public class LearningInfoPanel extends JPanel {
|
||||||
viewListener.onFastLearnChange(fastLearning);
|
viewListener.onFastLearnChange(fastLearning);
|
||||||
});
|
});
|
||||||
smoothGraphCheckbox = new JCheckBox("Smoothen Graph");
|
smoothGraphCheckbox = new JCheckBox("Smoothen Graph");
|
||||||
smoothGraphCheckbox.setSelected(false);
|
smoothGraphCheckbox.setSelected(true);
|
||||||
last100Checkbox = new JCheckBox("Only show last 100 Rewards");
|
last100Checkbox = new JCheckBox("Only show last 100 Rewards");
|
||||||
last100Checkbox.setSelected(true);
|
last100Checkbox.setSelected(false);
|
||||||
drawEnvironmentCheckbox = new JCheckBox("Update Environment");
|
drawEnvironmentCheckbox = new JCheckBox("Update Environment");
|
||||||
drawEnvironmentCheckbox.setSelected(true);
|
drawEnvironmentCheckbox.setSelected(true);
|
||||||
|
|
||||||
|
|
|
@ -1,27 +0,0 @@
|
||||||
package example;
|
|
||||||
|
|
||||||
import core.RNG;
|
|
||||||
import core.algo.Method;
|
|
||||||
import core.controller.RLController;
|
|
||||||
import evironment.jumpingDino.DinoAction;
|
|
||||||
import evironment.jumpingDino.DinoWorld;
|
|
||||||
|
|
||||||
public class DinoSampling {
|
|
||||||
public static void main(String[] args) {
|
|
||||||
for (int i = 0; i < 10 ; i++) {
|
|
||||||
RNG.setSeed(55);
|
|
||||||
|
|
||||||
RLController<DinoAction> rl = new RLController<>(
|
|
||||||
new DinoWorld(false, false),
|
|
||||||
Method.MC_CONTROL_FIRST_VISIT,
|
|
||||||
DinoAction.values());
|
|
||||||
|
|
||||||
rl.setDelay(0);
|
|
||||||
rl.setDiscountFactor(1f);
|
|
||||||
rl.setEpsilon(0.15f);
|
|
||||||
rl.setLearningRate(1f);
|
|
||||||
rl.setNrOfEpisodes(400);
|
|
||||||
rl.start();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -3,6 +3,7 @@ package example;
|
||||||
import core.RNG;
|
import core.RNG;
|
||||||
import core.algo.Method;
|
import core.algo.Method;
|
||||||
import core.controller.RLController;
|
import core.controller.RLController;
|
||||||
|
import core.controller.RLControllerGUI;
|
||||||
import evironment.jumpingDino.DinoAction;
|
import evironment.jumpingDino.DinoAction;
|
||||||
import evironment.jumpingDino.DinoWorld;
|
import evironment.jumpingDino.DinoWorld;
|
||||||
|
|
||||||
|
@ -10,16 +11,16 @@ public class JumpingDino {
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
RNG.setSeed(55);
|
RNG.setSeed(55);
|
||||||
|
|
||||||
RLController<DinoAction> rl = new RLController<>(
|
RLController<DinoAction> rl = new RLControllerGUI<>(
|
||||||
new DinoWorld(false, false),
|
new DinoWorld(false, false),
|
||||||
Method.MC_CONTROL_FIRST_VISIT,
|
Method.MC_CONTROL_FIRST_VISIT,
|
||||||
DinoAction.values());
|
DinoAction.values());
|
||||||
|
|
||||||
rl.setDelay(0);
|
rl.setDelay(100);
|
||||||
rl.setDiscountFactor(1f);
|
rl.setDiscountFactor(1f);
|
||||||
rl.setEpsilon(0.15f);
|
rl.setEpsilon(0.15f);
|
||||||
rl.setLearningRate(1f);
|
rl.setLearningRate(1f);
|
||||||
rl.setNrOfEpisodes(400);
|
rl.setNrOfEpisodes(10000);
|
||||||
rl.start();
|
rl.start();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue