change RL-Controller initialization process and action space iterable
- no fake builder pattern anymore, moved needed fields into constructor - add serializeUID - action space extends iterable interface to simplify looping over all actions (and not returning the actual list)
This commit is contained in:
parent
5a4e380faf
commit
b2c3854b3a
|
@ -62,7 +62,7 @@ public class DeterministicStateActionTable<A extends Enum> implements StateActio
|
||||||
|
|
||||||
private Map<A, Double> createDefaultActionValues(){
|
private Map<A, Double> createDefaultActionValues(){
|
||||||
final Map<A, Double> defaultActionValues = new LinkedHashMap<>();
|
final Map<A, Double> defaultActionValues = new LinkedHashMap<>();
|
||||||
for(A action: discreteActionSpace.getAllActions()){
|
for(A action: discreteActionSpace){
|
||||||
defaultActionValues.put(action, DEFAULT_VALUE);
|
defaultActionValues.put(action, DEFAULT_VALUE);
|
||||||
}
|
}
|
||||||
return defaultActionValues;
|
return defaultActionValues;
|
||||||
|
|
|
@ -1,10 +1,7 @@
|
||||||
package core;
|
package core;
|
||||||
|
|
||||||
import java.util.List;
|
public interface DiscreteActionSpace<A extends Enum> extends Iterable<A> {
|
||||||
|
|
||||||
public interface DiscreteActionSpace<A extends Enum> {
|
|
||||||
int getNumberOfActions();
|
int getNumberOfActions();
|
||||||
void addAction(A a);
|
void addAction(A a);
|
||||||
void addActions(A... as);
|
void addActions(A... as);
|
||||||
List<A> getAllActions();
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,11 +1,10 @@
|
||||||
package core;
|
package core;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.ArrayList;
|
import java.util.*;
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSpace<A>, Serializable{
|
public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSpace<A>, Serializable{
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
private List<A> actions;
|
private List<A> actions;
|
||||||
|
|
||||||
public ListDiscreteActionSpace(){
|
public ListDiscreteActionSpace(){
|
||||||
|
@ -27,13 +26,13 @@ public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSp
|
||||||
actions.addAll(Arrays.asList(as));
|
actions.addAll(Arrays.asList(as));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<A> getAllActions() {
|
|
||||||
return actions;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int getNumberOfActions(){
|
public int getNumberOfActions(){
|
||||||
return actions.size();
|
return actions.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Iterator<A> iterator() {
|
||||||
|
return actions.iterator();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ import java.io.Serializable;
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
@Getter
|
@Getter
|
||||||
public class SaveState<A extends Enum> implements Serializable {
|
public class SaveState<A extends Enum> implements Serializable {
|
||||||
|
private static final long serialVersionUID = 1L;
|
||||||
private StateActionTable<A> stateActionTable;
|
private StateActionTable<A> stateActionTable;
|
||||||
private int currentEpisode;
|
private int currentEpisode;
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,12 +3,10 @@ package core.algo;
|
||||||
import core.DiscreteActionSpace;
|
import core.DiscreteActionSpace;
|
||||||
import core.Environment;
|
import core.Environment;
|
||||||
import core.listener.LearningListener;
|
import core.listener.LearningListener;
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
|
|
||||||
public abstract class EpisodicLearning<A extends Enum> extends Learning<A> implements Episodic {
|
public abstract class EpisodicLearning<A extends Enum> extends Learning<A> implements Episodic {
|
||||||
@Setter
|
@Setter
|
||||||
@Getter
|
|
||||||
protected int currentEpisode;
|
protected int currentEpisode;
|
||||||
protected int episodesToLearn;
|
protected int episodesToLearn;
|
||||||
protected volatile int episodePerSecond;
|
protected volatile int episodePerSecond;
|
||||||
|
|
|
@ -9,14 +9,13 @@ import core.policy.Policy;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.concurrent.CopyOnWriteArrayList;
|
import java.util.concurrent.CopyOnWriteArrayList;
|
||||||
|
|
||||||
@Getter
|
@Getter
|
||||||
public abstract class Learning<A extends Enum> implements Serializable {
|
public abstract class Learning<A extends Enum>{
|
||||||
protected Policy<A> policy;
|
protected Policy<A> policy;
|
||||||
protected DiscreteActionSpace<A> actionSpace;
|
protected DiscreteActionSpace<A> actionSpace;
|
||||||
@Setter
|
@Setter
|
||||||
|
|
|
@ -18,31 +18,29 @@ import java.util.concurrent.ExecutorService;
|
||||||
import java.util.concurrent.Executors;
|
import java.util.concurrent.Executors;
|
||||||
|
|
||||||
public class RLController<A extends Enum> implements ViewListener, LearningListener {
|
public class RLController<A extends Enum> implements ViewListener, LearningListener {
|
||||||
protected Environment<A> environment;
|
private Environment<A> environment;
|
||||||
protected Learning<A> learning;
|
private DiscreteActionSpace<A> discreteActionSpace;
|
||||||
protected DiscreteActionSpace<A> discreteActionSpace;
|
|
||||||
protected LearningView learningView;
|
|
||||||
private int nrOfEpisodes;
|
|
||||||
private Method method;
|
private Method method;
|
||||||
private int prevDelay;
|
|
||||||
private int delay = LearningConfig.DEFAULT_DELAY;
|
private int delay = LearningConfig.DEFAULT_DELAY;
|
||||||
private float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR;
|
private float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR;
|
||||||
private float epsilon = LearningConfig.DEFAULT_EPSILON;
|
private float epsilon = LearningConfig.DEFAULT_EPSILON;
|
||||||
private boolean fastLearning;
|
private Learning<A> learning;
|
||||||
private boolean currentlyLearning;
|
private LearningView learningView;
|
||||||
private ExecutorService learningExecutor;
|
private ExecutorService learningExecutor;
|
||||||
|
private boolean currentlyLearning;
|
||||||
|
private boolean fastLearning;
|
||||||
private List<Double> latestRewardsHistory;
|
private List<Double> latestRewardsHistory;
|
||||||
|
private int nrOfEpisodes;
|
||||||
|
private int prevDelay;
|
||||||
|
|
||||||
public RLController(){
|
public RLController(Environment<A> env, Method method, A... actions){
|
||||||
learningExecutor = Executors.newSingleThreadExecutor();
|
learningExecutor = Executors.newSingleThreadExecutor();
|
||||||
|
setEnvironment(env);
|
||||||
|
setMethod(method);
|
||||||
|
setAllowedActions(actions);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public void start(){
|
public void start(){
|
||||||
if(environment == null || discreteActionSpace == null || method == null){
|
|
||||||
throw new RuntimeException("Set environment, discreteActionSpace and method before calling .start()");
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (method){
|
switch (method){
|
||||||
case MC_ONPOLICY_EGREEDY:
|
case MC_ONPOLICY_EGREEDY:
|
||||||
learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
|
learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay);
|
||||||
|
@ -50,13 +48,21 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
||||||
case TD_ONPOLICY:
|
case TD_ONPOLICY:
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw new RuntimeException("Undefined method");
|
throw new IllegalArgumentException("Undefined method");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
initGUI();
|
||||||
|
initLearning();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void initGUI(){
|
||||||
SwingUtilities.invokeLater(()->{
|
SwingUtilities.invokeLater(()->{
|
||||||
learningView = new View<>(learning, environment, this);
|
learningView = new View<>(learning, environment, this);
|
||||||
learning.addListener(this);
|
learning.addListener(this);
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
private void initLearning(){
|
||||||
if(learning instanceof EpisodicLearning){
|
if(learning instanceof EpisodicLearning){
|
||||||
learningExecutor.submit(()->((EpisodicLearning) learning).learn(nrOfEpisodes));
|
learningExecutor.submit(()->((EpisodicLearning) learning).learn(nrOfEpisodes));
|
||||||
}else{
|
}else{
|
||||||
|
@ -65,7 +71,7 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
||||||
}
|
}
|
||||||
|
|
||||||
/*************************************************
|
/*************************************************
|
||||||
* VIEW LISTENERS *
|
** VIEW LISTENERS **
|
||||||
*************************************************/
|
*************************************************/
|
||||||
@Override
|
@Override
|
||||||
public void onLearnMoreEpisodes(int nrOfEpisodes){
|
public void onLearnMoreEpisodes(int nrOfEpisodes){
|
||||||
|
@ -148,7 +154,7 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
||||||
}
|
}
|
||||||
|
|
||||||
/*************************************************
|
/*************************************************
|
||||||
* LEARNING LISTENERS *
|
** LEARNING LISTENERS **
|
||||||
*************************************************/
|
*************************************************/
|
||||||
@Override
|
@Override
|
||||||
public void onLearningStart() {
|
public void onLearningStart() {
|
||||||
|
@ -185,37 +191,43 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/*************************************************
|
||||||
|
** SETTER **
|
||||||
|
*************************************************/
|
||||||
|
|
||||||
public RLController<A> setMethod(Method method){
|
private void setEnvironment(Environment<A> environment){
|
||||||
this.method = method;
|
if(environment == null){
|
||||||
return this;
|
throw new IllegalArgumentException("Environment cannot be null");
|
||||||
}
|
}
|
||||||
public RLController<A> setEnvironment(Environment<A> environment){
|
|
||||||
this.environment = environment;
|
this.environment = environment;
|
||||||
return this;
|
|
||||||
}
|
}
|
||||||
@SafeVarargs
|
|
||||||
public final RLController<A> setAllowedActions(A... actions){
|
private void setMethod(Method method){
|
||||||
|
if(method == null){
|
||||||
|
throw new IllegalArgumentException("Method cannot be null");
|
||||||
|
}
|
||||||
|
this.method = method;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void setAllowedActions(A[] actions){
|
||||||
|
if(actions == null || actions.length == 0){
|
||||||
|
throw new IllegalArgumentException("There has to be at least one action");
|
||||||
|
}
|
||||||
this.discreteActionSpace = new ListDiscreteActionSpace<>(actions);
|
this.discreteActionSpace = new ListDiscreteActionSpace<>(actions);
|
||||||
return this;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public RLController<A> setDelay(int delay){
|
public void setDelay(int delay){
|
||||||
this.delay = delay;
|
this.delay = delay;
|
||||||
return this;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public RLController<A> setEpisodes(int nrOfEpisodes){
|
public void setEpisodes(int nrOfEpisodes){
|
||||||
this.nrOfEpisodes = nrOfEpisodes;
|
this.nrOfEpisodes = nrOfEpisodes;
|
||||||
return this;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public RLController<A> setDiscountFactor(float discountFactor){
|
public void setDiscountFactor(float discountFactor){
|
||||||
this.discountFactor = discountFactor;
|
this.discountFactor = discountFactor;
|
||||||
return this;
|
|
||||||
}
|
}
|
||||||
public RLController<A> setEpsilon(float epsilon){
|
public void setEpsilon(float epsilon){
|
||||||
this.epsilon = epsilon;
|
this.epsilon = epsilon;
|
||||||
return this;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,14 +10,16 @@ public class JumpingDino {
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
RNG.setSeed(55);
|
RNG.setSeed(55);
|
||||||
|
|
||||||
RLController<DinoAction> rl = new RLController<DinoAction>()
|
RLController<DinoAction> rl = new RLController<>(
|
||||||
.setEnvironment(new DinoWorld())
|
new DinoWorld(),
|
||||||
.setAllowedActions(DinoAction.values())
|
Method.MC_ONPOLICY_EGREEDY,
|
||||||
.setMethod(Method.MC_ONPOLICY_EGREEDY)
|
DinoAction.values());
|
||||||
.setDiscountFactor(1f)
|
|
||||||
.setEpsilon(0.15f)
|
rl.setDelay(200);
|
||||||
.setDelay(200)
|
rl.setDiscountFactor(1f);
|
||||||
.setEpisodes(100000);
|
rl.setEpsilon(0.15f);
|
||||||
|
rl.setEpisodes(5000);
|
||||||
|
|
||||||
rl.start();
|
rl.start();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,12 +10,16 @@ public class RunningAnt {
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
RNG.setSeed(123);
|
RNG.setSeed(123);
|
||||||
|
|
||||||
RLController<AntAction> rl = new RLController<AntAction>()
|
RLController<AntAction> rl = new RLController<>(
|
||||||
.setEnvironment(new AntWorld(3,3,0.1))
|
new AntWorld(3, 3, 0.1),
|
||||||
.setAllowedActions(AntAction.values())
|
Method.MC_ONPOLICY_EGREEDY,
|
||||||
.setMethod(Method.MC_ONPOLICY_EGREEDY)
|
AntAction.values());
|
||||||
.setDelay(200)
|
|
||||||
.setEpisodes(100000);
|
rl.setDelay(200);
|
||||||
|
rl.setEpisodes(10000);
|
||||||
|
rl.setDiscountFactor(1f);
|
||||||
|
rl.setEpsilon(0.15f);
|
||||||
|
|
||||||
rl.start();
|
rl.start();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue