diff --git a/.idea/misc.xml b/.idea/misc.xml index a19b0c6..cbb7577 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -11,7 +11,7 @@ - + \ No newline at end of file diff --git a/build.gradle b/build.gradle index 5af83ec..66bdd43 100644 --- a/build.gradle +++ b/build.gradle @@ -7,8 +7,8 @@ plugins { group 'net.lwenstrom.jan' version '1.0-SNAPSHOT' -sourceCompatibility = 11 -targetCompatibility = 11 +sourceCompatibility = 8 +targetCompatibility = 8 repositories { mavenCentral() @@ -29,8 +29,7 @@ dependencies { } // Include dependent libraries in archive. -mainClassName = "example.DinoSampling" - +mainClassName = "example.GUIMain" jar { manifest { diff --git a/src/main/java/core/controller/OpeningDialog.java b/src/main/java/core/controller/OpeningDialog.java new file mode 100644 index 0000000..4281157 --- /dev/null +++ b/src/main/java/core/controller/OpeningDialog.java @@ -0,0 +1,114 @@ +package core.controller; + +import core.algo.Method; +import evironment.antGame.AntAction; +import evironment.antGame.AntWorldContinuous; +import evironment.antGame.Constants; +import evironment.blackjack.BlackJackTable; +import evironment.blackjack.PlayerAction; +import evironment.jumpingDino.DinoAction; +import evironment.jumpingDino.DinoWorld; +import evironment.jumpingDino.DinoWorldAdvanced; + +import javax.swing.*; +import java.text.NumberFormat; + +public class OpeningDialog { + + // JSlider is integer value only. Instead of creating a subclass + // the JSlider value is simply divided by this scale factor. + // 100 does mean 0 to 1 in 0.01 steps. + private int scaleFactor = 100; + + public OpeningDialog() { + createStartingDialog(); + } + + private void setLabelText(JSlider slider, JLabel label, boolean scaledValue, String parameterName) { + if(scaledValue) { + label.setText(parameterName + ": " + (float) slider.getValue() / scaleFactor); + } else { + label.setText(parameterName + ": " + slider.getValue()); + } + + } + + private void linkSliderWithLabel(JSlider slider, JLabel label, boolean scaledValue, String parameterName) { + setLabelText(slider, label, scaledValue, parameterName); + + slider.addChangeListener(changeEvent -> setLabelText(slider, label, scaledValue, parameterName)); + } + + private void createStartingDialog() { + JComboBox scenarioSelection = new JComboBox<>(Scenario.values()); + JComboBox algorithmSelection = new JComboBox<>(Method.values()); + + JSlider delaySlider = new JSlider(1, 1000, 200); + JLabel delayLabel = new JLabel(); + linkSliderWithLabel(delaySlider, delayLabel, false, "Delay"); + + JSlider discountSlider = new JSlider(0, 100, 100); + JLabel discountLabel = new JLabel(); + linkSliderWithLabel(discountSlider, discountLabel, true, "Discount Factor (gamma)"); + + JSlider epsilonSlider = new JSlider(0, 100, 15); + JLabel epsilonLabel = new JLabel(); + linkSliderWithLabel(epsilonSlider, epsilonLabel, true, "Exploration Factor (epsilon)"); + + JSlider learningRateSlider = new JSlider(0, 100, 90); + JLabel learningRateLabel = new JLabel(); + linkSliderWithLabel(learningRateSlider, learningRateLabel, true, "Learning rate (alpha)"); + + JTextField episodesToLearn = new JFormattedTextField(NumberFormat.getIntegerInstance()); + episodesToLearn.setText("10000"); + + JTextField seedTextField = new JFormattedTextField(NumberFormat.getIntegerInstance()); + seedTextField.setText("29"); + + + Object[] parameters = { + "Environment:", scenarioSelection, + "Algorithm:", algorithmSelection, + discountLabel, discountSlider, + epsilonLabel, epsilonSlider, + learningRateLabel, learningRateSlider, + delayLabel, delaySlider, + "Episodes to learn:", episodesToLearn, + "RNG Seed: ", seedTextField, + }; + + int option = JOptionPane.showConfirmDialog(null, parameters, "Learning parameters", JOptionPane.OK_CANCEL_OPTION); + if(option == JOptionPane.OK_OPTION) { + Scenario selectedScenario = (Scenario) scenarioSelection.getSelectedItem(); + + RLControllerGUI rl; + if(selectedScenario == Scenario.JUMPING_DINO_SIMPLE) { + rl = new RLControllerGUI(new DinoWorld(), (Method) algorithmSelection.getSelectedItem(), DinoAction.values()); + } else if(selectedScenario == Scenario.JUMPING_DINO_ADVANCED) { + rl = new RLControllerGUI(new DinoWorldAdvanced(), (Method) algorithmSelection.getSelectedItem(), DinoAction.values()); + } else if(selectedScenario == Scenario.ANTGAME) { + rl = new RLControllerGUI(new AntWorldContinuous(Constants.DEFAULT_GRID_WIDTH, Constants.DEFAULT_GRID_HEIGHT), (Method) algorithmSelection.getSelectedItem(), AntAction.values()); + } else if(selectedScenario == Scenario.BLACKJACK) { + rl = new RLControllerGUI(new BlackJackTable(), (Method) algorithmSelection.getSelectedItem(), PlayerAction.values()); + } else { + throw new IllegalArgumentException("Invalid learning scenario selected"); + } + rl.setDelay(delaySlider.getValue()); + rl.setDiscountFactor((float) discountSlider.getValue() / scaleFactor); + rl.setEpsilon((float) epsilonSlider.getValue() / scaleFactor); + rl.setLearningRate((float) learningRateSlider.getValue() / scaleFactor); + rl.setNrOfEpisodes(Integer.parseInt(episodesToLearn.getText())); + rl.start(); + } else { + System.out.println("Parameter dialog canceled"); + } + } + + private enum Scenario { + JUMPING_DINO_SIMPLE, + JUMPING_DINO_ADVANCED, + ANTGAME, + BLACKJACK + } + +} diff --git a/src/main/java/core/controller/RLControllerGUI.java b/src/main/java/core/controller/RLControllerGUI.java index adc8a14..068edcb 100644 --- a/src/main/java/core/controller/RLControllerGUI.java +++ b/src/main/java/core/controller/RLControllerGUI.java @@ -18,6 +18,7 @@ public class RLControllerGUI extends RLController implements super(env, method, actions); } + @Override protected void initListeners() { SwingUtilities.invokeLater(() -> { diff --git a/src/main/java/core/gui/LearningInfoPanel.java b/src/main/java/core/gui/LearningInfoPanel.java index dfd8dbc..36f09ce 100644 --- a/src/main/java/core/gui/LearningInfoPanel.java +++ b/src/main/java/core/gui/LearningInfoPanel.java @@ -38,7 +38,7 @@ public class LearningInfoPanel extends JPanel { episodeLabel = new JLabel(); add(episodeLabel); } - delaySlider = new JSlider(0, 1000, learning.getDelay()); + delaySlider = new JSlider(1, 1000, learning.getDelay()); delaySlider.addChangeListener(e -> viewListener.onDelayChange(delaySlider.getValue())); add(policyLabel); add(discountLabel); diff --git a/src/main/java/evironment/antGame/Constants.java b/src/main/java/evironment/antGame/Constants.java index cdc8040..f001248 100644 --- a/src/main/java/evironment/antGame/Constants.java +++ b/src/main/java/evironment/antGame/Constants.java @@ -1,8 +1,8 @@ package evironment.antGame; public class Constants { - public static final int DEFAULT_GRID_WIDTH = 5; - public static final int DEFAULT_GRID_HEIGHT = 5; + public static final int DEFAULT_GRID_WIDTH = 8; + public static final int DEFAULT_GRID_HEIGHT = 8; public static final int START_X = 5; public static final int START_Y = 2; } diff --git a/src/main/java/evironment/antGame/Grid.java b/src/main/java/evironment/antGame/Grid.java index c6ef661..1ad6879 100644 --- a/src/main/java/evironment/antGame/Grid.java +++ b/src/main/java/evironment/antGame/Grid.java @@ -30,10 +30,9 @@ public class Grid { initialGrid[x][y] = new Cell(new Point(x, y), CellType.FREE); } } - spawnNewFood(initialGrid); spawnObstacles(); initialGrid[start.x][start.y] = new Cell(new Point(start.x, start.y), CellType.START); - + spawnNewFood(initialGrid); } //TODO diff --git a/src/main/java/example/GUIMain.java b/src/main/java/example/GUIMain.java new file mode 100644 index 0000000..e35b71e --- /dev/null +++ b/src/main/java/example/GUIMain.java @@ -0,0 +1,9 @@ +package example; + +import core.controller.OpeningDialog; + +public class GUIMain { + public static void main(String[] args) { + new OpeningDialog(); + } +}