split DinoWorld between simple and advanced example

# Conflicts:
#	src/main/java/example/JumpingDino.java
This commit is contained in:
Jan Löwenstrom 2020-03-05 11:58:57 +01:00
parent 9b54b72a25
commit e67f40ad65
4 changed files with 83 additions and 77 deletions

View File

@ -12,24 +12,15 @@ import java.awt.*;
@Getter @Getter
public class DinoWorld implements Environment<DinoAction>, Visualizable { public class DinoWorld implements Environment<DinoAction>, Visualizable {
private Dino dino; protected Dino dino;
private Obstacle currentObstacle; protected Obstacle currentObstacle;
private boolean randomObstacleSpeed;
private boolean randomObstacleDistance;
private DinoWorldComponent comp; private DinoWorldComponent comp;
public DinoWorld(boolean randomObstacleSpeed, boolean randomObstacleDistance){ public DinoWorld(){
this.randomObstacleSpeed = randomObstacleSpeed;
this.randomObstacleDistance = randomObstacleDistance;
dino = new Dino(Config.DINO_SIZE, Config.DINO_STARTING_X, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE, 0, 0, Color.GREEN); dino = new Dino(Config.DINO_SIZE, Config.DINO_STARTING_X, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE, 0, 0, Color.GREEN);
spawnNewObstacle(); spawnNewObstacle();
comp = new DinoWorldComponent(this); comp = new DinoWorldComponent(this);
} }
public DinoWorld(){
this(false, false);
}
private boolean ranIntoObstacle(){ private boolean ranIntoObstacle(){
Obstacle o = currentObstacle; Obstacle o = currentObstacle;
Dino p = dino; Dino p = dino;
@ -43,7 +34,7 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
return xAxis && yAxis; return xAxis && yAxis;
} }
private int getDistanceToObstacle(){ protected int getDistanceToObstacle(){
return currentObstacle.getX() - dino.getX() + Config.DINO_SIZE; return currentObstacle.getX() - dino.getX() + Config.DINO_SIZE;
} }
@ -78,27 +69,20 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
done = true; done = true;
} }
return new StepResultEnvironment(new DinoStateWithSpeed(getDistanceToObstacle(), dino.isInJump(), getCurrentObstacle().getDx()), reward, done, ""); return new StepResultEnvironment(generateReturnState(), reward, done, "");
} }
protected State generateReturnState(){
return new DinoState(getDistanceToObstacle(), dino.isInJump());
}
private void spawnNewObstacle(){ protected void spawnNewObstacle(){
int dx; int dx;
int xSpawn; int xSpawn;
if(randomObstacleSpeed){ dx = -Config.OBSTACLE_SPEED;
dx = -(int)((Math.random() + 0.5) * Config.OBSTACLE_SPEED); // instantly respawning on the left screen border
}else{ xSpawn = Config.FRAME_WIDTH + Config.OBSTACLE_SIZE;
dx = -Config.OBSTACLE_SPEED;
}
if(randomObstacleDistance){
// randomly spawning more right outside of the screen
xSpawn = (int)(Math.random() + 0.5 * Config.FRAME_WIDTH + Config.FRAME_WIDTH + Config.OBSTACLE_SIZE);
}else{
// instantly respawning on the left screen border
xSpawn = Config.FRAME_WIDTH + Config.OBSTACLE_SIZE;
}
currentObstacle = new Obstacle(Config.OBSTACLE_SIZE, xSpawn, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE, dx, 0, Color.BLACK); currentObstacle = new Obstacle(Config.OBSTACLE_SIZE, xSpawn, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE, dx, 0, Color.BLACK);
} }
@ -106,11 +90,12 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
private void spawnDino(){ private void spawnDino(){
dino = new Dino(Config.DINO_SIZE, Config.DINO_STARTING_X, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE, 0, 0, Color.GREEN); dino = new Dino(Config.DINO_SIZE, Config.DINO_STARTING_X, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE, 0, 0, Color.GREEN);
} }
@Override @Override
public State reset() { public State reset() {
spawnDino(); spawnDino();
spawnNewObstacle(); spawnNewObstacle();
return new DinoState(getDistanceToObstacle(), dino.isInJump()); return generateReturnState();
} }
@Override @Override

View File

