Merge remote-tracking branch 'origin/epsilonTest'
This commit is contained in:
		
						commit
						33f896ff40
					
				|  | @ -5,6 +5,7 @@ import core.Environment; | |||
| import core.LearningConfig; | ||||
| import core.StepResult; | ||||
| import core.listener.LearningListener; | ||||
| import core.policy.EpsilonGreedyPolicy; | ||||
| import lombok.Getter; | ||||
| import lombok.Setter; | ||||
| 
 | ||||
|  | @ -74,12 +75,35 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple | |||
| 
 | ||||
|     protected void dispatchEpisodeStart(){ | ||||
|         ++currentEpisode; | ||||
|         /* | ||||
|         2f 0.02 => 100 | ||||
|         1.5f 0.02 => 75 | ||||
|         1.4f 0.02 => fail | ||||
|         1.5f 0.1 => 16 ! | ||||
|          */ | ||||
|         if(this.policy instanceof EpsilonGreedyPolicy){ | ||||
|             float ep = 1.5f/(float)currentEpisode; | ||||
|             if(ep < 0.10) ep = 0; | ||||
|            ((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(ep); | ||||
|             System.out.println(ep); | ||||
|         } | ||||
|         episodesToLearn.decrementAndGet(); | ||||
|         for(LearningListener l: learningListeners){ | ||||
|             l.onEpisodeStart(); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     protected void dispatchStepEnd() { | ||||
|         super.dispatchStepEnd(); | ||||
|         timestamp++; | ||||
|         // TODO: more sophisticated way to check convergence | ||||
|         if(timestamp > 300000){ | ||||
|             System.out.println("converged after: " + currentEpisode + " episode!"); | ||||
|             interruptLearning(); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void learn(){ | ||||
|         learn(LearningConfig.DEFAULT_NR_OF_EPISODES); | ||||
|  |  | |||
|  | @ -12,7 +12,6 @@ import lombok.Setter; | |||
| import java.io.IOException; | ||||
| import java.io.ObjectInputStream; | ||||
| import java.io.ObjectOutputStream; | ||||
| import java.io.Serializable; | ||||
| import java.util.HashSet; | ||||
| import java.util.List; | ||||
| import java.util.Set; | ||||
|  | @ -26,6 +25,13 @@ import java.util.concurrent.Executors; | |||
|  */ | ||||
| @Getter | ||||
| public abstract class Learning<A extends Enum>{ | ||||
|     // TODO: temp testing -> extract to dedicated test | ||||
|     protected int checkSum; | ||||
|     protected int rewardCheckSum; | ||||
| 
 | ||||
|     // current discrete timestamp t | ||||
|     protected int timestamp; | ||||
|     protected int currentEpisode; | ||||
|     protected Policy<A> policy; | ||||
|     protected DiscreteActionSpace<A> actionSpace; | ||||
|     @Setter | ||||
|  | @ -83,6 +89,8 @@ public abstract class Learning<A extends Enum>{ | |||
| 
 | ||||
|     protected void dispatchLearningEnd() { | ||||
|         currentlyLearning = false; | ||||
|         System.out.println("Checksum: " + checkSum); | ||||
|         System.out.println("Reward Checksum: " + rewardCheckSum); | ||||
|         for (LearningListener l : learningListeners) { | ||||
|             l.onLearningEnd(); | ||||
|         } | ||||
|  |  | |||
|  | @ -5,5 +5,5 @@ package core.algo; | |||
|  * which RL-algorithm should be used. | ||||
|  */ | ||||
| public enum Method { | ||||
|     MC_CONTROL_EGREEDY, SARSA_EPISODIC, Q_LEARNING_OFF_POLICY_CONTROL | ||||
|     MC_CONTROL_FIRST_VISIT, SARSA_EPISODIC, Q_LEARNING_OFF_POLICY_CONTROL | ||||
| } | ||||
|  |  | |||
|  | @ -17,7 +17,7 @@ import java.util.*; | |||
|  * For example: | ||||
|  * <p> | ||||
|  * startingState -> MOVE_LEFT : very first state action in the episode i = 1 | ||||
|  * image the agent does not collect the food and drops it to the start, the agent will receive | ||||
|  * image the agent does not collect the food and does not drop it onto start, the agent will receive | ||||
|  * -1 for every timestamp hence (startingState -> MOVE_LEFT) will get a value of -10; | ||||
|  * <p> | ||||
|  * BUT image moving left from the starting position will have no impact on the state because | ||||
|  | @ -30,12 +30,12 @@ import java.util.*; | |||
|  * | ||||
|  * @param <A> | ||||
|  */ | ||||
| public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A> { | ||||
| public class MonteCarloControlFirstVisitEGreedy<A extends Enum> extends EpisodicLearning<A> { | ||||
| 
 | ||||
|     private Map<Pair<State, A>, Double> returnSum; | ||||
|     private Map<Pair<State, A>, Integer> returnCount; | ||||
| 
 | ||||
|     public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) { | ||||
|     public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay) { | ||||
|         super(environment, actionSpace, discountFactor, delay); | ||||
|         this.policy = new EpsilonGreedyPolicy<>(epsilon); | ||||
|         this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace); | ||||
|  | @ -43,7 +43,7 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A | |||
|         returnCount = new HashMap<>(); | ||||
|     } | ||||
| 
 | ||||
|     public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) { | ||||
|     public MonteCarloControlFirstVisitEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, int delay) { | ||||
|         this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, delay); | ||||
|     } | ||||
| 
 | ||||
|  | @ -58,12 +58,16 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A | |||
|         } | ||||
|         sumOfRewards = 0; | ||||
|         StepResultEnvironment envResult = null; | ||||
|         //TODO extract to learning | ||||
|         int timestamp = 0; | ||||
|         while(envResult == null || !envResult.isDone()) { | ||||
|             Map<A, Double> actionValues = stateActionTable.getActionValues(state); | ||||
|             A chosenAction = policy.chooseAction(actionValues); | ||||
|             checkSum += chosenAction.ordinal(); | ||||
|             envResult = environment.step(chosenAction); | ||||
|             State nextState = envResult.getState(); | ||||
|             sumOfRewards += envResult.getReward(); | ||||
|             rewardCheckSum += envResult.getReward(); | ||||
|             episode.add(new StepResult<>(state, chosenAction, envResult.getReward())); | ||||
| 
 | ||||
|             state = nextState; | ||||
|  | @ -73,6 +77,7 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A | |||
|             } catch (InterruptedException e) { | ||||
|                 e.printStackTrace(); | ||||
|             } | ||||
|             timestamp++; | ||||
|             dispatchStepEnd(); | ||||
|         } | ||||
| 
 | ||||
|  | @ -46,7 +46,7 @@ public class SARSA<A extends Enum> extends EpisodicLearning<A> { | |||
|             actionValues = stateActionTable.getActionValues(nextState); | ||||
|             A nextAction = policy.chooseAction(actionValues); | ||||
| 
 | ||||
|             // TD update | ||||
|             // td update | ||||
|             // target = reward + gamma * Q(nextState, nextAction) | ||||
|             double currentQValue = stateActionTable.getActionValues(state).get(action); | ||||
|             double nextQValue = stateActionTable.getActionValues(nextState).get(nextAction); | ||||
|  | @ -7,7 +7,7 @@ import core.ListDiscreteActionSpace; | |||
| import core.algo.EpisodicLearning; | ||||
| import core.algo.Learning; | ||||
| import core.algo.Method; | ||||
| import core.algo.mc.MonteCarloControlEGreedy; | ||||
| import core.algo.mc.MonteCarloControlFirstVisitEGreedy; | ||||
| import core.algo.td.QLearningOffPolicyTDControl; | ||||
| import core.algo.td.SARSA; | ||||
| import core.listener.LearningListener; | ||||
|  | @ -48,8 +48,8 @@ public class RLController<A extends Enum> implements LearningListener { | |||
| 
 | ||||
|     public void start() { | ||||
|         switch(method) { | ||||
|             case MC_CONTROL_EGREEDY: | ||||
|                 learning = new MonteCarloControlEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay); | ||||
|             case MC_CONTROL_FIRST_VISIT: | ||||
|                 learning = new MonteCarloControlFirstVisitEGreedy<>(environment, discreteActionSpace, discountFactor, epsilon, delay); | ||||
|                 break; | ||||
|             case SARSA_EPISODIC: | ||||
|                 learning = new SARSA<>(environment, discreteActionSpace, discountFactor, epsilon, learningRate, delay); | ||||
|  | @ -115,7 +115,7 @@ public class RLController<A extends Enum> implements LearningListener { | |||
|         try { | ||||
|             fis = new FileInputStream(fileName); | ||||
|             in = new ObjectInputStream(fis); | ||||
|             System.out.println("interrup" + Thread.currentThread().getId()); | ||||
|             System.out.println("interrupt" + Thread.currentThread().getId()); | ||||
|             learning.interruptLearning(); | ||||
|             learning.load(in); | ||||
|             in.close(); | ||||
|  |  | |||
|  | @ -50,7 +50,7 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable { | |||
|     @Override | ||||
|     public StepResultEnvironment step(DinoAction action) { | ||||
|         boolean done = false; | ||||
|         int reward = 0; | ||||
|         int reward = 1; | ||||
| 
 | ||||
|         if(action == DinoAction.JUMP){ | ||||
|             dino.jump(); | ||||
|  | @ -74,7 +74,7 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable { | |||
|             spawnNewObstacle(); | ||||
|         } | ||||
|         if(ranIntoObstacle()) { | ||||
|             reward = -1; | ||||
|             reward = 0; | ||||
|             done = true; | ||||
|         } | ||||
| 
 | ||||
|  |  | |||
|  | @ -0,0 +1,27 @@ | |||
| package example; | ||||
| 
 | ||||
| import core.RNG; | ||||
| import core.algo.Method; | ||||
| import core.controller.RLController; | ||||
| import evironment.jumpingDino.DinoAction; | ||||
| import evironment.jumpingDino.DinoWorld; | ||||
| 
 | ||||
| public class DinoSampling { | ||||
|     public static void main(String[] args) { | ||||
|         for (int i = 0; i < 10 ; i++) { | ||||
|             RNG.setSeed(55); | ||||
| 
 | ||||
|             RLController<DinoAction> rl = new RLController<>( | ||||
|                     new DinoWorld(false, false), | ||||
|                     Method.MC_CONTROL_FIRST_VISIT, | ||||
|                     DinoAction.values()); | ||||
| 
 | ||||
|             rl.setDelay(0); | ||||
|             rl.setDiscountFactor(1f); | ||||
|             rl.setEpsilon(0.15f); | ||||
|             rl.setLearningRate(1f); | ||||
|             rl.setNrOfEpisodes(400); | ||||
|             rl.start(); | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | @ -3,7 +3,6 @@ package example; | |||
| import core.RNG; | ||||
| import core.algo.Method; | ||||
| import core.controller.RLController; | ||||
| import core.controller.RLControllerGUI; | ||||
| import evironment.jumpingDino.DinoAction; | ||||
| import evironment.jumpingDino.DinoWorld; | ||||
| 
 | ||||
|  | @ -11,16 +10,16 @@ public class JumpingDino { | |||
|     public static void main(String[] args) { | ||||
|         RNG.setSeed(55); | ||||
| 
 | ||||
|         RLController<DinoAction> rl = new RLControllerGUI<>( | ||||
|         RLController<DinoAction> rl = new RLController<>( | ||||
|                 new DinoWorld(false, false), | ||||
|                 Method.Q_LEARNING_OFF_POLICY_CONTROL, | ||||
|                 Method.MC_CONTROL_FIRST_VISIT, | ||||
|                 DinoAction.values()); | ||||
| 
 | ||||
|         rl.setDelay(1000); | ||||
|         rl.setDiscountFactor(0.9f); | ||||
|         rl.setEpsilon(0.1f); | ||||
|         rl.setLearningRate(0.5f); | ||||
|         rl.setNrOfEpisodes(4000000); | ||||
|         rl.setDelay(0); | ||||
|         rl.setDiscountFactor(1f); | ||||
|         rl.setEpsilon(0.15f); | ||||
|         rl.setLearningRate(1f); | ||||
|         rl.setNrOfEpisodes(400); | ||||
|         rl.start(); | ||||
|     } | ||||
| } | ||||
|  |  | |||
|  | @ -13,7 +13,7 @@ public class RunningAnt { | |||
| 
 | ||||
|         RLController<AntAction> rl = new RLControllerGUI<>( | ||||
|                 new AntWorld(3, 3, 0.1), | ||||
|                 Method.MC_CONTROL_EGREEDY, | ||||
|                 Method.MC_CONTROL_FIRST_VISIT, | ||||
|                 AntAction.values()); | ||||
| 
 | ||||
|         rl.setDelay(200); | ||||
|  |  | |||
|  | @ -0,0 +1,30 @@ | |||
| import core.RNG; | ||||
| import core.algo.Method; | ||||
| import core.controller.RLController; | ||||
| import core.controller.RLControllerGUI; | ||||
| import evironment.jumpingDino.DinoAction; | ||||
| import evironment.jumpingDino.DinoWorld; | ||||
| import org.junit.Test; | ||||
| 
 | ||||
| public class MCFirstVisit { | ||||
| 
 | ||||
|     /** | ||||
|      * Test if the action sequence is deterministic | ||||
|      */ | ||||
|     @Test | ||||
|     public void deterministicActionSequence(){ | ||||
|         RNG.setSeed(55); | ||||
| 
 | ||||
|         RLController<DinoAction> rl = new RLControllerGUI<>( | ||||
|                 new DinoWorld(false, false), | ||||
|                 Method.MC_CONTROL_FIRST_VISIT, | ||||
|                 DinoAction.values()); | ||||
| 
 | ||||
|         rl.setDelay(10); | ||||
|         rl.setDiscountFactor(1f); | ||||
|         rl.setEpsilon(0.1f); | ||||
|         rl.setLearningRate(0.8f); | ||||
|         rl.setNrOfEpisodes(4000000); | ||||
|         rl.start(); | ||||
|     } | ||||
| } | ||||
		Loading…
	
		Reference in New Issue