adopt MVC pattern and add real time graph interface
This commit is contained in:
		
							parent
							
								
									7f18a66e98
								
							
						
					
					
						commit
						e0160ca1df
					
				|  | @ -13,6 +13,9 @@ repositories { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| dependencies { | dependencies { | ||||||
|  |     // https://mvnrepository.com/artifact/org.jfree/jfreechart | ||||||
|  |     // https://mvnrepository.com/artifact/org.knowm.xchart/xchart | ||||||
|  |     compile group: 'org.knowm.xchart', name: 'xchart', version: '3.2.2' | ||||||
|     testCompile group: 'junit', name: 'junit', version: '4.12' |     testCompile group: 'junit', name: 'junit', version: '4.12' | ||||||
|     compileOnly 'org.projectlombok:lombok:1.18.10' |     compileOnly 'org.projectlombok:lombok:1.18.10' | ||||||
|     annotationProcessor 'org.projectlombok:lombok:1.18.10' |     annotationProcessor 'org.projectlombok:lombok:1.18.10' | ||||||
|  |  | ||||||
|  | @ -3,7 +3,7 @@ package core; | ||||||
| import java.util.List; | import java.util.List; | ||||||
| 
 | 
 | ||||||
| public interface DiscreteActionSpace<A extends Enum> { | public interface DiscreteActionSpace<A extends Enum> { | ||||||
|     int getNumberOfAction(); |     int getNumberOfActions(); | ||||||
|     void addAction(A a); |     void addAction(A a); | ||||||
|     void addActions(A... as); |     void addActions(A... as); | ||||||
|     List<A> getAllActions(); |     List<A> getAllActions(); | ||||||
|  |  | ||||||
|  | @ -0,0 +1,7 @@ | ||||||
|  | package core; | ||||||
|  | 
 | ||||||
|  | public class LearningConfig { | ||||||
|  |     public static final int DEFAULT_DELAY = 1; | ||||||
|  |     public static final float DEFAULT_EPSILON = 0.1f; | ||||||
|  |     public static final float DEFAULT_DISCOUNT_FACTOR = 1.0f; | ||||||
|  | } | ||||||
|  | @ -32,7 +32,7 @@ public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSp | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
|     public int getNumberOfAction(){ |     public int getNumberOfActions(){ | ||||||
|         return actions.size(); |         return actions.size(); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -2,26 +2,62 @@ package core.algo; | ||||||
| 
 | 
 | ||||||
| import core.DiscreteActionSpace; | import core.DiscreteActionSpace; | ||||||
| import core.Environment; | import core.Environment; | ||||||
|  | import core.LearningConfig; | ||||||
| import core.StateActionTable; | import core.StateActionTable; | ||||||
|  | import core.listener.LearningListener; | ||||||
| import core.policy.Policy; | import core.policy.Policy; | ||||||
|  | import lombok.Getter; | ||||||
|  | import lombok.Setter; | ||||||
| 
 | 
 | ||||||
|  | import javax.swing.*; | ||||||
|  | import java.util.HashSet; | ||||||
|  | import java.util.Set; | ||||||
|  | 
 | ||||||
|  | @Getter | ||||||
| public abstract class Learning<A extends Enum> { | public abstract class Learning<A extends Enum> { | ||||||
|     protected Policy<A> policy; |     protected Policy<A> policy; | ||||||
|     protected DiscreteActionSpace<A> actionSpace; |     protected DiscreteActionSpace<A> actionSpace; | ||||||
|     protected StateActionTable<A> stateActionTable; |     protected StateActionTable<A> stateActionTable; | ||||||
|     protected Environment<A> environment; |     protected Environment<A> environment; | ||||||
|     protected float discountFactor; |     protected float discountFactor; | ||||||
|  |     @Setter | ||||||
|     protected float epsilon; |     protected float epsilon; | ||||||
|  |     protected Set<LearningListener> learningListeners; | ||||||
|  |     @Setter | ||||||
|  |     protected int delay; | ||||||
| 
 | 
 | ||||||
|     public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon){ |     public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay){ | ||||||
|         this.environment = environment; |         this.environment = environment; | ||||||
|         this.actionSpace = actionSpace; |         this.actionSpace = actionSpace; | ||||||
|         this.discountFactor = discountFactor; |         this.discountFactor = discountFactor; | ||||||
|         this.epsilon = epsilon; |         this.epsilon = epsilon; | ||||||
|     } |         this.delay = delay; | ||||||
|     public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace){ |         learningListeners = new HashSet<>(); | ||||||
|         this(environment, actionSpace, 1.0f, 0.1f); |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public abstract void learn(int nrOfEpisodes, int delay); |     public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon){ | ||||||
|  |         this(environment, actionSpace, discountFactor, epsilon, LearningConfig.DEFAULT_DELAY); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace){ | ||||||
|  |         this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, LearningConfig.DEFAULT_DELAY); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     public abstract void learn(int nrOfEpisodes); | ||||||
|  | 
 | ||||||
|  |     public void addListener(LearningListener learningListener){ | ||||||
|  |         learningListeners.add(learningListener); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     protected void dispatchEpisodeEnd(double sum){ | ||||||
|  |         for(LearningListener l: learningListeners) { | ||||||
|  |             l.onEpisodeEnd(sum); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     protected void dispatchEpisodeStart(){ | ||||||
|  |         for(LearningListener l: learningListeners){ | ||||||
|  |             l.onEpisodeStart(); | ||||||
|  |         } | ||||||
|  |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -34,22 +34,21 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends Learning<A> { | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     @Override |     @Override | ||||||
|     public void learn(int nrOfEpisodes, int delay) { |     public void learn(int nrOfEpisodes) { | ||||||
| 
 | 
 | ||||||
|         Map<Pair<State, A>, Double> returnSum = new HashMap<>(); |         Map<Pair<State, A>, Double> returnSum = new HashMap<>(); | ||||||
|         Map<Pair<State, A>, Integer> returnCount = new HashMap<>(); |         Map<Pair<State, A>, Integer> returnCount = new HashMap<>(); | ||||||
| 
 | 
 | ||||||
|         State startingState = environment.reset(); |  | ||||||
|         for(int i = 0; i < nrOfEpisodes; ++i) { |         for(int i = 0; i < nrOfEpisodes; ++i) { | ||||||
|             List<StepResult<A>> episode = new ArrayList<>(); |             List<StepResult<A>> episode = new ArrayList<>(); | ||||||
|             State state = environment.reset(); |             State state = environment.reset(); | ||||||
|             double rewardSum = 0; |             double sumOfRewards = 0; | ||||||
|             for(int j=0; j < 10; ++j){ |             for(int j=0; j < 10; ++j){ | ||||||
|                 Map<A, Double> actionValues = stateActionTable.getActionValues(state); |                 Map<A, Double> actionValues = stateActionTable.getActionValues(state); | ||||||
|                 A chosenAction = policy.chooseAction(actionValues); |                 A chosenAction = policy.chooseAction(actionValues); | ||||||
|                 StepResultEnvironment envResult = environment.step(chosenAction); |                 StepResultEnvironment envResult = environment.step(chosenAction); | ||||||
|                 State nextState = envResult.getState(); |                 State nextState = envResult.getState(); | ||||||
|                 rewardSum +=  envResult.getReward(); |                 sumOfRewards +=  envResult.getReward(); | ||||||
|                 episode.add(new StepResult<>(state, chosenAction, envResult.getReward())); |                 episode.add(new StepResult<>(state, chosenAction, envResult.getReward())); | ||||||
| 
 | 
 | ||||||
|                 if(envResult.isDone()) break; |                 if(envResult.isDone()) break; | ||||||
|  | @ -57,13 +56,14 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends Learning<A> { | ||||||
|                 state = nextState; |                 state = nextState; | ||||||
| 
 | 
 | ||||||
|                 try { |                 try { | ||||||
|                     Thread.sleep(1); |                     Thread.sleep(delay); | ||||||
|                 } catch (InterruptedException e) { |                 } catch (InterruptedException e) { | ||||||
|                     e.printStackTrace(); |                     e.printStackTrace(); | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
| 
 | 
 | ||||||
|             System.out.printf("Episode %d \t Reward: %f \n", i, rewardSum); |             dispatchEpisodeEnd(sumOfRewards); | ||||||
|  |             System.out.printf("Episode %d \t Reward: %f \n", i, sumOfRewards); | ||||||
|             Set<Pair<State, A>> stateActionPairs = new HashSet<>(); |             Set<Pair<State, A>> stateActionPairs = new HashSet<>(); | ||||||
| 
 | 
 | ||||||
|             for(StepResult<A> sr: episode){ |             for(StepResult<A> sr: episode){ | ||||||
|  |  | ||||||
|  | @ -0,0 +1,5 @@ | ||||||
|  | package core.algo; | ||||||
|  | 
 | ||||||
|  | public enum Method { | ||||||
|  |     MC_ONPOLICY_EGREEDY, TD_ONPOLICY | ||||||
|  | } | ||||||
|  | @ -0,0 +1,81 @@ | ||||||
|  | package core.controller; | ||||||
|  | 
 | ||||||
|  | import core.DiscreteActionSpace; | ||||||
|  | import core.Environment; | ||||||
|  | import core.ListDiscreteActionSpace; | ||||||
|  | import core.algo.Learning; | ||||||
|  | import core.algo.Method; | ||||||
|  | import core.algo.mc.MonteCarloOnPolicyEGreedy; | ||||||
|  | import core.gui.View; | ||||||
|  | 
 | ||||||
|  | import javax.swing.*; | ||||||
|  | import java.util.Optional; | ||||||
|  | 
 | ||||||
|  | public class RLController<A extends Enum> implements ViewListener{ | ||||||
|  |     protected Environment<A> environment; | ||||||
|  |     protected Learning<A> learning; | ||||||
|  |     protected DiscreteActionSpace<A> discreteActionSpace; | ||||||
|  |     protected View<A> view; | ||||||
|  |     private int delay; | ||||||
|  |     private int nrOfEpisodes; | ||||||
|  |     private Method method; | ||||||
|  | 
 | ||||||
|  |     public RLController(){ | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     public void start(){ | ||||||
|  |         if(environment == null || discreteActionSpace == null || method == null){ | ||||||
|  |             throw new RuntimeException("Set environment, discreteActionSpace and method before calling .start()"); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         switch (method){ | ||||||
|  |             case MC_ONPOLICY_EGREEDY: | ||||||
|  |                 learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace); | ||||||
|  |                 break; | ||||||
|  |             case TD_ONPOLICY: | ||||||
|  |                 break; | ||||||
|  |             default: | ||||||
|  |                 throw new RuntimeException("Undefined method"); | ||||||
|  |         } | ||||||
|  |         SwingUtilities.invokeLater(() ->{ | ||||||
|  |             view = new View<>(learning, this); | ||||||
|  |             learning.addListener(view); | ||||||
|  |         }); | ||||||
|  |         learning.learn(nrOfEpisodes); | ||||||
|  |     } | ||||||
|  |      | ||||||
|  |     @Override | ||||||
|  |     public void onEpsilonChange(float epsilon) { | ||||||
|  |         learning.setEpsilon(epsilon); | ||||||
|  |         SwingUtilities.invokeLater(() -> view.updateLearningInfoPanel()); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     @Override | ||||||
|  |     public void onDelayChange(int delay) { | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     public RLController<A> setMethod(Method method){ | ||||||
|  |         this.method = method; | ||||||
|  |         return this; | ||||||
|  |     } | ||||||
|  |     public RLController<A> setEnvironment(Environment<A> environment){ | ||||||
|  |         this.environment = environment; | ||||||
|  |         return this; | ||||||
|  |     } | ||||||
|  |     @SafeVarargs | ||||||
|  |     public final RLController<A> setAllowedActions(A... actions){ | ||||||
|  |         this.discreteActionSpace = new ListDiscreteActionSpace<>(actions); | ||||||
|  |         return this; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     public RLController<A> setDelay(int delay){ | ||||||
|  |         this.delay = delay; | ||||||
|  |         return this; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     public RLController<A> setEpisodes(int nrOfEpisodes){ | ||||||
|  |         this.nrOfEpisodes = nrOfEpisodes; | ||||||
|  |         return this; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  | } | ||||||
|  | @ -0,0 +1,6 @@ | ||||||
|  | package core.controller; | ||||||
|  | 
 | ||||||
|  | public interface ViewListener { | ||||||
|  |     void onEpsilonChange(float epsilon); | ||||||
|  |     void onDelayChange(int delay); | ||||||
|  | } | ||||||
|  | @ -0,0 +1,41 @@ | ||||||
|  | package core.gui; | ||||||
|  | 
 | ||||||
|  | import core.algo.Learning; | ||||||
|  | import core.controller.ViewListener; | ||||||
|  | 
 | ||||||
|  | import javax.swing.*; | ||||||
|  | 
 | ||||||
|  | public class LearningInfoPanel extends JPanel { | ||||||
|  |     private Learning learning; | ||||||
|  |     private JLabel policyLabel; | ||||||
|  |     private JLabel discountLabel; | ||||||
|  |     private JLabel epsilonLabel; | ||||||
|  |     private JSlider epsilonSlider; | ||||||
|  |     private JSlider delaySlider; | ||||||
|  | 
 | ||||||
|  |     public LearningInfoPanel(Learning learning, ViewListener viewListener){ | ||||||
|  |         this.learning = learning; | ||||||
|  |         setLayout(new BoxLayout(this, BoxLayout.Y_AXIS)); | ||||||
|  |         policyLabel = new JLabel(); | ||||||
|  |         discountLabel = new JLabel(); | ||||||
|  |         epsilonLabel = new JLabel(); | ||||||
|  |         epsilonSlider = new JSlider(0, 100, (int)(learning.getEpsilon() * 100)); | ||||||
|  |         epsilonSlider.addChangeListener(e -> viewListener.onEpsilonChange(epsilonSlider.getValue() / 100f)); | ||||||
|  |         add(policyLabel); | ||||||
|  |         add(discountLabel); | ||||||
|  |         add(epsilonLabel); | ||||||
|  |         add(epsilonSlider); | ||||||
|  |         refreshLabels(); | ||||||
|  |         setVisible(true); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     public void refreshLabels(){ | ||||||
|  |         policyLabel.setText("Policy: " + learning.getPolicy().getClass()); | ||||||
|  |         discountLabel.setText("Discount factor: " + learning.getDiscountFactor()); | ||||||
|  |         epsilonLabel.setText("Exploration (Epsilon): " + learning.getEpsilon()); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     protected JSlider getEpsilonSlider(){ | ||||||
|  |         return epsilonSlider; | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | @ -0,0 +1,102 @@ | ||||||
|  | package core.gui; | ||||||
|  | 
 | ||||||
|  | import core.algo.Learning; | ||||||
|  | import core.controller.ViewListener; | ||||||
|  | import core.listener.LearningListener; | ||||||
|  | import lombok.Getter; | ||||||
|  | import org.knowm.xchart.QuickChart; | ||||||
|  | import org.knowm.xchart.XChartPanel; | ||||||
|  | import org.knowm.xchart.XYChart; | ||||||
|  | 
 | ||||||
|  | import javax.swing.*; | ||||||
|  | import java.awt.*; | ||||||
|  | import java.util.ArrayList; | ||||||
|  | import java.util.List; | ||||||
|  | 
 | ||||||
|  | public class View<A extends Enum> implements LearningListener { | ||||||
|  |     private Learning<A> learning; | ||||||
|  |     @Getter | ||||||
|  |     private XYChart chart; | ||||||
|  |     @Getter | ||||||
|  |     private LearningInfoPanel learningInfoPanel; | ||||||
|  |     @Getter | ||||||
|  |     private JFrame mainFrame; | ||||||
|  |     private XChartPanel<XYChart> rewardChartPanel; | ||||||
|  |     private ViewListener viewListener; | ||||||
|  |     private List<Double> rewardHistory; | ||||||
|  | 
 | ||||||
|  |     public View(Learning<A> learning, ViewListener viewListener){ | ||||||
|  |         this.learning = learning; | ||||||
|  |         this.viewListener = viewListener; | ||||||
|  |         rewardHistory = new ArrayList<>(); | ||||||
|  |         this.initMainFrame(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     private void initMainFrame(){ | ||||||
|  |         mainFrame = new JFrame(); | ||||||
|  |         mainFrame.setPreferredSize(new Dimension(1280, 720)); | ||||||
|  |         mainFrame.setLayout(new BorderLayout()); | ||||||
|  | 
 | ||||||
|  |         initLearningInfoPanel(); | ||||||
|  |         initRewardChart(); | ||||||
|  | 
 | ||||||
|  |         mainFrame.add(BorderLayout.WEST, learningInfoPanel); | ||||||
|  |         mainFrame.add(BorderLayout.CENTER, rewardChartPanel); | ||||||
|  | 
 | ||||||
|  |         mainFrame.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE); | ||||||
|  |         mainFrame.pack(); | ||||||
|  |         mainFrame.setVisible(true); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     private void initLearningInfoPanel(){ | ||||||
|  |         learningInfoPanel = new LearningInfoPanel(learning, viewListener); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     private void initRewardChart(){ | ||||||
|  |         chart = | ||||||
|  |                 QuickChart.getChart( | ||||||
|  |                         "Rewards per Episode", | ||||||
|  |                         "Episode", | ||||||
|  |                         "Reward", | ||||||
|  |                         "randomWalk", | ||||||
|  |                         new double[] {0}, | ||||||
|  |                         new double[] {0}); | ||||||
|  |         chart.getStyler().setLegendVisible(true); | ||||||
|  |         chart.getStyler().setXAxisTicksVisible(true); | ||||||
|  |         rewardChartPanel = new XChartPanel<>(chart); | ||||||
|  |         rewardChartPanel.setPreferredSize(new Dimension(300,300)); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     public void showState(Visualizable state){ | ||||||
|  |         new JFrame(){ | ||||||
|  |             { | ||||||
|  |                 JComponent stateComponent = state.visualize(); | ||||||
|  |                 setPreferredSize(new Dimension(stateComponent.getWidth(), stateComponent.getHeight())); | ||||||
|  |                 add(stateComponent); | ||||||
|  |                 setVisible(true); | ||||||
|  |             } | ||||||
|  |         }; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     public void updateRewardGraph(double recentReward){ | ||||||
|  |         rewardHistory.add(recentReward); | ||||||
|  |         chart.updateXYSeries("randomWalk", null, rewardHistory, null); | ||||||
|  |         rewardChartPanel.revalidate(); | ||||||
|  |         rewardChartPanel.repaint(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     public void updateLearningInfoPanel(){ | ||||||
|  |         this.learningInfoPanel.refreshLabels(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |     @Override | ||||||
|  |     public void onEpisodeEnd(double sumOfRewards) { | ||||||
|  |         SwingUtilities.invokeLater(()->updateRewardGraph(sumOfRewards)); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     @Override | ||||||
|  |     public void onEpisodeStart() { | ||||||
|  | 
 | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | @ -0,0 +1,7 @@ | ||||||
|  | package core.gui; | ||||||
|  | 
 | ||||||
|  | import javax.swing.*; | ||||||
|  | 
 | ||||||
|  | public interface Visualizable { | ||||||
|  |     JComponent visualize(); | ||||||
|  | } | ||||||
|  | @ -0,0 +1,6 @@ | ||||||
|  | package core.listener; | ||||||
|  | 
 | ||||||
|  | public interface LearningListener{ | ||||||
|  |     void onEpisodeEnd(double sumOfRewards); | ||||||
|  |     void onEpisodeStart(); | ||||||
|  | } | ||||||
|  | @ -1,7 +1,10 @@ | ||||||
| package evironment.antGame; | package evironment.antGame; | ||||||
| 
 | 
 | ||||||
| import core.State; | import core.State; | ||||||
|  | import core.gui.Visualizable; | ||||||
|  | import evironment.antGame.gui.CellColor; | ||||||
| 
 | 
 | ||||||
|  | import javax.swing.*; | ||||||
| import java.awt.*; | import java.awt.*; | ||||||
| import java.util.Arrays; | import java.util.Arrays; | ||||||
| 
 | 
 | ||||||
|  | @ -10,7 +13,7 @@ import java.util.Arrays; | ||||||
|  * Essentially a snapshot of the current Ant Agent |  * Essentially a snapshot of the current Ant Agent | ||||||
|  * and therefor has to be deep copied |  * and therefor has to be deep copied | ||||||
|  */ |  */ | ||||||
| public class AntState implements State { | public class AntState implements State, Visualizable { | ||||||
|     private final Cell[][] knownWorld; |     private final Cell[][] knownWorld; | ||||||
|     private final Point pos; |     private final Point pos; | ||||||
|     private final boolean hasFood; |     private final boolean hasFood; | ||||||
|  | @ -29,9 +32,9 @@ public class AntState implements State { | ||||||
| 
 | 
 | ||||||
|         int unknown = 0; |         int unknown = 0; | ||||||
|         int diff = 0; |         int diff = 0; | ||||||
|         for (int i = 0; i < knownWorld.length; i++) { |         for (Cell[] cells : knownWorld) { | ||||||
|             for (int j = 0; j < knownWorld[i].length; j++) { |             for (Cell cell : cells) { | ||||||
|                 if(knownWorld[i][j].getType() == CellType.UNKNOWN){ |                 if (cell.getType() == CellType.UNKNOWN) { | ||||||
|                     unknown += 1; |                     unknown += 1; | ||||||
|                 } else { |                 } else { | ||||||
|                     diff += 1; |                     diff += 1; | ||||||
|  | @ -89,4 +92,62 @@ public class AntState implements State { | ||||||
|         } |         } | ||||||
|         return  super.equals(obj); |         return  super.equals(obj); | ||||||
|     } |     } | ||||||
|  | 
 | ||||||
|  |     @Override | ||||||
|  |     public JComponent visualize() { | ||||||
|  |         return new JScrollPane() { | ||||||
|  |             private int cellSize; | ||||||
|  |             private final int paneWidth = 500; | ||||||
|  |             private final int paneHeight = 500; | ||||||
|  |             private Font font; | ||||||
|  |             { | ||||||
|  |                 setPreferredSize(new Dimension(paneWidth, paneHeight)); | ||||||
|  |                 cellSize = (paneWidth- knownWorld.length) /knownWorld.length; | ||||||
|  |                 font = new Font("plain", Font.BOLD, cellSize); | ||||||
|  |                 JPanel worldPanel = new JPanel(){ | ||||||
|  |                     { | ||||||
|  |                         setPreferredSize(new Dimension(knownWorld.length * cellSize, knownWorld[0].length * cellSize)); | ||||||
|  |                         setVisible(true); | ||||||
|  | 
 | ||||||
|  |                         addMouseWheelListener(e -> { | ||||||
|  |                             if(e.getWheelRotation() > 0){ | ||||||
|  |                                 cellSize -= 1; | ||||||
|  |                             }else { | ||||||
|  |                                 cellSize += 1; | ||||||
|  |                             } | ||||||
|  |                             font = new Font("plain", Font.BOLD, cellSize); | ||||||
|  |                             setPreferredSize(new Dimension(knownWorld.length * cellSize, knownWorld[0].length * cellSize)); | ||||||
|  |                             revalidate(); | ||||||
|  |                             repaint(); | ||||||
|  |                         }); | ||||||
|  |                     } | ||||||
|  | 
 | ||||||
|  |                     @Override | ||||||
|  |                     public void paintComponent(Graphics g) { | ||||||
|  |                         super.paintComponent(g); | ||||||
|  |                         for (int i = 0; i < knownWorld.length; i++) { | ||||||
|  |                             for (int j = 0; j < knownWorld[0].length; j++) { | ||||||
|  |                                 g.setColor(Color.BLACK); | ||||||
|  |                                 g.drawRect(i*cellSize, j*cellSize, cellSize, cellSize); | ||||||
|  |                                 g.setColor(CellColor.map.get(knownWorld[i][j].getType())); | ||||||
|  |                                 if(knownWorld[i][j].getFood() > 0){ | ||||||
|  |                                     g.setColor(Color.YELLOW); | ||||||
|  |                                 } | ||||||
|  |                                 g.fillRect(i*cellSize+1, j*cellSize+1, cellSize -1, cellSize-1); | ||||||
|  |                             } | ||||||
|  |                         } | ||||||
|  |                         if(hasFocus()){ | ||||||
|  |                             g.setColor(Color.RED); | ||||||
|  |                         }else { | ||||||
|  |                             g.setColor(Color.BLACK); | ||||||
|  |                         } | ||||||
|  |                         g.setFont(font); | ||||||
|  |                         g.drawString("A", pos.x * cellSize, (pos.y + 1) * cellSize); | ||||||
|  |                     } | ||||||
|  |                 }; | ||||||
|  |                 getViewport().add(worldPanel); | ||||||
|  |                 setVisible(true); | ||||||
|  |             } | ||||||
|  |         }; | ||||||
|  |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -35,21 +35,15 @@ public class AntWorld implements Environment<AntAction>{ | ||||||
| 
 | 
 | ||||||
|     private int tick; |     private int tick; | ||||||
|     private int maxEpisodeTicks; |     private int maxEpisodeTicks; | ||||||
|     MainFrame gui; |  | ||||||
| 
 | 
 | ||||||
|     public AntWorld(int width, int height, double foodDensity){ |     public AntWorld(int width, int height, double foodDensity){ | ||||||
|         grid = new Grid(width, height, foodDensity); |         grid = new Grid(width, height, foodDensity); | ||||||
|         antAgent = new AntAgent(width, height); |         antAgent = new AntAgent(width, height); | ||||||
|         myAnt = new Ant(); |         myAnt = new Ant(); | ||||||
|         gui = new MainFrame(this, antAgent); |  | ||||||
|         maxEpisodeTicks = 1000; |         maxEpisodeTicks = 1000; | ||||||
|         reset(); |         reset(); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     public MainFrame getGui(){ |  | ||||||
|         return gui; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public AntWorld(){ |     public AntWorld(){ | ||||||
|         this(Constants.DEFAULT_GRID_WIDTH, Constants.DEFAULT_GRID_HEIGHT, Constants.DEFAULT_FOOD_DENSITY); |         this(Constants.DEFAULT_GRID_WIDTH, Constants.DEFAULT_GRID_HEIGHT, Constants.DEFAULT_FOOD_DENSITY); | ||||||
|     } |     } | ||||||
|  | @ -166,7 +160,6 @@ public class AntWorld implements Environment<AntAction>{ | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|         StepResultEnvironment result = new StepResultEnvironment(newState, reward, done, info); |         StepResultEnvironment result = new StepResultEnvironment(newState, reward, done, info); | ||||||
|         getGui().update(action, result); |  | ||||||
|         return result; |         return result; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  | @ -216,6 +209,6 @@ public class AntWorld implements Environment<AntAction>{ | ||||||
|                 new AntWorld(3, 3, 0.1), |                 new AntWorld(3, 3, 0.1), | ||||||
|                 new ListDiscreteActionSpace<>(AntAction.values()) |                 new ListDiscreteActionSpace<>(AntAction.values()) | ||||||
|         ); |         ); | ||||||
|         monteCarlo.learn(20000,5); |         monteCarlo.learn(20000); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -0,0 +1,21 @@ | ||||||
|  | package example; | ||||||
|  | 
 | ||||||
|  | import core.RNG; | ||||||
|  | import core.algo.Method; | ||||||
|  | import core.controller.RLController; | ||||||
|  | import evironment.antGame.AntAction; | ||||||
|  | import evironment.antGame.AntWorld; | ||||||
|  | 
 | ||||||
|  | public class RunningAnt { | ||||||
|  |     public static void main(String[] args) { | ||||||
|  |         RNG.setSeed(1234); | ||||||
|  | 
 | ||||||
|  |         RLController<AntAction> rl = new RLController<AntAction>() | ||||||
|  |                             .setEnvironment(new AntWorld(3,3,0.1)) | ||||||
|  |                             .setAllowedActions(AntAction.values()) | ||||||
|  |                             .setMethod(Method.MC_ONPOLICY_EGREEDY) | ||||||
|  |                             .setDelay(10) | ||||||
|  |                             .setEpisodes(1000); | ||||||
|  |         rl.start(); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | @ -0,0 +1,52 @@ | ||||||
|  | package example; | ||||||
|  | 
 | ||||||
|  | public class Test { | ||||||
|  |     interface Drawable{ | ||||||
|  |         void draw(); | ||||||
|  |     } | ||||||
|  |     interface State{ | ||||||
|  |         int getInt(); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     static class A implements  Drawable, State{ | ||||||
|  |         private int k; | ||||||
|  |         public A(int a){ | ||||||
|  |             k = a; | ||||||
|  |         } | ||||||
|  |         @Override | ||||||
|  |         public void draw() { | ||||||
|  |             System.out.println("draw " + k); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         @Override | ||||||
|  |         public int getInt() { | ||||||
|  |             System.out.println("getInt" + k); | ||||||
|  |             return k; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     static class B  implements State{ | ||||||
|  |         @Override | ||||||
|  |         public int getInt() { | ||||||
|  |             return 0; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     public static void main(String[] args) { | ||||||
|  |         State state = new A(24); | ||||||
|  |         State state2 = new B(); | ||||||
|  |         state.getInt(); | ||||||
|  | 
 | ||||||
|  |         System.out.println(state2 instanceof Drawable); | ||||||
|  |         drawState(state2); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     static void drawState(State s){ | ||||||
|  |         if(s instanceof Drawable){ | ||||||
|  |             Drawable d = (Drawable) s; | ||||||
|  |             d.draw(); | ||||||
|  |         }else{ | ||||||
|  |             System.out.println("invalid"); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
		Loading…
	
		Reference in New Issue