adopt MVC pattern and add real time graph interface
This commit is contained in:
parent
7f18a66e98
commit
e0160ca1df
|
@ -5,4 +5,4 @@
|
||||||
<profile default="true" name="Default" enabled="true" />
|
<profile default="true" name="Default" enabled="true" />
|
||||||
</annotationProcessing>
|
</annotationProcessing>
|
||||||
</component>
|
</component>
|
||||||
</project>
|
</project>
|
||||||
|
|
|
@ -13,6 +13,9 @@ repositories {
|
||||||
}
|
}
|
||||||
|
|
||||||
dependencies {
|
dependencies {
|
||||||
|
// https://mvnrepository.com/artifact/org.jfree/jfreechart
|
||||||
|
// https://mvnrepository.com/artifact/org.knowm.xchart/xchart
|
||||||
|
compile group: 'org.knowm.xchart', name: 'xchart', version: '3.2.2'
|
||||||
testCompile group: 'junit', name: 'junit', version: '4.12'
|
testCompile group: 'junit', name: 'junit', version: '4.12'
|
||||||
compileOnly 'org.projectlombok:lombok:1.18.10'
|
compileOnly 'org.projectlombok:lombok:1.18.10'
|
||||||
annotationProcessor 'org.projectlombok:lombok:1.18.10'
|
annotationProcessor 'org.projectlombok:lombok:1.18.10'
|
||||||
|
|
|
@ -3,7 +3,7 @@ package core;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public interface DiscreteActionSpace<A extends Enum> {
|
public interface DiscreteActionSpace<A extends Enum> {
|
||||||
int getNumberOfAction();
|
int getNumberOfActions();
|
||||||
void addAction(A a);
|
void addAction(A a);
|
||||||
void addActions(A... as);
|
void addActions(A... as);
|
||||||
List<A> getAllActions();
|
List<A> getAllActions();
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
package core;
|
||||||
|
|
||||||
|
public class LearningConfig {
|
||||||
|
public static final int DEFAULT_DELAY = 1;
|
||||||
|
public static final float DEFAULT_EPSILON = 0.1f;
|
||||||
|
public static final float DEFAULT_DISCOUNT_FACTOR = 1.0f;
|
||||||
|
}
|
|
@ -32,7 +32,7 @@ public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSp
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int getNumberOfAction(){
|
public int getNumberOfActions(){
|
||||||
return actions.size();
|
return actions.size();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,26 +2,62 @@ package core.algo;
|
||||||
|
|
||||||
import core.DiscreteActionSpace;
|
import core.DiscreteActionSpace;
|
||||||
import core.Environment;
|
import core.Environment;
|
||||||
|
import core.LearningConfig;
|
||||||
import core.StateActionTable;
|
import core.StateActionTable;
|
||||||
|
import core.listener.LearningListener;
|
||||||
import core.policy.Policy;
|
import core.policy.Policy;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.Setter;
|
||||||
|
|
||||||
|
import javax.swing.*;
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
@Getter
|
||||||
public abstract class Learning<A extends Enum> {
|
public abstract class Learning<A extends Enum> {
|
||||||
protected Policy<A> policy;
|
protected Policy<A> policy;
|
||||||
protected DiscreteActionSpace<A> actionSpace;
|
protected DiscreteActionSpace<A> actionSpace;
|
||||||
protected StateActionTable<A> stateActionTable;
|
protected StateActionTable<A> stateActionTable;
|
||||||
protected Environment<A> environment;
|
protected Environment<A> environment;
|
||||||
protected float discountFactor;
|
protected float discountFactor;
|
||||||
|
@Setter
|
||||||
protected float epsilon;
|
protected float epsilon;
|
||||||
|
protected Set<LearningListener> learningListeners;
|
||||||
|
@Setter
|
||||||
|
protected int delay;
|
||||||
|
|
||||||
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon){
|
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay){
|
||||||
this.environment = environment;
|
this.environment = environment;
|
||||||
this.actionSpace = actionSpace;
|
this.actionSpace = actionSpace;
|
||||||
this.discountFactor = discountFactor;
|
this.discountFactor = discountFactor;
|
||||||
this.epsilon = epsilon;
|
this.epsilon = epsilon;
|
||||||
}
|
this.delay = delay;
|
||||||
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace){
|
learningListeners = new HashSet<>();
|
||||||
this(environment, actionSpace, 1.0f, 0.1f);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract void learn(int nrOfEpisodes, int delay);
|
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon){
|
||||||
|
this(environment, actionSpace, discountFactor, epsilon, LearningConfig.DEFAULT_DELAY);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Learning(Environment<A> environment, DiscreteActionSpace<A> actionSpace){
|
||||||
|
this(environment, actionSpace, LearningConfig.DEFAULT_DISCOUNT_FACTOR, LearningConfig.DEFAULT_EPSILON, LearningConfig.DEFAULT_DELAY);
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract void learn(int nrOfEpisodes);
|
||||||
|
|
||||||
|
public void addListener(LearningListener learningListener){
|
||||||
|
learningListeners.add(learningListener);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void dispatchEpisodeEnd(double sum){
|
||||||
|
for(LearningListener l: learningListeners) {
|
||||||
|
l.onEpisodeEnd(sum);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected void dispatchEpisodeStart(){
|
||||||
|
for(LearningListener l: learningListeners){
|
||||||
|
l.onEpisodeStart();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,22 +34,21 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends Learning<A> {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void learn(int nrOfEpisodes, int delay) {
|
public void learn(int nrOfEpisodes) {
|
||||||
|
|
||||||
Map<Pair<State, A>, Double> returnSum = new HashMap<>();
|
Map<Pair<State, A>, Double> returnSum = new HashMap<>();
|
||||||
Map<Pair<State, A>, Integer> returnCount = new HashMap<>();
|
Map<Pair<State, A>, Integer> returnCount = new HashMap<>();
|
||||||
|
|
||||||
State startingState = environment.reset();
|
|
||||||
for(int i = 0; i < nrOfEpisodes; ++i) {
|
for(int i = 0; i < nrOfEpisodes; ++i) {
|
||||||
List<StepResult<A>> episode = new ArrayList<>();
|
List<StepResult<A>> episode = new ArrayList<>();
|
||||||
State state = environment.reset();
|
State state = environment.reset();
|
||||||
double rewardSum = 0;
|
double sumOfRewards = 0;
|
||||||
for(int j=0; j < 10; ++j){
|
for(int j=0; j < 10; ++j){
|
||||||
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
|
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
|
||||||
A chosenAction = policy.chooseAction(actionValues);
|
A chosenAction = policy.chooseAction(actionValues);
|
||||||
StepResultEnvironment envResult = environment.step(chosenAction);
|
StepResultEnvironment envResult = environment.step(chosenAction);
|
||||||
State nextState = envResult.getState();
|
State nextState = envResult.getState();
|
||||||
rewardSum += envResult.getReward();
|
sumOfRewards += envResult.getReward();
|
||||||
episode.add(new StepResult<>(state, chosenAction, envResult.getReward()));
|
episode.add(new StepResult<>(state, chosenAction, envResult.getReward()));
|
||||||
|
|
||||||
if(envResult.isDone()) break;
|
if(envResult.isDone()) break;
|
||||||
|
@ -57,13 +56,14 @@ public class MonteCarloOnPolicyEGreedy<A extends Enum> extends Learning<A> {
|
||||||
state = nextState;
|
state = nextState;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
Thread.sleep(1);
|
Thread.sleep(delay);
|
||||||
} catch (InterruptedException e) {
|
} catch (InterruptedException e) {
|
||||||
e.printStackTrace();
|
e.printStackTrace();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
System.out.printf("Episode %d \t Reward: %f \n", i, rewardSum);
|
dispatchEpisodeEnd(sumOfRewards);
|
||||||
|
System.out.printf("Episode %d \t Reward: %f \n", i, sumOfRewards);
|
||||||
Set<Pair<State, A>> stateActionPairs = new HashSet<>();
|
Set<Pair<State, A>> stateActionPairs = new HashSet<>();
|
||||||
|
|
||||||
for(StepResult<A> sr: episode){
|
for(StepResult<A> sr: episode){
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
package core.algo;
|
||||||
|
|
||||||
|
public enum Method {
|
||||||
|
MC_ONPOLICY_EGREEDY, TD_ONPOLICY
|
||||||
|
}
|
|
@ -0,0 +1,81 @@
|
||||||
|
package core.controller;
|
||||||
|
|
||||||
|
import core.DiscreteActionSpace;
|
||||||
|
import core.Environment;
|
||||||
|
import core.ListDiscreteActionSpace;
|
||||||
|
import core.algo.Learning;
|
||||||
|
import core.algo.Method;
|
||||||
|
import core.algo.mc.MonteCarloOnPolicyEGreedy;
|
||||||
|
import core.gui.View;
|
||||||
|
|
||||||
|
import javax.swing.*;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
public class RLController<A extends Enum> implements ViewListener{
|
||||||
|
protected Environment<A> environment;
|
||||||
|
protected Learning<A> learning;
|
||||||
|
protected DiscreteActionSpace<A> discreteActionSpace;
|
||||||
|
protected View<A> view;
|
||||||
|
private int delay;
|
||||||
|
private int nrOfEpisodes;
|
||||||
|
private Method method;
|
||||||
|
|
||||||
|
public RLController(){
|
||||||
|
}
|
||||||
|
|
||||||
|
public void start(){
|
||||||
|
if(environment == null || discreteActionSpace == null || method == null){
|
||||||
|
throw new RuntimeException("Set environment, discreteActionSpace and method before calling .start()");
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (method){
|
||||||
|
case MC_ONPOLICY_EGREEDY:
|
||||||
|
learning = new MonteCarloOnPolicyEGreedy<>(environment, discreteActionSpace);
|
||||||
|
break;
|
||||||
|
case TD_ONPOLICY:
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new RuntimeException("Undefined method");
|
||||||
|
}
|
||||||
|
SwingUtilities.invokeLater(() ->{
|
||||||
|
view = new View<>(learning, this);
|
||||||
|
learning.addListener(view);
|
||||||
|
});
|
||||||
|
learning.learn(nrOfEpisodes);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onEpsilonChange(float epsilon) {
|
||||||
|
learning.setEpsilon(epsilon);
|
||||||
|
SwingUtilities.invokeLater(() -> view.updateLearningInfoPanel());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onDelayChange(int delay) {
|
||||||
|
}
|
||||||
|
|
||||||
|
public RLController<A> setMethod(Method method){
|
||||||
|
this.method = method;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
public RLController<A> setEnvironment(Environment<A> environment){
|
||||||
|
this.environment = environment;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
@SafeVarargs
|
||||||
|
public final RLController<A> setAllowedActions(A... actions){
|
||||||
|
this.discreteActionSpace = new ListDiscreteActionSpace<>(actions);
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public RLController<A> setDelay(int delay){
|
||||||
|
this.delay = delay;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
public RLController<A> setEpisodes(int nrOfEpisodes){
|
||||||
|
this.nrOfEpisodes = nrOfEpisodes;
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,6 @@
|
||||||
|
package core.controller;
|
||||||
|
|
||||||
|
public interface ViewListener {
|
||||||
|
void onEpsilonChange(float epsilon);
|
||||||
|
void onDelayChange(int delay);
|
||||||
|
}
|
|
@ -0,0 +1,41 @@
|
||||||
|
package core.gui;
|
||||||
|
|
||||||
|
import core.algo.Learning;
|
||||||
|
import core.controller.ViewListener;
|
||||||
|
|
||||||
|
import javax.swing.*;
|
||||||
|
|
||||||
|
public class LearningInfoPanel extends JPanel {
|
||||||
|
private Learning learning;
|
||||||
|
private JLabel policyLabel;
|
||||||
|
private JLabel discountLabel;
|
||||||
|
private JLabel epsilonLabel;
|
||||||
|
private JSlider epsilonSlider;
|
||||||
|
private JSlider delaySlider;
|
||||||
|
|
||||||
|
public LearningInfoPanel(Learning learning, ViewListener viewListener){
|
||||||
|
this.learning = learning;
|
||||||
|
setLayout(new BoxLayout(this, BoxLayout.Y_AXIS));
|
||||||
|
policyLabel = new JLabel();
|
||||||
|
discountLabel = new JLabel();
|
||||||
|
epsilonLabel = new JLabel();
|
||||||
|
epsilonSlider = new JSlider(0, 100, (int)(learning.getEpsilon() * 100));
|
||||||
|
epsilonSlider.addChangeListener(e -> viewListener.onEpsilonChange(epsilonSlider.getValue() / 100f));
|
||||||
|
add(policyLabel);
|
||||||
|
add(discountLabel);
|
||||||
|
add(epsilonLabel);
|
||||||
|
add(epsilonSlider);
|
||||||
|
refreshLabels();
|
||||||
|
setVisible(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void refreshLabels(){
|
||||||
|
policyLabel.setText("Policy: " + learning.getPolicy().getClass());
|
||||||
|
discountLabel.setText("Discount factor: " + learning.getDiscountFactor());
|
||||||
|
epsilonLabel.setText("Exploration (Epsilon): " + learning.getEpsilon());
|
||||||
|
}
|
||||||
|
|
||||||
|
protected JSlider getEpsilonSlider(){
|
||||||
|
return epsilonSlider;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,102 @@
|
||||||
|
package core.gui;
|
||||||
|
|
||||||
|
import core.algo.Learning;
|
||||||
|
import core.controller.ViewListener;
|
||||||
|
import core.listener.LearningListener;
|
||||||
|
import lombok.Getter;
|
||||||
|
import org.knowm.xchart.QuickChart;
|
||||||
|
import org.knowm.xchart.XChartPanel;
|
||||||
|
import org.knowm.xchart.XYChart;
|
||||||
|
|
||||||
|
import javax.swing.*;
|
||||||
|
import java.awt.*;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class View<A extends Enum> implements LearningListener {
|
||||||
|
private Learning<A> learning;
|
||||||
|
@Getter
|
||||||
|
private XYChart chart;
|
||||||
|
@Getter
|
||||||
|
private LearningInfoPanel learningInfoPanel;
|
||||||
|
@Getter
|
||||||
|
private JFrame mainFrame;
|
||||||
|
private XChartPanel<XYChart> rewardChartPanel;
|
||||||
|
private ViewListener viewListener;
|
||||||
|
private List<Double> rewardHistory;
|
||||||
|
|
||||||
|
public View(Learning<A> learning, ViewListener viewListener){
|
||||||
|
this.learning = learning;
|
||||||
|
this.viewListener = viewListener;
|
||||||
|
rewardHistory = new ArrayList<>();
|
||||||
|
this.initMainFrame();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void initMainFrame(){
|
||||||
|
mainFrame = new JFrame();
|
||||||
|
mainFrame.setPreferredSize(new Dimension(1280, 720));
|
||||||
|
mainFrame.setLayout(new BorderLayout());
|
||||||
|
|
||||||
|
initLearningInfoPanel();
|
||||||
|
initRewardChart();
|
||||||
|
|
||||||
|
mainFrame.add(BorderLayout.WEST, learningInfoPanel);
|
||||||
|
mainFrame.add(BorderLayout.CENTER, rewardChartPanel);
|
||||||
|
|
||||||
|
mainFrame.setDefaultCloseOperation(WindowConstants.EXIT_ON_CLOSE);
|
||||||
|
mainFrame.pack();
|
||||||
|
mainFrame.setVisible(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void initLearningInfoPanel(){
|
||||||
|
learningInfoPanel = new LearningInfoPanel(learning, viewListener);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void initRewardChart(){
|
||||||
|
chart =
|
||||||
|
QuickChart.getChart(
|
||||||
|
"Rewards per Episode",
|
||||||
|
"Episode",
|
||||||
|
"Reward",
|
||||||
|
"randomWalk",
|
||||||
|
new double[] {0},
|
||||||
|
new double[] {0});
|
||||||
|
chart.getStyler().setLegendVisible(true);
|
||||||
|
chart.getStyler().setXAxisTicksVisible(true);
|
||||||
|
rewardChartPanel = new XChartPanel<>(chart);
|
||||||
|
rewardChartPanel.setPreferredSize(new Dimension(300,300));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void showState(Visualizable state){
|
||||||
|
new JFrame(){
|
||||||
|
{
|
||||||
|
JComponent stateComponent = state.visualize();
|
||||||
|
setPreferredSize(new Dimension(stateComponent.getWidth(), stateComponent.getHeight()));
|
||||||
|
add(stateComponent);
|
||||||
|
setVisible(true);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
public void updateRewardGraph(double recentReward){
|
||||||
|
rewardHistory.add(recentReward);
|
||||||
|
chart.updateXYSeries("randomWalk", null, rewardHistory, null);
|
||||||
|
rewardChartPanel.revalidate();
|
||||||
|
rewardChartPanel.repaint();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void updateLearningInfoPanel(){
|
||||||
|
this.learningInfoPanel.refreshLabels();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onEpisodeEnd(double sumOfRewards) {
|
||||||
|
SwingUtilities.invokeLater(()->updateRewardGraph(sumOfRewards));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void onEpisodeStart() {
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,7 @@
|
||||||
|
package core.gui;
|
||||||
|
|
||||||
|
import javax.swing.*;
|
||||||
|
|
||||||
|
public interface Visualizable {
|
||||||
|
JComponent visualize();
|
||||||
|
}
|
|
@ -0,0 +1,6 @@
|
||||||
|
package core.listener;
|
||||||
|
|
||||||
|
public interface LearningListener{
|
||||||
|
void onEpisodeEnd(double sumOfRewards);
|
||||||
|
void onEpisodeStart();
|
||||||
|
}
|
|
@ -1,7 +1,10 @@
|
||||||
package evironment.antGame;
|
package evironment.antGame;
|
||||||
|
|
||||||
import core.State;
|
import core.State;
|
||||||
|
import core.gui.Visualizable;
|
||||||
|
import evironment.antGame.gui.CellColor;
|
||||||
|
|
||||||
|
import javax.swing.*;
|
||||||
import java.awt.*;
|
import java.awt.*;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
|
||||||
|
@ -10,7 +13,7 @@ import java.util.Arrays;
|
||||||
* Essentially a snapshot of the current Ant Agent
|
* Essentially a snapshot of the current Ant Agent
|
||||||
* and therefor has to be deep copied
|
* and therefor has to be deep copied
|
||||||
*/
|
*/
|
||||||
public class AntState implements State {
|
public class AntState implements State, Visualizable {
|
||||||
private final Cell[][] knownWorld;
|
private final Cell[][] knownWorld;
|
||||||
private final Point pos;
|
private final Point pos;
|
||||||
private final boolean hasFood;
|
private final boolean hasFood;
|
||||||
|
@ -29,12 +32,12 @@ public class AntState implements State {
|
||||||
|
|
||||||
int unknown = 0;
|
int unknown = 0;
|
||||||
int diff = 0;
|
int diff = 0;
|
||||||
for (int i = 0; i < knownWorld.length; i++) {
|
for (Cell[] cells : knownWorld) {
|
||||||
for (int j = 0; j < knownWorld[i].length; j++) {
|
for (Cell cell : cells) {
|
||||||
if(knownWorld[i][j].getType() == CellType.UNKNOWN){
|
if (cell.getType() == CellType.UNKNOWN) {
|
||||||
unknown += 1;
|
unknown += 1;
|
||||||
}else{
|
} else {
|
||||||
diff +=1;
|
diff += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -62,7 +65,7 @@ public class AntState implements State {
|
||||||
@Override
|
@Override
|
||||||
public String toString(){
|
public String toString(){
|
||||||
return String.format("Pos: %s, hasFood: %b, knownWorld: %s", pos.toString(), hasFood, Arrays.toString(knownWorld));
|
return String.format("Pos: %s, hasFood: %b, knownWorld: %s", pos.toString(), hasFood, Arrays.toString(knownWorld));
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO: make this a utility function to generate hash Code based upon 2 prime numbers
|
//TODO: make this a utility function to generate hash Code based upon 2 prime numbers
|
||||||
@Override
|
@Override
|
||||||
|
@ -89,4 +92,62 @@ public class AntState implements State {
|
||||||
}
|
}
|
||||||
return super.equals(obj);
|
return super.equals(obj);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public JComponent visualize() {
|
||||||
|
return new JScrollPane() {
|
||||||
|
private int cellSize;
|
||||||
|
private final int paneWidth = 500;
|
||||||
|
private final int paneHeight = 500;
|
||||||
|
private Font font;
|
||||||
|
{
|
||||||
|
setPreferredSize(new Dimension(paneWidth, paneHeight));
|
||||||
|
cellSize = (paneWidth- knownWorld.length) /knownWorld.length;
|
||||||
|
font = new Font("plain", Font.BOLD, cellSize);
|
||||||
|
JPanel worldPanel = new JPanel(){
|
||||||
|
{
|
||||||
|
setPreferredSize(new Dimension(knownWorld.length * cellSize, knownWorld[0].length * cellSize));
|
||||||
|
setVisible(true);
|
||||||
|
|
||||||
|
addMouseWheelListener(e -> {
|
||||||
|
if(e.getWheelRotation() > 0){
|
||||||
|
cellSize -= 1;
|
||||||
|
}else {
|
||||||
|
cellSize += 1;
|
||||||
|
}
|
||||||
|
font = new Font("plain", Font.BOLD, cellSize);
|
||||||
|
setPreferredSize(new Dimension(knownWorld.length * cellSize, knownWorld[0].length * cellSize));
|
||||||
|
revalidate();
|
||||||
|
repaint();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void paintComponent(Graphics g) {
|
||||||
|
super.paintComponent(g);
|
||||||
|
for (int i = 0; i < knownWorld.length; i++) {
|
||||||
|
for (int j = 0; j < knownWorld[0].length; j++) {
|
||||||
|
g.setColor(Color.BLACK);
|
||||||
|
g.drawRect(i*cellSize, j*cellSize, cellSize, cellSize);
|
||||||
|
g.setColor(CellColor.map.get(knownWorld[i][j].getType()));
|
||||||
|
if(knownWorld[i][j].getFood() > 0){
|
||||||
|
g.setColor(Color.YELLOW);
|
||||||
|
}
|
||||||
|
g.fillRect(i*cellSize+1, j*cellSize+1, cellSize -1, cellSize-1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(hasFocus()){
|
||||||
|
g.setColor(Color.RED);
|
||||||
|
}else {
|
||||||
|
g.setColor(Color.BLACK);
|
||||||
|
}
|
||||||
|
g.setFont(font);
|
||||||
|
g.drawString("A", pos.x * cellSize, (pos.y + 1) * cellSize);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
getViewport().add(worldPanel);
|
||||||
|
setVisible(true);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -35,21 +35,15 @@ public class AntWorld implements Environment<AntAction>{
|
||||||
|
|
||||||
private int tick;
|
private int tick;
|
||||||
private int maxEpisodeTicks;
|
private int maxEpisodeTicks;
|
||||||
MainFrame gui;
|
|
||||||
|
|
||||||
public AntWorld(int width, int height, double foodDensity){
|
public AntWorld(int width, int height, double foodDensity){
|
||||||
grid = new Grid(width, height, foodDensity);
|
grid = new Grid(width, height, foodDensity);
|
||||||
antAgent = new AntAgent(width, height);
|
antAgent = new AntAgent(width, height);
|
||||||
myAnt = new Ant();
|
myAnt = new Ant();
|
||||||
gui = new MainFrame(this, antAgent);
|
|
||||||
maxEpisodeTicks = 1000;
|
maxEpisodeTicks = 1000;
|
||||||
reset();
|
reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
public MainFrame getGui(){
|
|
||||||
return gui;
|
|
||||||
}
|
|
||||||
|
|
||||||
public AntWorld(){
|
public AntWorld(){
|
||||||
this(Constants.DEFAULT_GRID_WIDTH, Constants.DEFAULT_GRID_HEIGHT, Constants.DEFAULT_FOOD_DENSITY);
|
this(Constants.DEFAULT_GRID_WIDTH, Constants.DEFAULT_GRID_HEIGHT, Constants.DEFAULT_FOOD_DENSITY);
|
||||||
}
|
}
|
||||||
|
@ -166,7 +160,6 @@ public class AntWorld implements Environment<AntAction>{
|
||||||
|
|
||||||
|
|
||||||
StepResultEnvironment result = new StepResultEnvironment(newState, reward, done, info);
|
StepResultEnvironment result = new StepResultEnvironment(newState, reward, done, info);
|
||||||
getGui().update(action, result);
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -216,6 +209,6 @@ public class AntWorld implements Environment<AntAction>{
|
||||||
new AntWorld(3, 3, 0.1),
|
new AntWorld(3, 3, 0.1),
|
||||||
new ListDiscreteActionSpace<>(AntAction.values())
|
new ListDiscreteActionSpace<>(AntAction.values())
|
||||||
);
|
);
|
||||||
monteCarlo.learn(20000,5);
|
monteCarlo.learn(20000);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,21 @@
|
||||||
|
package example;
|
||||||
|
|
||||||
|
import core.RNG;
|
||||||
|
import core.algo.Method;
|
||||||
|
import core.controller.RLController;
|
||||||
|
import evironment.antGame.AntAction;
|
||||||
|
import evironment.antGame.AntWorld;
|
||||||
|
|
||||||
|
public class RunningAnt {
|
||||||
|
public static void main(String[] args) {
|
||||||
|
RNG.setSeed(1234);
|
||||||
|
|
||||||
|
RLController<AntAction> rl = new RLController<AntAction>()
|
||||||
|
.setEnvironment(new AntWorld(3,3,0.1))
|
||||||
|
.setAllowedActions(AntAction.values())
|
||||||
|
.setMethod(Method.MC_ONPOLICY_EGREEDY)
|
||||||
|
.setDelay(10)
|
||||||
|
.setEpisodes(1000);
|
||||||
|
rl.start();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,52 @@
|
||||||
|
package example;
|
||||||
|
|
||||||
|
public class Test {
|
||||||
|
interface Drawable{
|
||||||
|
void draw();
|
||||||
|
}
|
||||||
|
interface State{
|
||||||
|
int getInt();
|
||||||
|
}
|
||||||
|
|
||||||
|
static class A implements Drawable, State{
|
||||||
|
private int k;
|
||||||
|
public A(int a){
|
||||||
|
k = a;
|
||||||
|
}
|
||||||
|
@Override
|
||||||
|
public void draw() {
|
||||||
|
System.out.println("draw " + k);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getInt() {
|
||||||
|
System.out.println("getInt" + k);
|
||||||
|
return k;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static class B implements State{
|
||||||
|
@Override
|
||||||
|
public int getInt() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void main(String[] args) {
|
||||||
|
State state = new A(24);
|
||||||
|
State state2 = new B();
|
||||||
|
state.getInt();
|
||||||
|
|
||||||
|
System.out.println(state2 instanceof Drawable);
|
||||||
|
drawState(state2);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void drawState(State s){
|
||||||
|
if(s instanceof Drawable){
|
||||||
|
Drawable d = (Drawable) s;
|
||||||
|
d.draw();
|
||||||
|
}else{
|
||||||
|
System.out.println("invalid");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue