diff --git a/.idea/compiler.xml b/.idea/compiler.xml
index 95a88ae..a1757ae 100644
--- a/.idea/compiler.xml
+++ b/.idea/compiler.xml
@@ -5,4 +5,4 @@
-
+
\ No newline at end of file
diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java
index 0fb589a..e2b9ec3 100644
--- a/src/main/java/core/controller/RLController.java
+++ b/src/main/java/core/controller/RLController.java
@@ -10,6 +10,7 @@ import core.gui.View;
import core.listener.LearningListener;
import core.listener.ViewListener;
import core.policy.EpsilonPolicy;
+import lombok.Setter;
import javax.swing.*;
import java.io.*;
@@ -18,24 +19,29 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
public class RLController implements ViewListener, LearningListener {
- private final String folderPrefix = "learningStates" + File.separator;
- private Environment environment;
- private DiscreteActionSpace discreteActionSpace;
- private Method method;
- private int delay = LearningConfig.DEFAULT_DELAY;
- private float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR;
- private float epsilon = LearningConfig.DEFAULT_EPSILON;
- private Learning learning;
- private LearningView learningView;
- private boolean fastLearning;
- private List latestRewardsHistory;
- private int nrOfEpisodes;
- private int prevDelay;
+ protected final String folderPrefix = "learningStates" + File.separator;
+ protected Environment environment;
+ protected DiscreteActionSpace discreteActionSpace;
+ protected Method method;
+ @Setter
+ protected int delay = LearningConfig.DEFAULT_DELAY;
+ @Setter
+ protected float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR;
+ @Setter
+ protected float epsilon = LearningConfig.DEFAULT_EPSILON;
+ protected Learning learning;
+ protected boolean fastLearning;
+ protected List latestRewardsHistory;
+ @Setter
+ protected int nrOfEpisodes;
+ protected int prevDelay;
+ protected volatile boolean printNextEpisode;
public RLController(Environment env, Method method, A... actions){
setEnvironment(env);
setMethod(method);
setAllowedActions(actions);
+ printNextEpisode = true;
}
public void start(){
@@ -48,20 +54,29 @@ public class RLController implements ViewListener, LearningListe
default:
throw new IllegalArgumentException("Undefined method");
}
-
- initGUI();
+ System.out.println("Initialized learning: " + learning.getClass());
+ initListeners();
+ System.out.println("Set listeners");
initLearning();
}
- private void initGUI(){
- SwingUtilities.invokeLater(()->{
- learningView = new View<>(learning, environment, this);
- learning.addListener(this);
- });
+ protected void initListeners(){
+ learning.addListener(this);
+ new Thread(() -> {
+ while (true){
+ printNextEpisode = true;
+ try {
+ Thread.sleep(30*1000);
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ }
+ }).start();
}
private void initLearning(){
if(learning instanceof EpisodicLearning){
+ System.out.println("Starting learning of <" + nrOfEpisodes + "> episodes");
((EpisodicLearning) learning).learn(nrOfEpisodes);
}else{
learning.learn();
@@ -78,7 +93,6 @@ public class RLController implements ViewListener, LearningListe
}else{
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
}
- learningView.updateLearningInfoPanel();
}
@Override
@@ -88,10 +102,9 @@ public class RLController implements ViewListener, LearningListe
try {
fis = new FileInputStream(fileName);
in = new ObjectInputStream(fis);
- System.out.println("interrup" + Thread.currentThread().getId());
+ System.out.println("interrupt" + Thread.currentThread().getId());
learning.interruptLearning();
learning.load(in);
- SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
in.close();
} catch (IOException | ClassNotFoundException e) {
e.printStackTrace();
@@ -117,7 +130,6 @@ public class RLController implements ViewListener, LearningListe
public void onEpsilonChange(float epsilon) {
if(learning.getPolicy() instanceof EpsilonPolicy){
((EpsilonPolicy) learning.getPolicy()).setEpsilon(epsilon);
- SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
}else{
System.out.println("Trying to call inEpsilonChange on non-epsilon policy");
}
@@ -128,9 +140,8 @@ public class RLController implements ViewListener, LearningListe
changeLearningDelay(delay);
}
- private void changeLearningDelay(int delay){
+ protected void changeLearningDelay(int delay){
learning.setDelay(delay);
- SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
}
@Override
@@ -153,19 +164,8 @@ public class RLController implements ViewListener, LearningListe
@Override
public void onLearningEnd() {
- SwingUtilities.invokeLater(()-> learningView.updateRewardGraph(latestRewardsHistory));
- onSaveState( method.toString() + System.currentTimeMillis()/1000 + (learning instanceof EpisodicLearning ? "e " + ((EpisodicLearning) learning).getCurrentEpisode() : ""));
- }
-
- @Override
- public void onEpisodeEnd(List rewardHistory) {
- latestRewardsHistory = rewardHistory;
- SwingUtilities.invokeLater(() ->{
- if(!fastLearning){
- learningView.updateRewardGraph(latestRewardsHistory);
- }
- learningView.updateLearningInfoPanel();
- });
+ System.out.println("Learning finished");
+ onSaveState( method.toString() + System.currentTimeMillis()/1000 + (learning instanceof EpisodicLearning ? "e" + ((EpisodicLearning) learning).getCurrentEpisode() : ""));
}
@Override
@@ -174,12 +174,19 @@ public class RLController implements ViewListener, LearningListe
}
@Override
- public void onStepEnd() {
- if(!fastLearning){
- SwingUtilities.invokeLater(() -> learningView.repaintEnvironment());
+ public void onEpisodeEnd(List rewardHistory) {
+ latestRewardsHistory = rewardHistory;
+ if(printNextEpisode){
+ System.out.println("Episode " + ((EpisodicLearning) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size()-1));
+ System.out.println("Eps/sec: " + ((EpisodicLearning) learning).getEpisodePerSecond());
+ printNextEpisode = false;
}
}
+ @Override
+ public void onStepEnd() {
+ }
+
/*************************************************
** SETTERS **
@@ -205,19 +212,4 @@ public class RLController implements ViewListener, LearningListe
}
this.discreteActionSpace = new ListDiscreteActionSpace<>(actions);
}
-
- public void setDelay(int delay){
- this.delay = delay;
- }
-
- public void setEpisodes(int nrOfEpisodes){
- this.nrOfEpisodes = nrOfEpisodes;
- }
-
- public void setDiscountFactor(float discountFactor){
- this.discountFactor = discountFactor;
- }
- public void setEpsilon(float epsilon){
- this.epsilon = epsilon;
- }
}
diff --git a/src/main/java/core/controller/RLControllerGUI.java b/src/main/java/core/controller/RLControllerGUI.java
new file mode 100644
index 0000000..567ad54
--- /dev/null
+++ b/src/main/java/core/controller/RLControllerGUI.java
@@ -0,0 +1,73 @@
+package core.controller;
+
+import core.Environment;
+import core.algo.Method;
+import core.gui.LearningView;
+import core.gui.View;
+
+import javax.swing.*;
+import java.util.List;
+
+public class RLControllerGUI extends RLController {
+ private LearningView learningView;
+
+ public RLControllerGUI(Environment env, Method method, A... actions) {
+ super(env, method, actions);
+ }
+
+ @Override
+ protected void initListeners() {
+ SwingUtilities.invokeLater(() -> {
+ learningView = new View<>(learning, environment, this);
+ learning.addListener(this);
+ });
+ }
+
+ @Override
+ public void onLearnMoreEpisodes(int nrOfEpisodes) {
+ super.onLearnMoreEpisodes(nrOfEpisodes);
+ learningView.updateLearningInfoPanel();
+ }
+
+ @Override
+ public void onLoadState(String fileName) {
+ super.onLoadState(fileName);
+ SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
+ }
+
+ @Override
+ public void onEpsilonChange(float epsilon) {
+ super.onEpsilonChange(epsilon);
+ SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
+ }
+
+ @Override
+ protected void changeLearningDelay(int delay) {
+ super.changeLearningDelay(delay);
+ SwingUtilities.invokeLater(() -> learningView.updateLearningInfoPanel());
+ }
+
+ @Override
+ public void onLearningEnd() {
+ super.onLearningEnd();
+ SwingUtilities.invokeLater(() -> learningView.updateRewardGraph(latestRewardsHistory));
+ }
+
+ @Override
+ public void onEpisodeEnd(List rewardHistory) {
+ super.onEpisodeEnd(rewardHistory);
+ SwingUtilities.invokeLater(() -> {
+ if (!fastLearning) {
+ learningView.updateRewardGraph(latestRewardsHistory);
+ }
+ learningView.updateLearningInfoPanel();
+ });
+ }
+
+ @Override
+ public void onStepEnd() {
+ if (!fastLearning) {
+ SwingUtilities.invokeLater(() -> learningView.repaintEnvironment());
+ }
+ }
+}
diff --git a/src/main/java/example/JumpingDino.java b/src/main/java/example/JumpingDino.java
index 2e0a447..7cdd445 100644
--- a/src/main/java/example/JumpingDino.java
+++ b/src/main/java/example/JumpingDino.java
@@ -3,6 +3,7 @@ package example;
import core.RNG;
import core.algo.Method;
import core.controller.RLController;
+import core.controller.RLControllerGUI;
import evironment.jumpingDino.DinoAction;
import evironment.jumpingDino.DinoWorld;
@@ -10,15 +11,15 @@ public class JumpingDino {
public static void main(String[] args) {
RNG.setSeed(55);
- RLController rl = new RLController<>(
+ RLController rl = new RLControllerGUI<>(
new DinoWorld(true, true),
Method.MC_ONPOLICY_EGREEDY,
DinoAction.values());
- rl.setDelay(200);
+ rl.setDelay(0);
rl.setDiscountFactor(1f);
rl.setEpsilon(0.15f);
- rl.setEpisodes(5000);
+ rl.setNrOfEpisodes(100000);
rl.start();
}
diff --git a/src/main/java/example/RunningAnt.java b/src/main/java/example/RunningAnt.java
index bb2fa2e..ade0e92 100644
--- a/src/main/java/example/RunningAnt.java
+++ b/src/main/java/example/RunningAnt.java
@@ -16,7 +16,7 @@ public class RunningAnt {
AntAction.values());
rl.setDelay(200);
- rl.setEpisodes(10000);
+ rl.setNrOfEpisodes(10000);
rl.setDiscountFactor(1f);
rl.setEpsilon(0.15f);