change RL-Controller initialization process and action space iterable

- no fake builder pattern anymore, moved needed fields into constructor
- add serializeUID
- action space extends iterable interface to simplify looping over all actions (and not returning the actual list)
This commit is contained in:
Jan Löwenstrom 2019-12-24 19:38:35 +01:00
parent 5a4e380faf
commit b2c3854b3a
9 changed files with 78 additions and 66 deletions

View File

@ -62,7 +62,7 @@ public class DeterministicStateActionTable<A extends Enum> implements StateActio
private Map<A, Double> createDefaultActionValues(){
final Map<A, Double> defaultActionValues = new LinkedHashMap<>();
for(A action: discreteActionSpace.getAllActions()){
for(A action: discreteActionSpace){
defaultActionValues.put(action, DEFAULT_VALUE);
}
return defaultActionValues;

View File

@ -1,10 +1,7 @@
package core;
import java.util.List;
public interface DiscreteActionSpace<A extends Enum> {
public interface DiscreteActionSpace<A extends Enum> extends Iterable<A> {
int getNumberOfActions();
void addAction(A a);
void addActions(A... as);
List<A> getAllActions();
}

View File

@ -1,11 +1,10 @@
package core;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.*;
public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSpace<A>, Serializable {
public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSpace<A>, Serializable{
private static final long serialVersionUID = 1L;
private List<A> actions;
public ListDiscreteActionSpace(){
@ -27,13 +26,13 @@ public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSp
actions.addAll(Arrays.asList(as));
}
@Override
public List<A> getAllActions() {
return actions;
}
@Override
public int getNumberOfActions(){
return actions.size();
}
@Override
public Iterator<A> iterator() {
return actions.iterator();
}
}

View File

@ -8,6 +8,7 @@ import java.io.Serializable;
@AllArgsConstructor
@Getter
public class SaveState<A extends Enum> implements Serializable {
private static final long serialVersionUID = 1L;
private StateActionTable<A> stateActionTable;
private int currentEpisode;
}

View File

@ -3,12 +3,10 @@ package core.algo;
import core.DiscreteActionSpace;
import core.Environment;
import core.listener.LearningListener;
import lombok.Getter;
import lombok.Setter;
public abstract class EpisodicLearning<A extends Enum> extends Learning<A> implements Episodic {
@Setter
@Getter
protected int currentEpisode;
protected int episodesToLearn;
protected volatile int episodePerSecond;

View File

@ -9,14 +9,13 @@ import core.policy.Policy;
import lombok.Getter;
import lombok.Setter;
import java.io.Serializable;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArrayList;
@Getter
public abstract class Learning<A extends Enum> implements Serializable {
public abstract class Learning<A extends Enum>{
protected Policy<A> policy;
protected DiscreteActionSpace<A> actionSpace;
@Setter

View File

@ -18,31 +18,29 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
public class RLController<A extends Enum> implements ViewListener, LearningListener {
protected Environment<A> environment;
protected Learning<A> learning;
protected DiscreteActionSpace<A> discreteActionSpace;
protected LearningView learningView;
private int nrOfEpisodes;
private Environment<A> environment;
private DiscreteActionSpace<A> discreteActionSpace;
private Method method;
private int prevDelay;
private int delay = LearningConfig.DEFAULT_DELAY;
private float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR;
private float epsilon = LearningConfig.DEFAULT_EPSILON;
private boolean fastLearning;
private boolean currentlyLearning;
private Learning<A> learning;
private LearningView learningView;
private ExecutorService learningExecutor;
private boolean currentlyLearning;
private boolean fastLearning;
private List<Double> latestRewardsHistory;
private int nrOfEpisodes;
private int prevDelay;
public RLController(){
public RLController(Environment<A> env, Method method, A... actions){
learningExecutor = Executors.newSingleThreadExecutor();
setEnvironment(env);
setMethod(method);
setAllowedActions(actions);
}
public void start(){
if(environment == null || discreteActionSpace == null || method == null){
throw new RuntimeException("Set environment, discreteActionSpace and method before calling .start()");
}
switch (method){
case MC_ONPOLICY_EGREEDY:
learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
@ -50,13 +48,21 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
case TD_ONPOLICY:
break;
default:
throw new RuntimeException("Undefined method");
throw new IllegalArgumentException("Undefined method");
}
initGUI();
initLearning();
}
private void initGUI(){
SwingUtilities.invokeLater(()->{
learningView = new View<>(learning, environment, this);
learning.addListener(this);
});
}
private void initLearning(){
if(learning instanceof EpisodicLearning){
learningExecutor.submit(()->((EpisodicLearning) learning).learn(nrOfEpisodes));
}else{
@ -65,7 +71,7 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
}
/*************************************************
* VIEW LISTENERS *
** VIEW LISTENERS **
*************************************************/
@Override
public void onLearnMoreEpisodes(int nrOfEpisodes){
@ -148,7 +154,7 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
}
/*************************************************
* LEARNING LISTENERS *
** LEARNING LISTENERS **
*************************************************/
@Override
public void onLearningStart() {
@ -185,37 +191,43 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
}
/*************************************************
** SETTER **
*************************************************/
public RLController<A> setMethod(Method method){
this.method = method;
return this;
}
public RLController<A> setEnvironment(Environment<A> environment){
private void setEnvironment(Environment<A> environment){
if(environment == null){
throw new IllegalArgumentException("Environment cannot be null");
}
this.environment = environment;
return this;
}
@SafeVarargs
public final RLController<A> setAllowedActions(A... actions){
private void setMethod(Method method){
if(method == null){
throw new IllegalArgumentException("Method cannot be null");
}
this.method = method;
}
private void setAllowedActions(A[] actions){
if(actions == null || actions.length == 0){
throw new IllegalArgumentException("There has to be at least one action");
}
this.discreteActionSpace = new ListDiscreteActionSpace<>(actions);
return this;
}
public RLController<A> setDelay(int delay){
public void setDelay(int delay){
this.delay = delay;
return this;
}
public RLController<A> setEpisodes(int nrOfEpisodes){
public void setEpisodes(int nrOfEpisodes){
this.nrOfEpisodes = nrOfEpisodes;
return this;
}
public RLController<A> setDiscountFactor(float discountFactor){
public void setDiscountFactor(float discountFactor){
this.discountFactor = discountFactor;
return this;
}
public RLController<A> setEpsilon(float epsilon){
public void setEpsilon(float epsilon){
this.epsilon = epsilon;
return this;
}
}

View File

@ -10,14 +10,16 @@ public class JumpingDino {
public static void main(String[] args) {
RNG.setSeed(55);
RLController<DinoAction> rl = new RLController<DinoAction>()
.setEnvironment(new DinoWorld())
.setAllowedActions(DinoAction.values())
.setMethod(Method.MC_ONPOLICY_EGREEDY)
.setDiscountFactor(1f)
.setEpsilon(0.15f)
.setDelay(200)
.setEpisodes(100000);
RLController<DinoAction> rl = new RLController<>(
new DinoWorld(),
Method.MC_ONPOLICY_EGREEDY,
DinoAction.values());
rl.setDelay(200);
rl.setDiscountFactor(1f);
rl.setEpsilon(0.15f);
rl.setEpisodes(5000);
rl.start();
}
}

View File

@ -10,12 +10,16 @@ public class RunningAnt {
public static void main(String[] args) {
RNG.setSeed(123);
RLController<AntAction> rl = new RLController<AntAction>()
.setEnvironment(new AntWorld(3,3,0.1))
.setAllowedActions(AntAction.values())
.setMethod(Method.MC_ONPOLICY_EGREEDY)
.setDelay(200)
.setEpisodes(100000);
RLController<AntAction> rl = new RLController<>(
new AntWorld(3, 3, 0.1),
Method.MC_ONPOLICY_EGREEDY,
AntAction.values());
rl.setDelay(200);
rl.setEpisodes(10000);
rl.setDiscountFactor(1f);
rl.setEpsilon(0.15f);
rl.start();
}
}