@ -0,0 +1,26 @@
package evironment.jumpingDino;
import core.State;
import java.awt.*;
public class DinoWorldAdvanced extends DinoWorld{
public DinoWorldAdvanced(){
super();
}
@Override
protected State generateReturnState() {
return new DinoStateWithSpeed(getDistanceToObstacle(), dino.isInJump(), getCurrentObstacle().getDx());
}
@Override
protected void spawnNewObstacle() {
int dx;
int xSpawn;
dx = -(int)((Math.random() + 0.5) * Config.OBSTACLE_SPEED);
// randomly spawning more right outside of the screen
xSpawn = (int)(Math.random() + 0.5 * Config.FRAME_WIDTH + Config.FRAME_WIDTH + Config.OBSTACLE_SIZE);
currentObstacle = new Obstacle(Config.OBSTACLE_SIZE, xSpawn, Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE, dx, 0, Color.BLACK);
}
}

View File

@ -6,22 +6,42 @@ import core.controller.RLController;
import evironment.jumpingDino.DinoAction; import evironment.jumpingDino.DinoAction;
import evironment.jumpingDino.DinoWorld; import evironment.jumpingDino.DinoWorld;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
public class DinoSampling { public class DinoSampling {
public static void main(String[] args) { public static void main(String[] args) {
for (int i = 0; i < 10 ; i++) { File file = new File("convergence.txt");
RNG.setSeed(55); for(float f = 0.05f; f <=1.003 ; f+=0.05f){
try {
Files.writeString(Path.of(file.getPath()), f + ",", StandardOpenOption.APPEND);
} catch (IOException e) {
e.printStackTrace();
}
for(int i = 1; i <= 100; i++) {
System.out.println("seed: " + i *13);
RNG.setSeed(i *13);
RLController<DinoAction> rl = new RLController<>( RLController<DinoAction> rl = new RLController<>(
new DinoWorld(false, false), new DinoWorld(),
Method.MC_CONTROL_FIRST_VISIT, Method.MC_CONTROL_FIRST_VISIT,
DinoAction.values()); DinoAction.values());
rl.setDelay(0);
rl.setDiscountFactor(1f);
rl.setEpsilon(f);
rl.setLearningRate(1f);
rl.setNrOfEpisodes(20000);
rl.start();
rl.setDelay(0); }
rl.setDiscountFactor(1f); try {
rl.setEpsilon(0.15f); Files.writeString(Path.of(file.getPath()), "\n", StandardOpenOption.APPEND);
rl.setLearningRate(1f); } catch (IOException e) {
rl.setNrOfEpisodes(400); e.printStackTrace();
rl.start(); }
} }
} }
} }

View File

@ -7,45 +7,20 @@ import core.controller.RLControllerGUI;
import evironment.jumpingDino.DinoAction; import evironment.jumpingDino.DinoAction;
import evironment.jumpingDino.DinoWorld; import evironment.jumpingDino.DinoWorld;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
public class JumpingDino { public class JumpingDino {
public static void main(String[] args) { public static void main(String[] args) {
File file = new File("convergence.txt"); RNG.setSeed(55);
for(float f = 0.05f; f <=1.003 ; f+=0.05f){
try {
Files.writeString(Path.of(file.getPath()), f + ",", StandardOpenOption.APPEND);
} catch (IOException e) {
e.printStackTrace();
}
for(int i = 1; i <= 100; i++) {
System.out.println("seed: " + i *13);
RNG.setSeed(i *13);
RLController<DinoAction> rl = new RLController<>( RLController<DinoAction> rl = new RLControllerGUI<>(
new DinoWorld(false, false), new DinoWorld(),
Method.MC_CONTROL_FIRST_VISIT, Method.MC_CONTROL_FIRST_VISIT,
DinoAction.values()); DinoAction.values());
rl.setDelay(100);
rl.setDelay(0); rl.setDiscountFactor(1f);
rl.setDiscountFactor(1f); rl.setEpsilon(0.15f);
rl.setEpsilon(f); rl.setLearningRate(1f);
rl.setLearningRate(1f); rl.setNrOfEpisodes(10000);
rl.setNrOfEpisodes(20000); rl.start();
rl.start();
}
try {
Files.writeString(Path.of(file.getPath()), "\n", StandardOpenOption.APPEND);
} catch (IOException e) {
e.printStackTrace();
}
}
System.out.println("kek");
} }
} }