diff --git a/src/main/java/core/DeterministicStateActionTable.java b/src/main/java/core/DeterministicStateActionTable.java
index 6c9d55d..e04c520 100644
--- a/src/main/java/core/DeterministicStateActionTable.java
+++ b/src/main/java/core/DeterministicStateActionTable.java
@@ -62,7 +62,7 @@ public class DeterministicStateActionTable implements StateActio
private Map createDefaultActionValues(){
final Map defaultActionValues = new LinkedHashMap<>();
- for(A action: discreteActionSpace.getAllActions()){
+ for(A action: discreteActionSpace){
defaultActionValues.put(action, DEFAULT_VALUE);
}
return defaultActionValues;
diff --git a/src/main/java/core/DiscreteActionSpace.java b/src/main/java/core/DiscreteActionSpace.java
index a5caf2b..91683f9 100644
--- a/src/main/java/core/DiscreteActionSpace.java
+++ b/src/main/java/core/DiscreteActionSpace.java
@@ -1,10 +1,7 @@
package core;
-import java.util.List;
-
-public interface DiscreteActionSpace {
+public interface DiscreteActionSpace extends Iterable {
int getNumberOfActions();
void addAction(A a);
void addActions(A... as);
- List getAllActions();
}
diff --git a/src/main/java/core/ListDiscreteActionSpace.java b/src/main/java/core/ListDiscreteActionSpace.java
index 3e1b862..b456751 100644
--- a/src/main/java/core/ListDiscreteActionSpace.java
+++ b/src/main/java/core/ListDiscreteActionSpace.java
@@ -1,11 +1,10 @@
package core;
import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
+import java.util.*;
-public class ListDiscreteActionSpace implements DiscreteActionSpace, Serializable {
+public class ListDiscreteActionSpace implements DiscreteActionSpace, Serializable{
+ private static final long serialVersionUID = 1L;
private List actions;
public ListDiscreteActionSpace(){
@@ -27,13 +26,13 @@ public class ListDiscreteActionSpace implements DiscreteActionSp
actions.addAll(Arrays.asList(as));
}
- @Override
- public List getAllActions() {
- return actions;
- }
-
@Override
public int getNumberOfActions(){
return actions.size();
}
+
+ @Override
+ public Iterator iterator() {
+ return actions.iterator();
+ }
}
diff --git a/src/main/java/core/SaveState.java b/src/main/java/core/SaveState.java
index 140ac7a..8f0b976 100644
--- a/src/main/java/core/SaveState.java
+++ b/src/main/java/core/SaveState.java
@@ -8,6 +8,7 @@ import java.io.Serializable;
@AllArgsConstructor
@Getter
public class SaveState implements Serializable {
+ private static final long serialVersionUID = 1L;
private StateActionTable stateActionTable;
private int currentEpisode;
}
diff --git a/src/main/java/core/algo/EpisodicLearning.java b/src/main/java/core/algo/EpisodicLearning.java
index 9211224..ceb3756 100644
--- a/src/main/java/core/algo/EpisodicLearning.java
+++ b/src/main/java/core/algo/EpisodicLearning.java
@@ -3,12 +3,10 @@ package core.algo;
import core.DiscreteActionSpace;
import core.Environment;
import core.listener.LearningListener;
-import lombok.Getter;
import lombok.Setter;
public abstract class EpisodicLearning extends Learning implements Episodic {
@Setter
- @Getter
protected int currentEpisode;
protected int episodesToLearn;
protected volatile int episodePerSecond;
diff --git a/src/main/java/core/algo/Learning.java b/src/main/java/core/algo/Learning.java
index a9c73c7..9583696 100644
--- a/src/main/java/core/algo/Learning.java
+++ b/src/main/java/core/algo/Learning.java
@@ -9,14 +9,13 @@ import core.policy.Policy;
import lombok.Getter;
import lombok.Setter;
-import java.io.Serializable;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArrayList;
@Getter
-public abstract class Learning implements Serializable {
+public abstract class Learning{
protected Policy policy;
protected DiscreteActionSpace actionSpace;
@Setter
diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java
index b733807..4cafe4e 100644
--- a/src/main/java/core/controller/RLController.java
+++ b/src/main/java/core/controller/RLController.java
@@ -18,31 +18,29 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
public class RLController implements ViewListener, LearningListener {
- protected Environment environment;
- protected Learning learning;
- protected DiscreteActionSpace discreteActionSpace;
- protected LearningView learningView;
- private int nrOfEpisodes;
+ private Environment environment;
+ private DiscreteActionSpace discreteActionSpace;
private Method method;
- private int prevDelay;
private int delay = LearningConfig.DEFAULT_DELAY;
private float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR;
private float epsilon = LearningConfig.DEFAULT_EPSILON;
- private boolean fastLearning;
- private boolean currentlyLearning;
+ private Learning learning;
+ private LearningView learningView;
private ExecutorService learningExecutor;
+ private boolean currentlyLearning;
+ private boolean fastLearning;
private List latestRewardsHistory;
+ private int nrOfEpisodes;
+ private int prevDelay;
- public RLController(){
+ public RLController(Environment env, Method method, A... actions){
learningExecutor = Executors.newSingleThreadExecutor();
+ setEnvironment(env);
+ setMethod(method);
+ setAllowedActions(actions);
}
-
public void start(){
- if(environment == null || discreteActionSpace == null || method == null){
- throw new RuntimeException("Set environment, discreteActionSpace and method before calling .start()");
- }
-
switch (method){
case MC_ONPOLICY_EGREEDY:
learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
@@ -50,13 +48,21 @@ public class RLController implements ViewListener, LearningListe
case TD_ONPOLICY:
break;
default:
- throw new RuntimeException("Undefined method");
+ throw new IllegalArgumentException("Undefined method");
}
+
+ initGUI();
+ initLearning();
+ }
+
+ private void initGUI(){
SwingUtilities.invokeLater(()->{
learningView = new View<>(learning, environment, this);
learning.addListener(this);
});
+ }
+ private void initLearning(){
if(learning instanceof EpisodicLearning){
learningExecutor.submit(()->((EpisodicLearning) learning).learn(nrOfEpisodes));
}else{
@@ -65,7 +71,7 @@ public class RLController implements ViewListener, LearningListe
}
/*************************************************
- * VIEW LISTENERS *
+ ** VIEW LISTENERS **
*************************************************/
@Override
public void onLearnMoreEpisodes(int nrOfEpisodes){
@@ -148,7 +154,7 @@ public class RLController implements ViewListener, LearningListe
}
/*************************************************
- * LEARNING LISTENERS *
+ ** LEARNING LISTENERS **
*************************************************/
@Override
public void onLearningStart() {
@@ -185,37 +191,43 @@ public class RLController implements ViewListener, LearningListe
}
+ /*************************************************
+ ** SETTER **
+ *************************************************/
- public RLController setMethod(Method method){
- this.method = method;
- return this;
- }
- public RLController setEnvironment(Environment environment){
+ private void setEnvironment(Environment environment){
+ if(environment == null){
+ throw new IllegalArgumentException("Environment cannot be null");
+ }
this.environment = environment;
- return this;
}
- @SafeVarargs
- public final RLController setAllowedActions(A... actions){
+
+ private void setMethod(Method method){
+ if(method == null){
+ throw new IllegalArgumentException("Method cannot be null");
+ }
+ this.method = method;
+ }
+
+ private void setAllowedActions(A[] actions){
+ if(actions == null || actions.length == 0){
+ throw new IllegalArgumentException("There has to be at least one action");
+ }
this.discreteActionSpace = new ListDiscreteActionSpace<>(actions);
- return this;
}
- public RLController setDelay(int delay){
+ public void setDelay(int delay){
this.delay = delay;
- return this;
}
- public RLController setEpisodes(int nrOfEpisodes){
+ public void setEpisodes(int nrOfEpisodes){
this.nrOfEpisodes = nrOfEpisodes;
- return this;
}
- public RLController setDiscountFactor(float discountFactor){
+ public void setDiscountFactor(float discountFactor){
this.discountFactor = discountFactor;
- return this;
}
- public RLController setEpsilon(float epsilon){
+ public void setEpsilon(float epsilon){
this.epsilon = epsilon;
- return this;
}
}
diff --git a/src/main/java/example/JumpingDino.java b/src/main/java/example/JumpingDino.java
index 87b1ceb..bddc8b2 100644
--- a/src/main/java/example/JumpingDino.java
+++ b/src/main/java/example/JumpingDino.java
@@ -10,14 +10,16 @@ public class JumpingDino {
public static void main(String[] args) {
RNG.setSeed(55);
- RLController rl = new RLController()
- .setEnvironment(new DinoWorld())
- .setAllowedActions(DinoAction.values())
- .setMethod(Method.MC_ONPOLICY_EGREEDY)
- .setDiscountFactor(1f)
- .setEpsilon(0.15f)
- .setDelay(200)
- .setEpisodes(100000);
+ RLController rl = new RLController<>(
+ new DinoWorld(),
+ Method.MC_ONPOLICY_EGREEDY,
+ DinoAction.values());
+
+ rl.setDelay(200);
+ rl.setDiscountFactor(1f);
+ rl.setEpsilon(0.15f);
+ rl.setEpisodes(5000);
+
rl.start();
}
}
diff --git a/src/main/java/example/RunningAnt.java b/src/main/java/example/RunningAnt.java
index 307f107..bb2fa2e 100644
--- a/src/main/java/example/RunningAnt.java
+++ b/src/main/java/example/RunningAnt.java
@@ -10,12 +10,16 @@ public class RunningAnt {
public static void main(String[] args) {
RNG.setSeed(123);
- RLController rl = new RLController()
- .setEnvironment(new AntWorld(3,3,0.1))
- .setAllowedActions(AntAction.values())
- .setMethod(Method.MC_ONPOLICY_EGREEDY)
- .setDelay(200)
- .setEpisodes(100000);
+ RLController rl = new RLController<>(
+ new AntWorld(3, 3, 0.1),
+ Method.MC_ONPOLICY_EGREEDY,
+ AntAction.values());
+
+ rl.setDelay(200);
+ rl.setEpisodes(10000);
+ rl.setDiscountFactor(1f);
+ rl.setEpsilon(0.15f);
+
rl.start();
}
}