add isJumping info to dinoState
This commit is contained in:
parent
77898f4e5a
commit
cff1a4e531
|
@ -42,7 +42,6 @@ public class QTableFrame<A extends Enum> extends JFrame {
|
|||
}
|
||||
}
|
||||
protected void refreshQTable() {
|
||||
System.out.println("ref");
|
||||
int stateCount = stateActionTable.getStateCount();
|
||||
stateCountLabel.setText("Total states: " + stateCount);
|
||||
int idx = -1;
|
||||
|
|
|
@ -37,5 +37,6 @@ public class EpsilonGreedyPolicy<A extends Enum> implements EpsilonPolicy<A>{
|
|||
// Take the action with the highest value
|
||||
return greedyPolicy.chooseAction(actionValues);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,12 +14,15 @@ import java.util.Objects;
|
|||
@Getter
|
||||
public class DinoState implements State, Serializable, Visualizable {
|
||||
private int xDistanceToObstacle;
|
||||
private boolean isJumping;
|
||||
|
||||
protected final double scale = 0.5;
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "DinoState{" +
|
||||
"xDistanceToObstacle=" + xDistanceToObstacle +
|
||||
"isJumping=" + isJumping +
|
||||
'}';
|
||||
}
|
||||
|
||||
|
@ -28,12 +31,12 @@ public class DinoState implements State, Serializable, Visualizable {
|
|||
if (this == o) return true;
|
||||
if (o == null || getClass() != o.getClass()) return false;
|
||||
DinoState dinoState = (DinoState) o;
|
||||
return xDistanceToObstacle == dinoState.xDistanceToObstacle;
|
||||
return xDistanceToObstacle == dinoState.xDistanceToObstacle && isJumping == dinoState.isJumping;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(xDistanceToObstacle);
|
||||
return Objects.hash(xDistanceToObstacle, isJumping);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -10,8 +10,8 @@ import java.util.Objects;
|
|||
public class DinoStateWithSpeed extends DinoState implements Visualizable {
|
||||
private int obstacleSpeed;
|
||||
|
||||
public DinoStateWithSpeed(int xDistanceToObstacle, int obstacleSpeed) {
|
||||
super(xDistanceToObstacle);
|
||||
public DinoStateWithSpeed(int xDistanceToObstacle, boolean isJumping, int obstacleSpeed) {
|
||||
super(xDistanceToObstacle, isJumping);
|
||||
this.obstacleSpeed = obstacleSpeed;
|
||||
}
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
|
|||
@Override
|
||||
public StepResultEnvironment step(DinoAction action) {
|
||||
boolean done = false;
|
||||
int reward = 1;
|
||||
int reward = 0;
|
||||
|
||||
if(action == DinoAction.JUMP){
|
||||
dino.jump();
|
||||
|
@ -74,11 +74,11 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
|
|||
spawnNewObstacle();
|
||||
}
|
||||
if(ranIntoObstacle()) {
|
||||
reward = 0;
|
||||
reward = -1;
|
||||
done = true;
|
||||
}
|
||||
|
||||
return new StepResultEnvironment(new DinoStateWithSpeed(getDistanceToObstacle(), getCurrentObstacle().getDx()), reward, done, "");
|
||||
return new StepResultEnvironment(new DinoStateWithSpeed(getDistanceToObstacle(), dino.isInJump(), getCurrentObstacle().getDx()), reward, done, "");
|
||||
}
|
||||
|
||||
|
||||
|
@ -110,7 +110,7 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
|
|||
public State reset() {
|
||||
spawnDino();
|
||||
spawnNewObstacle();
|
||||
return new DinoState(getDistanceToObstacle());
|
||||
return new DinoState(getDistanceToObstacle(), dino.isInJump());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -16,13 +16,11 @@ public class JumpingDino {
|
|||
Method.Q_LEARNING_OFF_POLICY_CONTROL,
|
||||
DinoAction.values());
|
||||
|
||||
rl.setDelay(10);
|
||||
rl.setDiscountFactor(0.8f);
|
||||
rl.setDelay(1000);
|
||||
rl.setDiscountFactor(0.9f);
|
||||
rl.setEpsilon(0.1f);
|
||||
rl.setLearningRate(0.5f);
|
||||
rl.setNrOfEpisodes(10000);
|
||||
rl.setNrOfEpisodes(4000000);
|
||||
rl.start();
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue