package core.algo; import core.DiscreteActionSpace; import core.Environment; import core.LearningConfig; import core.StateActionTable; import core.listener.LearningListener; import core.policy.Policy; import lombok.Getter; import lombok.Setter; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.concurrent.CopyOnWriteArrayList; /** * * @param discrete action type for a specific environment */ @Getter public abstract class Learning{ // TODO: temp testing -> extract to dedicated test protected int checkSum; protected int rewardCheckSum; // current discrete timestamp t protected int timestamp; protected int currentEpisode; protected Policy policy; protected DiscreteActionSpace actionSpace; @Setter protected StateActionTable stateActionTable; protected Environment environment; protected float discountFactor; protected Set learningListeners; @Setter protected int delay; protected List rewardHistory; protected volatile boolean currentlyLearning; public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, int delay) { this.environment = environment; this.actionSpace = actionSpace; this.discountFactor = discountFactor; this.delay = delay; currentlyLearning = false; learningListeners = new HashSet<>(); rewardHistory = new CopyOnWriteArrayList<>(); } public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor) { this(environment, actionSpace, discountFactor, LearningConfig.DEFAULT_DELAY); } public Learning(Environment environment, DiscreteActionSpace actionSpace, int delay) { this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, delay); } public Learning(Environment environment, DiscreteActionSpace actionSpace) { this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_DELAY); } public abstract void learn(); public void addListener(LearningListener learningListener) { learningListeners.add(learningListener); } protected void dispatchStepEnd() { for (LearningListener l : learningListeners) { l.onStepEnd(); } } protected void dispatchLearningStart() { currentlyLearning = true; for (LearningListener l : learningListeners) { l.onLearningStart(); } } protected void dispatchLearningEnd() { currentlyLearning = false; for (LearningListener l : learningListeners) { l.onLearningEnd(); } } public synchronized void interruptLearning(){ //TODO: for non episodic learning } public void save(ObjectOutputStream oos) throws IOException { oos.writeObject(rewardHistory); // oos.writeObject(stateActionTable); } public void load(ObjectInputStream ois) throws IOException, ClassNotFoundException { rewardHistory = (List) ois.readObject(); stateActionTable = (StateActionTable) ois.readObject(); } }