change RL-Controller initialization process and action space iterable
- no fake builder pattern anymore, moved needed fields into constructor - add serializeUID - action space extends iterable interface to simplify looping over all actions (and not returning the actual list)
This commit is contained in:
		
							parent
							
								
									5a4e380faf
								
							
						
					
					
						commit
						b2c3854b3a
					
				|  | @ -62,7 +62,7 @@ public class DeterministicStateActionTable<A extends Enum> implements StateActio | |||
| 
 | ||||
|     private Map<A, Double> createDefaultActionValues(){ | ||||
|         final Map<A, Double> defaultActionValues = new LinkedHashMap<>(); | ||||
|         for(A action: discreteActionSpace.getAllActions()){ | ||||
|         for(A action: discreteActionSpace){ | ||||
|             defaultActionValues.put(action, DEFAULT_VALUE); | ||||
|         } | ||||
|         return defaultActionValues; | ||||
|  |  | |||
|  | @ -1,10 +1,7 @@ | |||
| package core; | ||||
| 
 | ||||
| import java.util.List; | ||||
| 
 | ||||
| public interface DiscreteActionSpace<A extends Enum> { | ||||
| public interface DiscreteActionSpace<A extends Enum> extends Iterable<A> { | ||||
|     int getNumberOfActions(); | ||||
|     void addAction(A a); | ||||
|     void addActions(A... as); | ||||
|     List<A> getAllActions(); | ||||
| } | ||||
|  |  | |||
|  | @ -1,11 +1,10 @@ | |||
| package core; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| import java.util.ArrayList; | ||||
| import java.util.Arrays; | ||||
| import java.util.List; | ||||
| import java.util.*; | ||||
| 
 | ||||
| public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSpace<A>, Serializable { | ||||
| public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSpace<A>, Serializable{ | ||||
|     private static final long serialVersionUID = 1L; | ||||
|     private List<A> actions; | ||||
| 
 | ||||
|     public ListDiscreteActionSpace(){ | ||||
|  | @ -27,13 +26,13 @@ public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSp | |||
|         actions.addAll(Arrays.asList(as)); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public List<A> getAllActions() { | ||||
|         return actions; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public int getNumberOfActions(){ | ||||
|         return actions.size(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Iterator<A> iterator() { | ||||
|         return actions.iterator(); | ||||
|     } | ||||
| } | ||||
|  |  | |||
|  | @ -8,6 +8,7 @@ import java.io.Serializable; | |||
| @AllArgsConstructor | ||||
| @Getter | ||||
| public class SaveState<A extends Enum> implements Serializable { | ||||
|     private static final long serialVersionUID = 1L; | ||||
|     private StateActionTable<A> stateActionTable; | ||||
|     private int currentEpisode; | ||||
| } | ||||
|  |  | |||
|  | @ -3,12 +3,10 @@ package core.algo; | |||
| import core.DiscreteActionSpace; | ||||
| import core.Environment; | ||||
| import core.listener.LearningListener; | ||||
| import lombok.Getter; | ||||
| import lombok.Setter; | ||||
| 
 | ||||
| public abstract class EpisodicLearning<A extends Enum> extends Learning<A> implements Episodic { | ||||
|     @Setter | ||||
|     @Getter | ||||
|     protected int currentEpisode; | ||||
|     protected int episodesToLearn; | ||||
|     protected volatile int episodePerSecond; | ||||
|  |  | |||
|  | @ -9,14 +9,13 @@ import core.policy.Policy; | |||
| import lombok.Getter; | ||||
| import lombok.Setter; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| import java.util.HashSet; | ||||
| import java.util.List; | ||||
| import java.util.Set; | ||||
| import java.util.concurrent.CopyOnWriteArrayList; | ||||
| 
 | ||||
| @Getter | ||||
| public abstract class Learning<A extends Enum> implements Serializable { | ||||
| public abstract class Learning<A extends Enum>{ | ||||
|     protected Policy<A> policy; | ||||
|     protected DiscreteActionSpace<A> actionSpace; | ||||
|     @Setter | ||||
|  |  | |||
|  | @ -18,31 +18,29 @@ import java.util.concurrent.ExecutorService; | |||
| import java.util.concurrent.Executors; | ||||
| 
 | ||||
| public class RLController<A extends Enum> implements ViewListener, LearningListener { | ||||
|     protected Environment<A> environment; | ||||
|     protected Learning<A> learning; | ||||
|     protected DiscreteActionSpace<A> discreteActionSpace; | ||||
|     protected LearningView learningView; | ||||
|     private int nrOfEpisodes; | ||||
|     private Environment<A> environment; | ||||
|     private DiscreteActionSpace<A> discreteActionSpace; | ||||
|     private Method method; | ||||
|     private int prevDelay; | ||||
|     private int delay = LearningConfig.DEFAULT_DELAY; | ||||
|     private float discountFactor = LearningConfig.DEFAULT_DISCOUNT_FACTOR; | ||||
|     private float epsilon = LearningConfig.DEFAULT_EPSILON; | ||||
|     private boolean fastLearning; | ||||
|     private boolean currentlyLearning; | ||||
|     private Learning<A> learning; | ||||
|     private LearningView learningView; | ||||
|     private ExecutorService learningExecutor; | ||||
|     private boolean currentlyLearning; | ||||
|     private boolean fastLearning; | ||||
|     private List<Double> latestRewardsHistory; | ||||
|     private int nrOfEpisodes; | ||||
|     private int prevDelay; | ||||
| 
 | ||||
|     public RLController(){ | ||||
|     public RLController(Environment<A> env, Method method, A... actions){ | ||||
|         learningExecutor = Executors.newSingleThreadExecutor(); | ||||
|         setEnvironment(env); | ||||
|         setMethod(method); | ||||
|         setAllowedActions(actions); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     public void start(){ | ||||
|         if(environment == null || discreteActionSpace == null || method == null){ | ||||
|             throw new RuntimeException("Set environment, discreteActionSpace and method before calling .start()"); | ||||
|         } | ||||
| 
 | ||||
|         switch (method){ | ||||
|             case MC_ONPOLICY_EGREEDY: | ||||
|                 learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay); | ||||
|  | @ -50,13 +48,21 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe | |||
|             case TD_ONPOLICY: | ||||
|                 break; | ||||
|             default: | ||||
|                 throw new RuntimeException("Undefined method"); | ||||
|                 throw new IllegalArgumentException("Undefined method"); | ||||
|         } | ||||
| 
 | ||||
|         initGUI(); | ||||
|         initLearning(); | ||||
|     } | ||||
| 
 | ||||
|     private void initGUI(){ | ||||
|         SwingUtilities.invokeLater(()->{ | ||||
|             learningView = new View<>(learning, environment, this); | ||||
|             learning.addListener(this); | ||||
|         }); | ||||
|     } | ||||
| 
 | ||||
|     private void initLearning(){ | ||||
|         if(learning instanceof EpisodicLearning){ | ||||
|             learningExecutor.submit(()->((EpisodicLearning) learning).learn(nrOfEpisodes)); | ||||
|         }else{ | ||||
|  | @ -65,7 +71,7 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe | |||
|     } | ||||
| 
 | ||||
|     /************************************************* | ||||
|      *                VIEW LISTENERS                 * | ||||
|      **                VIEW LISTENERS               ** | ||||
|      *************************************************/ | ||||
|     @Override | ||||
|     public void onLearnMoreEpisodes(int nrOfEpisodes){ | ||||
|  | @ -148,7 +154,7 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe | |||
|     } | ||||
| 
 | ||||
|     /************************************************* | ||||
|      *              LEARNING LISTENERS               * | ||||
|      **              LEARNING LISTENERS             ** | ||||
|      *************************************************/ | ||||
|     @Override | ||||
|     public void onLearningStart() { | ||||
|  | @ -185,37 +191,43 @@ public class RLController<A extends Enum> implements ViewListener, LearningListe | |||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /************************************************* | ||||
|      **                   SETTER                    ** | ||||
|      *************************************************/ | ||||
| 
 | ||||
|     public RLController<A> setMethod(Method method){ | ||||
|         this.method = method; | ||||
|         return this; | ||||
|     private void setEnvironment(Environment<A> environment){ | ||||
|         if(environment == null){ | ||||
|             throw new IllegalArgumentException("Environment cannot be null"); | ||||
|         } | ||||
|     public RLController<A> setEnvironment(Environment<A> environment){ | ||||
|         this.environment = environment; | ||||
|         return this; | ||||
|     } | ||||
|     @SafeVarargs | ||||
|     public final RLController<A> setAllowedActions(A... actions){ | ||||
| 
 | ||||
|     private void setMethod(Method method){ | ||||
|         if(method == null){ | ||||
|             throw new IllegalArgumentException("Method cannot be null"); | ||||
|         } | ||||
|         this.method = method; | ||||
|     } | ||||
| 
 | ||||
|     private void setAllowedActions(A[] actions){ | ||||
|         if(actions == null || actions.length == 0){ | ||||
|             throw new IllegalArgumentException("There has to be at least one action"); | ||||
|         } | ||||
|         this.discreteActionSpace = new ListDiscreteActionSpace<>(actions); | ||||
|         return this; | ||||
|     } | ||||
| 
 | ||||
|     public RLController<A> setDelay(int delay){ | ||||
|     public void setDelay(int delay){ | ||||
|         this.delay = delay; | ||||
|         return this; | ||||
|     } | ||||
| 
 | ||||
|     public RLController<A> setEpisodes(int nrOfEpisodes){ | ||||
|     public void setEpisodes(int nrOfEpisodes){ | ||||
|         this.nrOfEpisodes = nrOfEpisodes; | ||||
|         return this; | ||||
|     } | ||||
| 
 | ||||
|     public RLController<A> setDiscountFactor(float discountFactor){ | ||||
|     public void setDiscountFactor(float discountFactor){ | ||||
|         this.discountFactor = discountFactor; | ||||
|         return this; | ||||
|     } | ||||
|     public RLController<A> setEpsilon(float epsilon){ | ||||
|     public void setEpsilon(float epsilon){ | ||||
|         this.epsilon = epsilon; | ||||
|         return this; | ||||
|     } | ||||
| } | ||||
|  |  | |||
|  | @ -10,14 +10,16 @@ public class JumpingDino { | |||
|     public static void main(String[] args) { | ||||
|         RNG.setSeed(55); | ||||
| 
 | ||||
|         RLController<DinoAction> rl = new RLController<DinoAction>() | ||||
|                 .setEnvironment(new DinoWorld()) | ||||
|                 .setAllowedActions(DinoAction.values()) | ||||
|                 .setMethod(Method.MC_ONPOLICY_EGREEDY) | ||||
|                 .setDiscountFactor(1f) | ||||
|                 .setEpsilon(0.15f) | ||||
|                 .setDelay(200) | ||||
|                 .setEpisodes(100000); | ||||
|         RLController<DinoAction> rl = new RLController<>( | ||||
|                 new DinoWorld(), | ||||
|                 Method.MC_ONPOLICY_EGREEDY, | ||||
|                 DinoAction.values()); | ||||
| 
 | ||||
|         rl.setDelay(200); | ||||
|         rl.setDiscountFactor(1f); | ||||
|         rl.setEpsilon(0.15f); | ||||
|         rl.setEpisodes(5000); | ||||
| 
 | ||||
|         rl.start(); | ||||
|     } | ||||
| } | ||||
|  |  | |||
|  | @ -10,12 +10,16 @@ public class RunningAnt { | |||
|     public static void main(String[] args) { | ||||
|         RNG.setSeed(123); | ||||
| 
 | ||||
|         RLController<AntAction> rl = new RLController<AntAction>() | ||||
|                             .setEnvironment(new AntWorld(3,3,0.1)) | ||||
|                             .setAllowedActions(AntAction.values()) | ||||
|                             .setMethod(Method.MC_ONPOLICY_EGREEDY) | ||||
|                             .setDelay(200) | ||||
|                             .setEpisodes(100000); | ||||
|         RLController<AntAction> rl = new RLController<>( | ||||
|                 new AntWorld(3, 3, 0.1), | ||||
|                 Method.MC_ONPOLICY_EGREEDY, | ||||
|                 AntAction.values()); | ||||
| 
 | ||||
|         rl.setDelay(200); | ||||
|         rl.setEpisodes(10000); | ||||
|         rl.setDiscountFactor(1f); | ||||
|         rl.setEpsilon(0.15f); | ||||
| 
 | ||||
|         rl.start(); | ||||
|     } | ||||
| } | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue