package core.algo; import core.DiscreteActionSpace; import core.Environment; import core.LearningConfig; import core.StepResult; import core.listener.LearningListener; import example.DinoSampling; import lombok.Getter; import lombok.Setter; import java.io.File; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.StandardOpenOption; import java.util.ArrayList; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; public abstract class EpisodicLearning extends Learning implements Episodic { @Setter protected int currentEpisode = 0; protected volatile AtomicInteger episodesToLearn = new AtomicInteger(0); @Getter protected volatile int episodePerSecond; protected int episodeSumCurrentSecond; protected double sumOfRewards; protected List> episode = new ArrayList<>(); protected int timestampCurrentEpisode = 0; protected boolean converged; public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor, int delay) { super(environment, actionSpace, discountFactor, delay); initBenchMarking(); } public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace, float discountFactor) { super(environment, actionSpace, discountFactor); initBenchMarking(); } public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace, int delay) { super(environment, actionSpace, delay); initBenchMarking(); } public EpisodicLearning(Environment environment, DiscreteActionSpace actionSpace) { super(environment, actionSpace); initBenchMarking(); } protected abstract void nextEpisode(); private void initBenchMarking(){ new Thread(()->{ while (currentlyLearning){ episodePerSecond = episodeSumCurrentSecond; episodeSumCurrentSecond = 0; try { Thread.sleep(1000); } catch (InterruptedException e) { e.printStackTrace(); } } }).start(); } private void dispatchEpisodeEnd(){ ++episodeSumCurrentSecond; if(rewardHistory.size() > 10000){ rewardHistory.clear(); } rewardHistory.add(sumOfRewards); for(LearningListener l: learningListeners) { l.onEpisodeEnd(rewardHistory); } } private void dispatchEpisodeStart(){ ++currentEpisode; episodesToLearn.decrementAndGet(); for(LearningListener l: learningListeners){ l.onEpisodeStart(); } } @Override protected void dispatchStepEnd() { super.dispatchStepEnd(); timestamp++; timestampCurrentEpisode++; // TODO: more sophisticated way to check convergence if(timestampCurrentEpisode > 50000) { converged = true; // t File file = new File(DinoSampling.FILE_NAME); try { Files.writeString(Path.of(file.getPath()), currentEpisode/2 + ",", StandardOpenOption.APPEND); } catch (IOException e) { e.printStackTrace(); } System.out.println("converged after: " + currentEpisode/2 + " episode!"); episodesToLearn.set(0); dispatchLearningEnd(); } } @Override public void learn(){ learn(LearningConfig.DEFAULT_NR_OF_EPISODES); } private void startLearning(){ dispatchLearningStart(); while(episodesToLearn.get() > 0){ dispatchEpisodeStart(); timestampCurrentEpisode = 0; nextEpisode(); dispatchEpisodeEnd(); } synchronized (this){ dispatchLearningEnd(); notifyAll(); } } public void learnMoreEpisodes(int nrOfEpisodes){ episodesToLearn.addAndGet(nrOfEpisodes); } /** * Stopping the while loop by setting episodesToLearn to 0. * The current episode can not be interrupted, so the sleep delay * is removed and the calling thread has to wait until the * current episode is done. * Resetting the delay afterwards. */ @Override public synchronized void interruptLearning(){ episodesToLearn.set(0); int prevDelay = delay; delay = 0; while(currentlyLearning) { try { wait(); } catch (InterruptedException e) { e.printStackTrace(); } } delay = prevDelay; } public synchronized void learn(int nrOfEpisodes){ boolean isLearning = episodesToLearn.getAndAdd(nrOfEpisodes) != 0; if(!isLearning) startLearning(); } @Override public int getEpisodesToGo(){ return episodesToLearn.get(); } public int getCurrentEpisode() { return currentEpisode; } public int getEpisodesPerSecond() { return episodePerSecond; } @Override public synchronized void save(ObjectOutputStream oos) throws IOException { super.save(oos); oos.writeInt(currentEpisode); } @Override public synchronized void load(ObjectInputStream ois) throws IOException, ClassNotFoundException { super.load(ois); currentEpisode = ois.readInt(); } }