change RL-Controller initialization process and action space iterable
- no fake builder pattern anymore, moved needed fields into constructor - add serializeUID - action space extends iterable interface to simplify looping over all actions (and not returning the actual list)
This commit is contained in:
parent
5a4e380faf
commit
b2c3854b3a
|
@ -62,7 +62,7 @@ public class DeterministicStateActionTable<A extends Enum> implements StateActio
|
|||
|
||||
private Map<A, Double> createDefaultActionValues(){
|
||||
final Map<A, Double> defaultActionValues = new LinkedHashMap<>();
|
||||
for(A action: discreteActionSpace.getAllActions()){
|
||||
for(A action: discreteActionSpace){
|
||||
defaultActionValues.put(action, DEFAULT_VALUE);
|
||||
}
|
||||
return defaultActionValues;
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
package core;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface DiscreteActionSpace<A extends Enum> {
|
||||
public interface DiscreteActionSpace<A extends Enum> extends Iterable<A> {
|
||||
int getNumberOfActions();
|
||||
void addAction(A a);
|
||||
void addActions(A... as);
|
||||
List<A> getAllActions();
|
||||
}
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
package core;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.*;
|
||||
|
||||
public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSpace<A>, Serializable {
|
||||
public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSpace<A>, Serializable{
|
||||
private static final long serialVersionUID = 1L;
|
||||
private List<A> actions;
|
||||
|
||||
public ListDiscreteActionSpace(){
|
||||
|
@ -27,13 +26,13 @@ public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSp
|
|||
actions.addAll(Arrays.asList(as));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<A> getAllActions() {
|
||||
return actions;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getNumberOfActions(){
|
||||
return actions.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<A> iterator() {
|
||||
return actions.iterator();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import java.io.Serializable;
|
|||
@AllArgsConstructor
|
||||
@Getter
|
||||
public class SaveState<A extends Enum> implements Serializable {
|
||||
private static final long serialVersionUID = 1L;
|
||||
private StateActionTable<A> stateActionTable;
|
||||
private int currentEpisode;
|
||||
}
|
||||
|
|
|
@ -3,12 +3,10 @@ package core.algo;
|
|||
import core.DiscreteActionSpace;
|
||||
import core.Environment;
|
||||
import core.listener.LearningListener;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
public abstract class EpisodicLearning<A extends Enum> extends Learning<A> implements Episodic {
|
||||
@Setter
|
||||
@Getter
|
||||
protected int currentEpisode;
|
||||
protected int episodesToLearn;
|
||||
protected volatile int episodePerSecond;
|
||||
|
|
|
@ -9,14 +9,13 @@ import core.policy.Policy;
|
|||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
|
||||
@Getter
|
||||
public abstract class Learning<A extends Enum> implements Serializable {
|
||||
public abstract class Learning<A extends Enum>{
|
||||
protected Policy<A> policy;
|
||||
protected DiscreteActionSpace<A> actionSpace;
|
||||
@Setter
|
||||
|
|
|
@ -18,31 +18,29 @@ import java.util.concurrent.ExecutorService;
|
|||
import java.util.concurrent.Executors;
|
||||
|
||||
public class RLController<A extends Enum> implements ViewListener, LearningListener {
|
||||
protected Environment<A> environment;
|
||||
protected Learning<A> learning;
|
||||
protected DiscreteActionSpace<A> discreteActionSpace;
|
||||
protected LearningView learningView;
|
||||
private int nrOfEpisodes;
|
||||
private Environment<A> environment;
|
||||
private DiscreteActionSpace<A> discreteActionSpace;
|
||||
private Method method;
|
||||
private int prevDelay;
|
||||
private int delay = LearningConfig.DEFAULT_DELAY;
|
||||
private float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR;
|
||||
private float epsilon = LearningConfig.DEFAULT_EPSILON;
|
||||
private boolean fastLearning;
|
||||
private boolean currentlyLearning;
|
||||
private Learning<A> learning;
|
||||
private LearningView learningView;
|
||||
private ExecutorService learningExecutor;
|
||||
private boolean currentlyLearning;
|
||||
private boolean fastLearning;
|
||||
private List<Double> latestRewardsHistory;
|
||||
private int nrOfEpisodes;
|
||||
private int prevDelay;
|
||||
|
||||
public RLController(){
|
||||
public RLController(Environment<A> env, Method method, A... actions){
|
||||
learningExecutor = Executors.newSingleThreadExecutor();
|
||||
setEnvironment(env);
|
||||
setMethod(method);
|
||||
setAllowedActions(actions);
|
||||
}
|
||||
|
||||
|
||||
public void start(){
|
||||
if(environment == null || discreteActionSpace == null || method == null){
|
||||
throw new RuntimeException("Set environment, discreteActionSpace and method before calling .start()");
|
||||
}
|
||||
|
||||
switch (method){
|
||||
case MC_ONPOLICY_EGREEDY:
|
||||
learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
|
||||
|
@ -50,13 +48,21 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
|||
case TD_ONPOLICY:
|
||||
break;
|
||||
default:
|
||||
throw new RuntimeException("Undefined method");
|
||||
throw new IllegalArgumentException("Undefined method");
|
||||
}
|
||||
|
||||
initGUI();
|
||||
initLearning();
|
||||
}
|
||||
|
||||
private void initGUI(){
|
||||
SwingUtilities.invokeLater(()->{
|
||||
learningView = new View<>(learning, environment, this);
|
||||
learning.addListener(this);
|
||||
});
|
||||
}
|
||||
|
||||
private void initLearning(){
|
||||
if(learning instanceof EpisodicLearning){
|
||||
learningExecutor.submit(()->((EpisodicLearning) learning).learn(nrOfEpisodes));
|
||||
}else{
|
||||
|
@ -65,7 +71,7 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
|||
}
|
||||
|
||||
/*************************************************
|
||||
* VIEW LISTENERS *
|
||||
** VIEW LISTENERS **
|
||||
*************************************************/
|
||||
@Override
|
||||
public void onLearnMoreEpisodes(int nrOfEpisodes){
|
||||
|
@ -148,7 +154,7 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
|||
}
|
||||
|
||||
/*************************************************
|
||||
* LEARNING LISTENERS *
|
||||
** LEARNING LISTENERS **
|
||||
*************************************************/
|
||||
@Override
|
||||
public void onLearningStart() {
|
||||
|
@ -185,37 +191,43 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
|||
}
|
||||
|
||||
|
||||
/*************************************************
|
||||
** SETTER **
|
||||
*************************************************/
|
||||
|
||||
public RLController<A> setMethod(Method method){
|
||||
this.method = method;
|
||||
return this;
|
||||
}
|
||||
public RLController<A> setEnvironment(Environment<A> environment){
|
||||
private void setEnvironment(Environment<A> environment){
|
||||
if(environment == null){
|
||||
throw new IllegalArgumentException("Environment cannot be null");
|
||||
}
|
||||
this.environment = environment;
|
||||
return this;
|
||||
}
|
||||
@SafeVarargs
|
||||
public final RLController<A> setAllowedActions(A... actions){
|
||||
|
||||
private void setMethod(Method method){
|
||||
if(method == null){
|
||||
throw new IllegalArgumentException("Method cannot be null");
|
||||
}
|
||||
this.method = method;
|
||||
}
|
||||
|
||||
private void setAllowedActions(A[] actions){
|
||||
if(actions == null || actions.length == 0){
|
||||
throw new IllegalArgumentException("There has to be at least one action");
|
||||
}
|
||||
this.discreteActionSpace = new ListDiscreteActionSpace<>(actions);
|
||||
return this;
|
||||
}
|
||||
|
||||
public RLController<A> setDelay(int delay){
|
||||
public void setDelay(int delay){
|
||||
this.delay = delay;
|
||||
return this;
|
||||
}
|
||||
|
||||
public RLController<A> setEpisodes(int nrOfEpisodes){
|
||||
public void setEpisodes(int nrOfEpisodes){
|
||||
this.nrOfEpisodes = nrOfEpisodes;
|
||||
return this;
|
||||
}
|
||||
|
||||
public RLController<A> setDiscountFactor(float discountFactor){
|
||||
public void setDiscountFactor(float discountFactor){
|
||||
this.discountFactor = discountFactor;
|
||||
return this;
|
||||
}
|
||||
public RLController<A> setEpsilon(float epsilon){
|
||||
public void setEpsilon(float epsilon){
|
||||
this.epsilon = epsilon;
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,14 +10,16 @@ public class JumpingDino {
|
|||
public static void main(String[] args) {
|
||||
RNG.setSeed(55);
|
||||
|
||||
RLController<DinoAction> rl = new RLController<DinoAction>()
|
||||
.setEnvironment(new DinoWorld())
|
||||
.setAllowedActions(DinoAction.values())
|
||||
.setMethod(Method.MC_ONPOLICY_EGREEDY)
|
||||
.setDiscountFactor(1f)
|
||||
.setEpsilon(0.15f)
|
||||
.setDelay(200)
|
||||
.setEpisodes(100000);
|
||||
RLController<DinoAction> rl = new RLController<>(
|
||||
new DinoWorld(),
|
||||
Method.MC_ONPOLICY_EGREEDY,
|
||||
DinoAction.values());
|
||||
|
||||
rl.setDelay(200);
|
||||
rl.setDiscountFactor(1f);
|
||||
rl.setEpsilon(0.15f);
|
||||
rl.setEpisodes(5000);
|
||||
|
||||
rl.start();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,12 +10,16 @@ public class RunningAnt {
|
|||
public static void main(String[] args) {
|
||||
RNG.setSeed(123);
|
||||
|
||||
RLController<AntAction> rl = new RLController<AntAction>()
|
||||
.setEnvironment(new AntWorld(3,3,0.1))
|
||||
.setAllowedActions(AntAction.values())
|
||||
.setMethod(Method.MC_ONPOLICY_EGREEDY)
|
||||
.setDelay(200)
|
||||
.setEpisodes(100000);
|
||||
RLController<AntAction> rl = new RLController<>(
|
||||
new AntWorld(3, 3, 0.1),
|
||||
Method.MC_ONPOLICY_EGREEDY,
|
||||
AntAction.values());
|
||||
|
||||
rl.setDelay(200);
|
||||
rl.setEpisodes(10000);
|
||||
rl.setDiscountFactor(1f);
|
||||
rl.setEpsilon(0.15f);
|
||||
|
||||
rl.start();
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue