diff --git a/src/main/java/core/DeterministicStateActionTable.java b/src/main/java/core/DeterministicStateActionTable.java index 6c9d55d..e04c520 100644 --- a/src/main/java/core/DeterministicStateActionTable.java +++ b/src/main/java/core/DeterministicStateActionTable.java @@ -62,7 +62,7 @@ public class DeterministicStateActionTable implements StateActio private Map createDefaultActionValues(){ final Map defaultActionValues = new LinkedHashMap<>(); - for(A action: discreteActionSpace.getAllActions()){ + for(A action: discreteActionSpace){ defaultActionValues.put(action, DEFAULT_VALUE); } return defaultActionValues; diff --git a/src/main/java/core/DiscreteActionSpace.java b/src/main/java/core/DiscreteActionSpace.java index a5caf2b..91683f9 100644 --- a/src/main/java/core/DiscreteActionSpace.java +++ b/src/main/java/core/DiscreteActionSpace.java @@ -1,10 +1,7 @@ package core; -import java.util.List; - -public interface DiscreteActionSpace { +public interface DiscreteActionSpace extends Iterable { int getNumberOfActions(); void addAction(A a); void addActions(A... as); - List getAllActions(); } diff --git a/src/main/java/core/ListDiscreteActionSpace.java b/src/main/java/core/ListDiscreteActionSpace.java index 3e1b862..b456751 100644 --- a/src/main/java/core/ListDiscreteActionSpace.java +++ b/src/main/java/core/ListDiscreteActionSpace.java @@ -1,11 +1,10 @@ package core; import java.io.Serializable; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; +import java.util.*; -public class ListDiscreteActionSpace implements DiscreteActionSpace, Serializable { +public class ListDiscreteActionSpace implements DiscreteActionSpace, Serializable{ + private static final long serialVersionUID = 1L; private List actions; public ListDiscreteActionSpace(){ @@ -27,13 +26,13 @@ public class ListDiscreteActionSpace implements DiscreteActionSp actions.addAll(Arrays.asList(as)); } - @Override - public List getAllActions() { - return actions; - } - @Override public int getNumberOfActions(){ return actions.size(); } + + @Override + public Iterator iterator() { + return actions.iterator(); + } } diff --git a/src/main/java/core/SaveState.java b/src/main/java/core/SaveState.java index 140ac7a..8f0b976 100644 --- a/src/main/java/core/SaveState.java +++ b/src/main/java/core/SaveState.java @@ -8,6 +8,7 @@ import java.io.Serializable; @AllArgsConstructor @Getter public class SaveState implements Serializable { + private static final long serialVersionUID = 1L; private StateActionTable stateActionTable; private int currentEpisode; } diff --git a/src/main/java/core/algo/EpisodicLearning.java b/src/main/java/core/algo/EpisodicLearning.java index 9211224..ceb3756 100644 --- a/src/main/java/core/algo/EpisodicLearning.java +++ b/src/main/java/core/algo/EpisodicLearning.java @@ -3,12 +3,10 @@ package core.algo; import core.DiscreteActionSpace; import core.Environment; import core.listener.LearningListener; -import lombok.Getter; import lombok.Setter; public abstract class EpisodicLearning extends Learning implements Episodic { @Setter - @Getter protected int currentEpisode; protected int episodesToLearn; protected volatile int episodePerSecond; diff --git a/src/main/java/core/algo/Learning.java b/src/main/java/core/algo/Learning.java index a9c73c7..9583696 100644 --- a/src/main/java/core/algo/Learning.java +++ b/src/main/java/core/algo/Learning.java @@ -9,14 +9,13 @@ import core.policy.Policy; import lombok.Getter; import lombok.Setter; -import java.io.Serializable; import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.concurrent.CopyOnWriteArrayList; @Getter -public abstract class Learning implements Serializable { +public abstract class Learning{ protected Policy policy; protected DiscreteActionSpace actionSpace; @Setter diff --git a/src/main/java/core/controller/RLController.java b/src/main/java/core/controller/RLController.java index b733807..4cafe4e 100644 --- a/src/main/java/core/controller/RLController.java +++ b/src/main/java/core/controller/RLController.java @@ -18,31 +18,29 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; public class RLController implements ViewListener, LearningListener { - protected Environment environment; - protected Learning learning; - protected DiscreteActionSpace discreteActionSpace; - protected LearningView learningView; - private int nrOfEpisodes; + private Environment environment; + private DiscreteActionSpace discreteActionSpace; private Method method; - private int prevDelay; private int delay = LearningConfig.DEFAULT_DELAY; private float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR; private float epsilon = LearningConfig.DEFAULT_EPSILON; - private boolean fastLearning; - private boolean currentlyLearning; + private Learning learning; + private LearningView learningView; private ExecutorService learningExecutor; + private boolean currentlyLearning; + private boolean fastLearning; private List latestRewardsHistory; + private int nrOfEpisodes; + private int prevDelay; - public RLController(){ + public RLController(Environment env, Method method, A... actions){ learningExecutor = Executors.newSingleThreadExecutor(); + setEnvironment(env); + setMethod(method); + setAllowedActions(actions); } - public void start(){ - if(environment == null || discreteActionSpace == null || method == null){ - throw new RuntimeException("Set environment, discreteActionSpace and method before calling .start()"); - } - switch (method){ case MC_ONPOLICY_EGREEDY: learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay); @@ -50,13 +48,21 @@ public class RLController implements ViewListener, LearningListe case TD_ONPOLICY: break; default: - throw new RuntimeException("Undefined method"); + throw new IllegalArgumentException("Undefined method"); } + + initGUI(); + initLearning(); + } + + private void initGUI(){ SwingUtilities.invokeLater(()->{ learningView = new View<>(learning, environment, this); learning.addListener(this); }); + } + private void initLearning(){ if(learning instanceof EpisodicLearning){ learningExecutor.submit(()->((EpisodicLearning) learning).learn(nrOfEpisodes)); }else{ @@ -65,7 +71,7 @@ public class RLController implements ViewListener, LearningListe } /************************************************* - * VIEW LISTENERS * + ** VIEW LISTENERS ** *************************************************/ @Override public void onLearnMoreEpisodes(int nrOfEpisodes){ @@ -148,7 +154,7 @@ public class RLController implements ViewListener, LearningListe } /************************************************* - * LEARNING LISTENERS * + ** LEARNING LISTENERS ** *************************************************/ @Override public void onLearningStart() { @@ -185,37 +191,43 @@ public class RLController implements ViewListener, LearningListe } + /************************************************* + ** SETTER ** + *************************************************/ - public RLController setMethod(Method method){ - this.method = method; - return this; - } - public RLController setEnvironment(Environment environment){ + private void setEnvironment(Environment environment){ + if(environment == null){ + throw new IllegalArgumentException("Environment cannot be null"); + } this.environment = environment; - return this; } - @SafeVarargs - public final RLController setAllowedActions(A... actions){ + + private void setMethod(Method method){ + if(method == null){ + throw new IllegalArgumentException("Method cannot be null"); + } + this.method = method; + } + + private void setAllowedActions(A[] actions){ + if(actions == null || actions.length == 0){ + throw new IllegalArgumentException("There has to be at least one action"); + } this.discreteActionSpace = new ListDiscreteActionSpace<>(actions); - return this; } - public RLController setDelay(int delay){ + public void setDelay(int delay){ this.delay = delay; - return this; } - public RLController setEpisodes(int nrOfEpisodes){ + public void setEpisodes(int nrOfEpisodes){ this.nrOfEpisodes = nrOfEpisodes; - return this; } - public RLController setDiscountFactor(float discountFactor){ + public void setDiscountFactor(float discountFactor){ this.discountFactor = discountFactor; - return this; } - public RLController setEpsilon(float epsilon){ + public void setEpsilon(float epsilon){ this.epsilon = epsilon; - return this; } } diff --git a/src/main/java/example/JumpingDino.java b/src/main/java/example/JumpingDino.java index 87b1ceb..bddc8b2 100644 --- a/src/main/java/example/JumpingDino.java +++ b/src/main/java/example/JumpingDino.java @@ -10,14 +10,16 @@ public class JumpingDino { public static void main(String[] args) { RNG.setSeed(55); - RLController rl = new RLController() - .setEnvironment(new DinoWorld()) - .setAllowedActions(DinoAction.values()) - .setMethod(Method.MC_ONPOLICY_EGREEDY) - .setDiscountFactor(1f) - .setEpsilon(0.15f) - .setDelay(200) - .setEpisodes(100000); + RLController rl = new RLController<>( + new DinoWorld(), + Method.MC_ONPOLICY_EGREEDY, + DinoAction.values()); + + rl.setDelay(200); + rl.setDiscountFactor(1f); + rl.setEpsilon(0.15f); + rl.setEpisodes(5000); + rl.start(); } } diff --git a/src/main/java/example/RunningAnt.java b/src/main/java/example/RunningAnt.java index 307f107..bb2fa2e 100644 --- a/src/main/java/example/RunningAnt.java +++ b/src/main/java/example/RunningAnt.java @@ -10,12 +10,16 @@ public class RunningAnt { public static void main(String[] args) { RNG.setSeed(123); - RLController rl = new RLController() - .setEnvironment(new AntWorld(3,3,0.1)) - .setAllowedActions(AntAction.values()) - .setMethod(Method.MC_ONPOLICY_EGREEDY) - .setDelay(200) - .setEpisodes(100000); + RLController rl = new RLController<>( + new AntWorld(3, 3, 0.1), + Method.MC_ONPOLICY_EGREEDY, + AntAction.values()); + + rl.setDelay(200); + rl.setEpisodes(10000); + rl.setDiscountFactor(1f); + rl.setEpsilon(0.15f); + rl.start(); } }