add isJumping info to dinoState

This commit is contained in:
Jan Löwenstrom 2020-02-26 17:14:28 +01:00
parent 77898f4e5a
commit cff1a4e531
6 changed files with 15 additions and 14 deletions

View File

@ -42,7 +42,6 @@ public class QTableFrame<A extends Enum> extends JFrame {
}
}
protected void refreshQTable() {
System.out.println("ref");
int stateCount = stateActionTable.getStateCount();
stateCountLabel.setText("Total states: " + stateCount);
int idx = -1;

View File

@ -37,5 +37,6 @@ public class EpsilonGreedyPolicy<A extends Enum> implements EpsilonPolicy<A>{
// Take the action with the highest value
return greedyPolicy.chooseAction(actionValues);
}
}
}

View File

@ -14,12 +14,15 @@ import java.util.Objects;
@Getter
public class DinoState implements State, Serializable, Visualizable {
private int xDistanceToObstacle;
private boolean isJumping;
protected final double scale = 0.5;
@Override
public String toString() {
return "DinoState{" +
"xDistanceToObstacle=" + xDistanceToObstacle +
"isJumping=" + isJumping +
'}';
}
@ -28,12 +31,12 @@ public class DinoState implements State, Serializable, Visualizable {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
DinoState dinoState = (DinoState) o;
return xDistanceToObstacle == dinoState.xDistanceToObstacle;
return xDistanceToObstacle == dinoState.xDistanceToObstacle && isJumping == dinoState.isJumping;
}
@Override
public int hashCode() {
return Objects.hash(xDistanceToObstacle);
return Objects.hash(xDistanceToObstacle, isJumping);
}
@Override

View File

@ -10,8 +10,8 @@ import java.util.Objects;
public class DinoStateWithSpeed extends DinoState implements Visualizable {
private int obstacleSpeed;
public DinoStateWithSpeed(int xDistanceToObstacle, int obstacleSpeed) {
super(xDistanceToObstacle);
public DinoStateWithSpeed(int xDistanceToObstacle, boolean isJumping, int obstacleSpeed) {
super(xDistanceToObstacle, isJumping);
this.obstacleSpeed = obstacleSpeed;
}

View File

@ -50,7 +50,7 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
@Override
public StepResultEnvironment step(DinoAction action) {
boolean done = false;
int reward = 1;
int reward = 0;
if(action == DinoAction.JUMP){
dino.jump();
@ -74,11 +74,11 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
spawnNewObstacle();
}
if(ranIntoObstacle()) {
reward = 0;
reward = -1;
done = true;
}
return new StepResultEnvironment(new DinoStateWithSpeed(getDistanceToObstacle(), getCurrentObstacle().getDx()), reward, done, "");
return new StepResultEnvironment(new DinoStateWithSpeed(getDistanceToObstacle(), dino.isInJump(), getCurrentObstacle().getDx()), reward, done, "");
}
@ -110,7 +110,7 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
public State reset() {
spawnDino();
spawnNewObstacle();
return new DinoState(getDistanceToObstacle());
return new DinoState(getDistanceToObstacle(), dino.isInJump());
}
@Override

View File

@ -16,13 +16,11 @@ public class JumpingDino {
Method.Q_LEARNING_OFF_POLICY_CONTROL,
DinoAction.values());
rl.setDelay(10);
rl.setDiscountFactor(0.8f);
rl.setDelay(1000);
rl.setDiscountFactor(0.9f);
rl.setEpsilon(0.1f);
rl.setLearningRate(0.5f);
rl.setNrOfEpisodes(10000);
rl.setNrOfEpisodes(4000000);
rl.start();
}
}