add isJumping info to dinoState
This commit is contained in:
parent
77898f4e5a
commit
cff1a4e531
|
@ -42,7 +42,6 @@ public class QTableFrame<A extends Enum> extends JFrame {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
protected void refreshQTable() {
|
protected void refreshQTable() {
|
||||||
System.out.println("ref");
|
|
||||||
int stateCount = stateActionTable.getStateCount();
|
int stateCount = stateActionTable.getStateCount();
|
||||||
stateCountLabel.setText("Total states: " + stateCount);
|
stateCountLabel.setText("Total states: " + stateCount);
|
||||||
int idx = -1;
|
int idx = -1;
|
||||||
|
|
|
@ -37,5 +37,6 @@ public class EpsilonGreedyPolicy<A extends Enum> implements EpsilonPolicy<A>{
|
||||||
// Take the action with the highest value
|
// Take the action with the highest value
|
||||||
return greedyPolicy.chooseAction(actionValues);
|
return greedyPolicy.chooseAction(actionValues);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,12 +14,15 @@ import java.util.Objects;
|
||||||
@Getter
|
@Getter
|
||||||
public class DinoState implements State, Serializable, Visualizable {
|
public class DinoState implements State, Serializable, Visualizable {
|
||||||
private int xDistanceToObstacle;
|
private int xDistanceToObstacle;
|
||||||
|
private boolean isJumping;
|
||||||
|
|
||||||
protected final double scale = 0.5;
|
protected final double scale = 0.5;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString() {
|
public String toString() {
|
||||||
return "DinoState{" +
|
return "DinoState{" +
|
||||||
"xDistanceToObstacle=" + xDistanceToObstacle +
|
"xDistanceToObstacle=" + xDistanceToObstacle +
|
||||||
|
"isJumping=" + isJumping +
|
||||||
'}';
|
'}';
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -28,12 +31,12 @@ public class DinoState implements State, Serializable, Visualizable {
|
||||||
if (this == o) return true;
|
if (this == o) return true;
|
||||||
if (o == null || getClass() != o.getClass()) return false;
|
if (o == null || getClass() != o.getClass()) return false;
|
||||||
DinoState dinoState = (DinoState) o;
|
DinoState dinoState = (DinoState) o;
|
||||||
return xDistanceToObstacle == dinoState.xDistanceToObstacle;
|
return xDistanceToObstacle == dinoState.xDistanceToObstacle && isJumping == dinoState.isJumping;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int hashCode() {
|
public int hashCode() {
|
||||||
return Objects.hash(xDistanceToObstacle);
|
return Objects.hash(xDistanceToObstacle, isJumping);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -10,8 +10,8 @@ import java.util.Objects;
|
||||||
public class DinoStateWithSpeed extends DinoState implements Visualizable {
|
public class DinoStateWithSpeed extends DinoState implements Visualizable {
|
||||||
private int obstacleSpeed;
|
private int obstacleSpeed;
|
||||||
|
|
||||||
public DinoStateWithSpeed(int xDistanceToObstacle, int obstacleSpeed) {
|
public DinoStateWithSpeed(int xDistanceToObstacle, boolean isJumping, int obstacleSpeed) {
|
||||||
super(xDistanceToObstacle);
|
super(xDistanceToObstacle, isJumping);
|
||||||
this.obstacleSpeed = obstacleSpeed;
|
this.obstacleSpeed = obstacleSpeed;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -50,7 +50,7 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
|
||||||
@Override
|
@Override
|
||||||
public StepResultEnvironment step(DinoAction action) {
|
public StepResultEnvironment step(DinoAction action) {
|
||||||
boolean done = false;
|
boolean done = false;
|
||||||
int reward = 1;
|
int reward = 0;
|
||||||
|
|
||||||
if(action == DinoAction.JUMP){
|
if(action == DinoAction.JUMP){
|
||||||
dino.jump();
|
dino.jump();
|
||||||
|
@ -74,11 +74,11 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
|
||||||
spawnNewObstacle();
|
spawnNewObstacle();
|
||||||
}
|
}
|
||||||
if(ranIntoObstacle()) {
|
if(ranIntoObstacle()) {
|
||||||
reward = 0;
|
reward = -1;
|
||||||
done = true;
|
done = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
return new StepResultEnvironment(new DinoStateWithSpeed(getDistanceToObstacle(), getCurrentObstacle().getDx()), reward, done, "");
|
return new StepResultEnvironment(new DinoStateWithSpeed(getDistanceToObstacle(), dino.isInJump(), getCurrentObstacle().getDx()), reward, done, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -110,7 +110,7 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
|
||||||
public State reset() {
|
public State reset() {
|
||||||
spawnDino();
|
spawnDino();
|
||||||
spawnNewObstacle();
|
spawnNewObstacle();
|
||||||
return new DinoState(getDistanceToObstacle());
|
return new DinoState(getDistanceToObstacle(), dino.isInJump());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -16,13 +16,11 @@ public class JumpingDino {
|
||||||
Method.Q_LEARNING_OFF_POLICY_CONTROL,
|
Method.Q_LEARNING_OFF_POLICY_CONTROL,
|
||||||
DinoAction.values());
|
DinoAction.values());
|
||||||
|
|
||||||
rl.setDelay(10);
|
rl.setDelay(1000);
|
||||||
rl.setDiscountFactor(0.8f);
|
rl.setDiscountFactor(0.9f);
|
||||||
rl.setEpsilon(0.1f);
|
rl.setEpsilon(0.1f);
|
||||||
rl.setLearningRate(0.5f);
|
rl.setLearningRate(0.5f);
|
||||||
rl.setNrOfEpisodes(10000);
|
rl.setNrOfEpisodes(4000000);
|
||||||
rl.start();
|
rl.start();
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue