package core.algo; import core.DiscreteActionSpace; import core.Environment; import core.LearningConfig; import core.StateActionTable; import core.listener.LearningListener; import core.policy.Policy; import lombok.Getter; import lombok.Setter; import java.io.Serializable; import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.concurrent.CopyOnWriteArrayList; @Getter public abstract class Learning implements Serializable { protected Policy policy; protected DiscreteActionSpace actionSpace; @Setter protected StateActionTable stateActionTable; protected Environment environment; protected float discountFactor; protected Set learningListeners; @Setter protected int delay; protected List rewardHistory; public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, int delay) { this.environment = environment; this.actionSpace = actionSpace; this.discountFactor = discountFactor; this.delay = delay; learningListeners = new HashSet<>(); rewardHistory = new CopyOnWriteArrayList<>(); } public Learning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor) { this(environment, actionSpace, discountFactor, LearningConfig.DEFAULT_DELAY); } public Learning(Environment environment, DiscreteActionSpace actionSpace, int delay) { this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, delay); } public Learning(Environment environment, DiscreteActionSpace actionSpace) { this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_DELAY); } public abstract void learn(); public void addListener(LearningListener learningListener) { learningListeners.add(learningListener); } protected void dispatchStepEnd() { for (LearningListener l : learningListeners) { l.onStepEnd(); } } protected void dispatchLearningStart() { for (LearningListener l : learningListeners) { l.onLearningStart(); } } protected void dispatchLearningEnd() { for (LearningListener l : learningListeners) { l.onLearningEnd(); } } }