add opening dialog to select all learning settings
This commit is contained in:
parent
9d1f8dfd46
commit
7d3d097599
|
@ -11,7 +11,7 @@
|
||||||
</list>
|
</list>
|
||||||
</option>
|
</option>
|
||||||
</component>
|
</component>
|
||||||
<component name="ProjectRootManager" version="2" languageLevel="JDK_11" default="false" project-jdk-name="11.0.3" project-jdk-type="JavaSDK">
|
<component name="ProjectRootManager" version="2" languageLevel="JDK_11" default="false" project-jdk-name="1.8" project-jdk-type="JavaSDK">
|
||||||
<output url="file://$PROJECT_DIR$/out" />
|
<output url="file://$PROJECT_DIR$/out" />
|
||||||
</component>
|
</component>
|
||||||
</project>
|
</project>
|
|
@ -7,8 +7,8 @@ plugins {
|
||||||
group 'net.lwenstrom.jan'
|
group 'net.lwenstrom.jan'
|
||||||
version '1.0-SNAPSHOT'
|
version '1.0-SNAPSHOT'
|
||||||
|
|
||||||
sourceCompatibility = 11
|
sourceCompatibility = 8
|
||||||
targetCompatibility = 11
|
targetCompatibility = 8
|
||||||
|
|
||||||
repositories {
|
repositories {
|
||||||
mavenCentral()
|
mavenCentral()
|
||||||
|
@ -29,8 +29,7 @@ dependencies {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Include dependent libraries in archive.
|
// Include dependent libraries in archive.
|
||||||
mainClassName = "example.DinoSampling"
|
mainClassName = "example.GUIMain"
|
||||||
|
|
||||||
|
|
||||||
jar {
|
jar {
|
||||||
manifest {
|
manifest {
|
||||||
|
|
|
@ -0,0 +1,114 @@
|
||||||
|
package core.controller;
|
||||||
|
|
||||||
|
import core.algo.Method;
|
||||||
|
import evironment.antGame.AntAction;
|
||||||
|
import evironment.antGame.AntWorldContinuous;
|
||||||
|
import evironment.antGame.Constants;
|
||||||
|
import evironment.blackjack.BlackJackTable;
|
||||||
|
import evironment.blackjack.PlayerAction;
|
||||||
|
import evironment.jumpingDino.DinoAction;
|
||||||
|
import evironment.jumpingDino.DinoWorld;
|
||||||
|
import evironment.jumpingDino.DinoWorldAdvanced;
|
||||||
|
|
||||||
|
import javax.swing.*;
|
||||||
|
import java.text.NumberFormat;
|
||||||
|
|
||||||
|
public class OpeningDialog {
|
||||||
|
|
||||||
|
// JSlider is integer value only. Instead of creating a subclass
|
||||||
|
// the JSlider value is simply divided by this scale factor.
|
||||||
|
// 100 does mean 0 to 1 in 0.01 steps.
|
||||||
|
private int scaleFactor = 100;
|
||||||
|
|
||||||
|
public OpeningDialog() {
|
||||||
|
createStartingDialog();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void setLabelText(JSlider slider, JLabel label, boolean scaledValue, String parameterName) {
|
||||||
|
if(scaledValue) {
|
||||||
|
label.setText(parameterName + ": " + (float) slider.getValue() / scaleFactor);
|
||||||
|
} else {
|
||||||
|
label.setText(parameterName + ": " + slider.getValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private void linkSliderWithLabel(JSlider slider, JLabel label, boolean scaledValue, String parameterName) {
|
||||||
|
setLabelText(slider, label, scaledValue, parameterName);
|
||||||
|
|
||||||
|
slider.addChangeListener(changeEvent -> setLabelText(slider, label, scaledValue, parameterName));
|
||||||
|
}
|
||||||
|
|
||||||
|
private void createStartingDialog() {
|
||||||
|
JComboBox<Scenario> scenarioSelection = new JComboBox<>(Scenario.values());
|
||||||
|
JComboBox<Method> algorithmSelection = new JComboBox<>(Method.values());
|
||||||
|
|
||||||
|
JSlider delaySlider = new JSlider(1, 1000, 200);
|
||||||
|
JLabel delayLabel = new JLabel();
|
||||||
|
linkSliderWithLabel(delaySlider, delayLabel, false, "Delay");
|
||||||
|
|
||||||
|
JSlider discountSlider = new JSlider(0, 100, 100);
|
||||||
|
JLabel discountLabel = new JLabel();
|
||||||
|
linkSliderWithLabel(discountSlider, discountLabel, true, "Discount Factor (gamma)");
|
||||||
|
|
||||||
|
JSlider epsilonSlider = new JSlider(0, 100, 15);
|
||||||
|
JLabel epsilonLabel = new JLabel();
|
||||||
|
linkSliderWithLabel(epsilonSlider, epsilonLabel, true, "Exploration Factor (epsilon)");
|
||||||
|
|
||||||
|
JSlider learningRateSlider = new JSlider(0, 100, 90);
|
||||||
|
JLabel learningRateLabel = new JLabel();
|
||||||
|
linkSliderWithLabel(learningRateSlider, learningRateLabel, true, "Learning rate (alpha)");
|
||||||
|
|
||||||
|
JTextField episodesToLearn = new JFormattedTextField(NumberFormat.getIntegerInstance());
|
||||||
|
episodesToLearn.setText("10000");
|
||||||
|
|
||||||
|
JTextField seedTextField = new JFormattedTextField(NumberFormat.getIntegerInstance());
|
||||||
|
seedTextField.setText("29");
|
||||||
|
|
||||||
|
|
||||||
|
Object[] parameters = {
|
||||||
|
"Environment:", scenarioSelection,
|
||||||
|
"Algorithm:", algorithmSelection,
|
||||||
|
discountLabel, discountSlider,
|
||||||
|
epsilonLabel, epsilonSlider,
|
||||||
|
learningRateLabel, learningRateSlider,
|
||||||
|
delayLabel, delaySlider,
|
||||||
|
"Episodes to learn:", episodesToLearn,
|
||||||
|
"RNG Seed: ", seedTextField,
|
||||||
|
};
|
||||||
|
|
||||||
|
int option = JOptionPane.showConfirmDialog(null, parameters, "Learning parameters", JOptionPane.OK_CANCEL_OPTION);
|
||||||
|
if(option == JOptionPane.OK_OPTION) {
|
||||||
|
Scenario selectedScenario = (Scenario) scenarioSelection.getSelectedItem();
|
||||||
|
|
||||||
|
RLControllerGUI rl;
|
||||||
|
if(selectedScenario == Scenario.JUMPING_DINO_SIMPLE) {
|
||||||
|
rl = new RLControllerGUI<DinoAction>(new DinoWorld(), (Method) algorithmSelection.getSelectedItem(), DinoAction.values());
|
||||||
|
} else if(selectedScenario == Scenario.JUMPING_DINO_ADVANCED) {
|
||||||
|
rl = new RLControllerGUI<DinoAction>(new DinoWorldAdvanced(), (Method) algorithmSelection.getSelectedItem(), DinoAction.values());
|
||||||
|
} else if(selectedScenario == Scenario.ANTGAME) {
|
||||||
|
rl = new RLControllerGUI<AntAction>(new AntWorldContinuous(Constants.DEFAULT_GRID_WIDTH, Constants.DEFAULT_GRID_HEIGHT), (Method) algorithmSelection.getSelectedItem(), AntAction.values());
|
||||||
|
} else if(selectedScenario == Scenario.BLACKJACK) {
|
||||||
|
rl = new RLControllerGUI<PlayerAction>(new BlackJackTable(), (Method) algorithmSelection.getSelectedItem(), PlayerAction.values());
|
||||||
|
} else {
|
||||||
|
throw new IllegalArgumentException("Invalid learning scenario selected");
|
||||||
|
}
|
||||||
|
rl.setDelay(delaySlider.getValue());
|
||||||
|
rl.setDiscountFactor((float) discountSlider.getValue() / scaleFactor);
|
||||||
|
rl.setEpsilon((float) epsilonSlider.getValue() / scaleFactor);
|
||||||
|
rl.setLearningRate((float) learningRateSlider.getValue() / scaleFactor);
|
||||||
|
rl.setNrOfEpisodes(Integer.parseInt(episodesToLearn.getText()));
|
||||||
|
rl.start();
|
||||||
|
} else {
|
||||||
|
System.out.println("Parameter dialog canceled");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private enum Scenario {
|
||||||
|
JUMPING_DINO_SIMPLE,
|
||||||
|
JUMPING_DINO_ADVANCED,
|
||||||
|
ANTGAME,
|
||||||
|
BLACKJACK
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -18,6 +18,7 @@ public class RLControllerGUI<A extends Enum> extends RLController<A> implements
|
||||||
super(env, method, actions);
|
super(env, method, actions);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void initListeners() {
|
protected void initListeners() {
|
||||||
SwingUtilities.invokeLater(() -> {
|
SwingUtilities.invokeLater(() -> {
|
||||||
|
|
|
@ -38,7 +38,7 @@ public class LearningInfoPanel extends JPanel {
|
||||||
episodeLabel = new JLabel();
|
episodeLabel = new JLabel();
|
||||||
add(episodeLabel);
|
add(episodeLabel);
|
||||||
}
|
}
|
||||||
delaySlider = new JSlider(0, 1000, learning.getDelay());
|
delaySlider = new JSlider(1, 1000, learning.getDelay());
|
||||||
delaySlider.addChangeListener(e -> viewListener.onDelayChange(delaySlider.getValue()));
|
delaySlider.addChangeListener(e -> viewListener.onDelayChange(delaySlider.getValue()));
|
||||||
add(policyLabel);
|
add(policyLabel);
|
||||||
add(discountLabel);
|
add(discountLabel);
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
package evironment.antGame;
|
package evironment.antGame;
|
||||||
|
|
||||||
public class Constants {
|
public class Constants {
|
||||||
public static final int DEFAULT_GRID_WIDTH = 5;
|
public static final int DEFAULT_GRID_WIDTH = 8;
|
||||||
public static final int DEFAULT_GRID_HEIGHT = 5;
|
public static final int DEFAULT_GRID_HEIGHT = 8;
|
||||||
public static final int START_X = 5;
|
public static final int START_X = 5;
|
||||||
public static final int START_Y = 2;
|
public static final int START_Y = 2;
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,10 +30,9 @@ public class Grid {
|
||||||
initialGrid[x][y] = new Cell(new Point(x, y), CellType.FREE);
|
initialGrid[x][y] = new Cell(new Point(x, y), CellType.FREE);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
spawnNewFood(initialGrid);
|
|
||||||
spawnObstacles();
|
spawnObstacles();
|
||||||
initialGrid[start.x][start.y] = new Cell(new Point(start.x, start.y), CellType.START);
|
initialGrid[start.x][start.y] = new Cell(new Point(start.x, start.y), CellType.START);
|
||||||
|
spawnNewFood(initialGrid);
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO
|
//TODO
|
||||||
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
package example;
|
||||||
|
|
||||||
|
import core.controller.OpeningDialog;
|
||||||
|
|
||||||
|
public class GUIMain {
|
||||||
|
public static void main(String[] args) {
|
||||||
|
new OpeningDialog();
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue