change RL-Controller initialization process and action space iterable

- no fake builder pattern anymore, moved needed fields into constructor
- add serializeUID
- action space extends iterable interface to simplify looping over all actions (and not returning the actual list)
This commit is contained in:
Jan Löwenstrom 2019-12-24 19:38:35 +01:00
parent 5a4e380faf
commit b2c3854b3a
9 changed files with 78 additions and 66 deletions

View File

@ -62,7 +62,7 @@ public class DeterministicStateActionTable<A extends Enum> implements StateActio
private Map<A, Double> createDefaultActionValues(){ private Map<A, Double> createDefaultActionValues(){
final Map<A, Double> defaultActionValues = new LinkedHashMap<>(); final Map<A, Double> defaultActionValues = new LinkedHashMap<>();
for(A action: discreteActionSpace.getAllActions()){ for(A action: discreteActionSpace){
defaultActionValues.put(action, DEFAULT_VALUE); defaultActionValues.put(action, DEFAULT_VALUE);
} }
return defaultActionValues; return defaultActionValues;

View File

@ -1,10 +1,7 @@
package core; package core;
import java.util.List; public interface DiscreteActionSpace<A extends Enum> extends Iterable<A> {
public interface DiscreteActionSpace<A extends Enum> {
int getNumberOfActions(); int getNumberOfActions();
void addAction(A a); void addAction(A a);
void addActions(A... as); void addActions(A... as);
List<A> getAllActions();
} }

View File

@ -1,11 +1,10 @@
package core; package core;
import java.io.Serializable; import java.io.Serializable;
import java.util.ArrayList; import java.util.*;
import java.util.Arrays;
import java.util.List;
public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSpace<A>, Serializable{ public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSpace<A>, Serializable{
private static final long serialVersionUID = 1L;
private List<A> actions; private List<A> actions;
public ListDiscreteActionSpace(){ public ListDiscreteActionSpace(){
@ -27,13 +26,13 @@ public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSp
actions.addAll(Arrays.asList(as)); actions.addAll(Arrays.asList(as));
} }
@Override
public List<A> getAllActions() {
return actions;
}
@Override @Override
public int getNumberOfActions(){ public int getNumberOfActions(){
return actions.size(); return actions.size();
} }
@Override
public Iterator<A> iterator() {
return actions.iterator();
}
} }

View File

@ -8,6 +8,7 @@ import java.io.Serializable;
@AllArgsConstructor @AllArgsConstructor
@Getter @Getter
public class SaveState<A extends Enum> implements Serializable { public class SaveState<A extends Enum> implements Serializable {
private static final long serialVersionUID = 1L;
private StateActionTable<A> stateActionTable; private StateActionTable<A> stateActionTable;
private int currentEpisode; private int currentEpisode;
} }

View File

@ -3,12 +3,10 @@ package core.algo;
import core.DiscreteActionSpace; import core.DiscreteActionSpace;
import core.Environment; import core.Environment;
import core.listener.LearningListener; import core.listener.LearningListener;
import lombok.Getter;
import lombok.Setter; import lombok.Setter;
public abstract class EpisodicLearning<A extends Enum> extends Learning<A> implements Episodic { public abstract class EpisodicLearning<A extends Enum> extends Learning<A> implements Episodic {
@Setter @Setter
@Getter
protected int currentEpisode; protected int currentEpisode;
protected int episodesToLearn; protected int episodesToLearn;
protected volatile int episodePerSecond; protected volatile int episodePerSecond;

View File

@ -9,14 +9,13 @@ import core.policy.Policy;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import java.io.Serializable;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CopyOnWriteArrayList;
@Getter @Getter
public abstract class Learning<A extends Enum> implements Serializable { public abstract class Learning<A extends Enum>{
protected Policy<A> policy; protected Policy<A> policy;
protected DiscreteActionSpace<A> actionSpace; protected DiscreteActionSpace<A> actionSpace;
@Setter @Setter

View File

@ -18,31 +18,29 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
public class RLController<A extends Enum> implements ViewListener, LearningListener { public class RLController<A extends Enum> implements ViewListener, LearningListener {
protected Environment<A> environment; private Environment<A> environment;
protected Learning<A> learning; private DiscreteActionSpace<A> discreteActionSpace;
protected DiscreteActionSpace<A> discreteActionSpace;
protected LearningView learningView;
private int nrOfEpisodes;
private Method method; private Method method;
private int prevDelay;
private int delay = LearningConfig.DEFAULT_DELAY; private int delay = LearningConfig.DEFAULT_DELAY;
private float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR; private float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR;
private float epsilon = LearningConfig.DEFAULT_EPSILON; private float epsilon = LearningConfig.DEFAULT_EPSILON;
private boolean fastLearning; private Learning<A> learning;
private boolean currentlyLearning; private LearningView learningView;
private ExecutorService learningExecutor; private ExecutorService learningExecutor;
private boolean currentlyLearning;
private boolean fastLearning;
private List<Double> latestRewardsHistory; private List<Double> latestRewardsHistory;
private int nrOfEpisodes;
private int prevDelay;
public RLController(){ public RLController(Environment<A> env, Method method, A... actions){
learningExecutor = Executors.newSingleThreadExecutor(); learningExecutor = Executors.newSingleThreadExecutor();
setEnvironment(env);
setMethod(method);
setAllowedActions(actions);
} }
public void start(){ public void start(){
if(environment == null || discreteActionSpace == null || method == null){
throw new RuntimeException("Set environment, discreteActionSpace and method before calling .start()");
}
switch (method){ switch (method){
case MC_ONPOLICY_EGREEDY: case MC_ONPOLICY_EGREEDY:
learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay); learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
@ -50,13 +48,21 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
case TD_ONPOLICY: case TD_ONPOLICY:
break; break;
default: default:
throw new RuntimeException("Undefined method"); throw new IllegalArgumentException("Undefined method");
} }
initGUI();
initLearning();
}
private void initGUI(){
SwingUtilities.invokeLater(()->{ SwingUtilities.invokeLater(()->{
learningView = new View<>(learning, environment, this); learningView = new View<>(learning, environment, this);
learning.addListener(this); learning.addListener(this);
}); });
}
private void initLearning(){
if(learning instanceof EpisodicLearning){ if(learning instanceof EpisodicLearning){
learningExecutor.submit(()->((EpisodicLearning) learning).learn(nrOfEpisodes)); learningExecutor.submit(()->((EpisodicLearning) learning).learn(nrOfEpisodes));
}else{ }else{
@ -65,7 +71,7 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
} }
/************************************************* /*************************************************
* VIEW LISTENERS * ** VIEW LISTENERS **
*************************************************/ *************************************************/
@Override @Override
public void onLearnMoreEpisodes(int nrOfEpisodes){ public void onLearnMoreEpisodes(int nrOfEpisodes){
@ -148,7 +154,7 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
} }
/************************************************* /*************************************************
* LEARNING LISTENERS * ** LEARNING LISTENERS **
*************************************************/ *************************************************/
@Override @Override
public void onLearningStart() { public void onLearningStart() {
@ -185,37 +191,43 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
} }
/*************************************************
** SETTER **
*************************************************/
public RLController<A> setMethod(Method method){ private void setEnvironment(Environment<A> environment){
this.method = method; if(environment == null){
return this; throw new IllegalArgumentException("Environment cannot be null");
} }
public RLController<A> setEnvironment(Environment<A> environment){
this.environment = environment; this.environment = environment;
return this;
} }
@SafeVarargs
public final RLController<A> setAllowedActions(A... actions){ private void setMethod(Method method){
if(method == null){
throw new IllegalArgumentException("Method cannot be null");
}
this.method = method;
}
private void setAllowedActions(A[] actions){
if(actions == null || actions.length == 0){
throw new IllegalArgumentException("There has to be at least one action");
}
this.discreteActionSpace = new ListDiscreteActionSpace<>(actions); this.discreteActionSpace = new ListDiscreteActionSpace<>(actions);
return this;
} }
public RLController<A> setDelay(int delay){ public void setDelay(int delay){
this.delay = delay; this.delay = delay;
return this;
} }
public RLController<A> setEpisodes(int nrOfEpisodes){ public void setEpisodes(int nrOfEpisodes){
this.nrOfEpisodes = nrOfEpisodes; this.nrOfEpisodes = nrOfEpisodes;
return this;
} }
public RLController<A> setDiscountFactor(float discountFactor){ public void setDiscountFactor(float discountFactor){
this.discountFactor = discountFactor; this.discountFactor = discountFactor;
return this;
} }
public RLController<A> setEpsilon(float epsilon){ public void setEpsilon(float epsilon){
this.epsilon = epsilon; this.epsilon = epsilon;
return this;
} }
} }

View File

@ -10,14 +10,16 @@ public class JumpingDino {
public static void main(String[] args) { public static void main(String[] args) {
RNG.setSeed(55); RNG.setSeed(55);
RLController<DinoAction> rl = new RLController<DinoAction>() RLController<DinoAction> rl = new RLController<>(
.setEnvironment(new DinoWorld()) new DinoWorld(),
.setAllowedActions(DinoAction.values()) Method.MC_ONPOLICY_EGREEDY,
.setMethod(Method.MC_ONPOLICY_EGREEDY) DinoAction.values());
.setDiscountFactor(1f)
.setEpsilon(0.15f) rl.setDelay(200);
.setDelay(200) rl.setDiscountFactor(1f);
.setEpisodes(100000); rl.setEpsilon(0.15f);
rl.setEpisodes(5000);
rl.start(); rl.start();
} }
} }

View File

@ -10,12 +10,16 @@ public class RunningAnt {
public static void main(String[] args) { public static void main(String[] args) {
RNG.setSeed(123); RNG.setSeed(123);
RLController<AntAction> rl = new RLController<AntAction>() RLController<AntAction> rl = new RLController<>(
.setEnvironment(new AntWorld(3,3,0.1)) new AntWorld(3, 3, 0.1),
.setAllowedActions(AntAction.values()) Method.MC_ONPOLICY_EGREEDY,
.setMethod(Method.MC_ONPOLICY_EGREEDY) AntAction.values());
.setDelay(200)
.setEpisodes(100000); rl.setDelay(200);
rl.setEpisodes(10000);
rl.setDiscountFactor(1f);
rl.setEpsilon(0.15f);
rl.start(); rl.start();
} }
} }