add constant for default reward
This commit is contained in:
parent
e7404a8d24
commit
740289ee2b
|
@ -11,7 +11,7 @@
|
||||||
</list>
|
</list>
|
||||||
</option>
|
</option>
|
||||||
</component>
|
</component>
|
||||||
<component name="ProjectRootManager" version="2" languageLevel="JDK_11" default="false" project-jdk-name="11.0.5" project-jdk-type="JavaSDK">
|
<component name="ProjectRootManager" version="2" languageLevel="JDK_11" default="false" project-jdk-name="11.0.3" project-jdk-type="JavaSDK">
|
||||||
<output url="file://$PROJECT_DIR$/out" />
|
<output url="file://$PROJECT_DIR$/out" />
|
||||||
</component>
|
</component>
|
||||||
</project>
|
</project>
|
Binary file not shown.
Before Width: | Height: | Size: 18 KiB |
Binary file not shown.
Before Width: | Height: | Size: 28 KiB |
Binary file not shown.
Before Width: | Height: | Size: 34 KiB |
Binary file not shown.
Before Width: | Height: | Size: 27 KiB |
|
@ -16,8 +16,6 @@ import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.concurrent.CopyOnWriteArrayList;
|
import java.util.concurrent.CopyOnWriteArrayList;
|
||||||
import java.util.concurrent.ExecutorService;
|
|
||||||
import java.util.concurrent.Executors;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
|
@ -99,7 +97,7 @@ public abstract class Learning<A extends Enum>{
|
||||||
|
|
||||||
public void save(ObjectOutputStream oos) throws IOException {
|
public void save(ObjectOutputStream oos) throws IOException {
|
||||||
oos.writeObject(rewardHistory);
|
oos.writeObject(rewardHistory);
|
||||||
oos.writeObject(stateActionTable);
|
// oos.writeObject(stateActionTable);
|
||||||
}
|
}
|
||||||
|
|
||||||
public void load(ObjectInputStream ois) throws IOException, ClassNotFoundException {
|
public void load(ObjectInputStream ois) throws IOException, ClassNotFoundException {
|
||||||
|
|
|
@ -7,7 +7,6 @@ import core.policy.GreedyPolicy;
|
||||||
import core.policy.Policy;
|
import core.policy.Policy;
|
||||||
import evironment.antGame.Reward;
|
import evironment.antGame.Reward;
|
||||||
import example.ContinuousAnt;
|
import example.ContinuousAnt;
|
||||||
import example.DinoSampling;
|
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -77,7 +76,7 @@ public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearnin
|
||||||
foodCollected++;
|
foodCollected++;
|
||||||
foodTimestampsTotal += timestampTilFood;
|
foodTimestampsTotal += timestampTilFood;
|
||||||
if(foodCollected % 1000 == 0){
|
if(foodCollected % 1000 == 0){
|
||||||
System.out.println(foodTimestampsTotal/1000f);
|
System.out.println(foodTimestampsTotal / 1000f + " " + timestampCurrentEpisode);
|
||||||
File file = new File(ContinuousAnt.FILE_NAME);
|
File file = new File(ContinuousAnt.FILE_NAME);
|
||||||
try {
|
try {
|
||||||
Files.writeString(Path.of(file.getPath()), foodTimestampsTotal/1000f +",", StandardOpenOption.APPEND);
|
Files.writeString(Path.of(file.getPath()), foodTimestampsTotal/1000f +",", StandardOpenOption.APPEND);
|
||||||
|
|
|
@ -54,7 +54,7 @@ public class AntWorld implements Environment<AntAction>, Visualizable {
|
||||||
|
|
||||||
protected StepCalculation processStep(AntAction action) {
|
protected StepCalculation processStep(AntAction action) {
|
||||||
StepCalculation sc = new StepCalculation();
|
StepCalculation sc = new StepCalculation();
|
||||||
sc.reward = -1;
|
sc.reward = Reward.DEFAULT_REWARD;
|
||||||
sc.info = "";
|
sc.info = "";
|
||||||
sc.done = false;
|
sc.done = false;
|
||||||
Cell currentCell = grid.getCell(myAnt.getPos());
|
Cell currentCell = grid.getCell(myAnt.getPos());
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package evironment.antGame;
|
package evironment.antGame;
|
||||||
|
|
||||||
public class Reward {
|
public class Reward {
|
||||||
|
public static final double DEFAULT_REWARD = -1;
|
||||||
public static final double FOOD_PICK_UP_SUCCESS = 0;
|
public static final double FOOD_PICK_UP_SUCCESS = 0;
|
||||||
public static final double FOOD_PICK_UP_FAIL_NO_FOOD = -1;
|
public static final double FOOD_PICK_UP_FAIL_NO_FOOD = -1;
|
||||||
public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = -1;
|
public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = -1;
|
||||||
|
|
|
@ -44,7 +44,7 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
|
||||||
@Override
|
@Override
|
||||||
public StepResultEnvironment step(DinoAction action) {
|
public StepResultEnvironment step(DinoAction action) {
|
||||||
boolean done = false;
|
boolean done = false;
|
||||||
int reward = 0;
|
int reward = 1;
|
||||||
|
|
||||||
if(action == DinoAction.JUMP){
|
if(action == DinoAction.JUMP){
|
||||||
dino.jump();
|
dino.jump();
|
||||||
|
@ -68,7 +68,7 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
|
||||||
spawnNewObstacle();
|
spawnNewObstacle();
|
||||||
}
|
}
|
||||||
if(ranIntoObstacle()) {
|
if(ranIntoObstacle()) {
|
||||||
reward = -1;
|
reward = 0;
|
||||||
done = true;
|
done = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3,10 +3,8 @@ package example;
|
||||||
import core.RNG;
|
import core.RNG;
|
||||||
import core.algo.Method;
|
import core.algo.Method;
|
||||||
import core.controller.RLController;
|
import core.controller.RLController;
|
||||||
import core.controller.RLControllerGUI;
|
|
||||||
import evironment.antGame.AntAction;
|
import evironment.antGame.AntAction;
|
||||||
import evironment.antGame.AntWorldContinuous;
|
import evironment.antGame.AntWorldContinuous;
|
||||||
import evironment.antGame.AntWorldContinuousOriginalState;
|
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -31,7 +29,7 @@ public class ContinuousAnt {
|
||||||
rl.setNrOfEpisodes(1);
|
rl.setNrOfEpisodes(1);
|
||||||
//0.99 0.9 0.5
|
//0.99 0.9 0.5
|
||||||
//0.99 0.95 0.9 0.7 0.5 0.3 0.1
|
//0.99 0.95 0.9 0.7 0.5 0.3 0.1
|
||||||
rl.setDiscountFactor(0.05f);
|
rl.setDiscountFactor(0.1f);
|
||||||
// 0.1, 0.3, 0.5, 0.7 0.9
|
// 0.1, 0.3, 0.5, 0.7 0.9
|
||||||
rl.setLearningRate(0.9f);
|
rl.setLearningRate(0.9f);
|
||||||
rl.setEpsilon(0.2f);
|
rl.setEpsilon(0.2f);
|
||||||
|
|
|
@ -6,7 +6,6 @@ import core.controller.RLController;
|
||||||
import core.controller.RLControllerGUI;
|
import core.controller.RLControllerGUI;
|
||||||
import evironment.jumpingDino.DinoAction;
|
import evironment.jumpingDino.DinoAction;
|
||||||
import evironment.jumpingDino.DinoWorld;
|
import evironment.jumpingDino.DinoWorld;
|
||||||
import evironment.jumpingDino.DinoWorldAdvanced;
|
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -34,13 +33,13 @@ public class DinoSampling {
|
||||||
System.out.println("seed: " + i * 13);
|
System.out.println("seed: " + i * 13);
|
||||||
RNG.setSeed(i * 13);
|
RNG.setSeed(i * 13);
|
||||||
|
|
||||||
RLController<DinoAction> rl = new RLController<>(
|
RLController<DinoAction> rl = new RLControllerGUI<>(
|
||||||
new DinoWorldAdvanced(),
|
new DinoWorld(),
|
||||||
Method.Q_LEARNING_OFF_POLICY_CONTROL,
|
Method.MC_CONTROL_FIRST_VISIT,
|
||||||
DinoAction.values());
|
DinoAction.values());
|
||||||
rl.setDelay(0);
|
rl.setDelay(300);
|
||||||
rl.setDiscountFactor(0.99f);
|
rl.setDiscountFactor(1f);
|
||||||
rl.setEpsilon(f);
|
rl.setEpsilon(0.5f);
|
||||||
rl.setLearningRate(0.9f);
|
rl.setLearningRate(0.9f);
|
||||||
rl.setNrOfEpisodes(400000);
|
rl.setNrOfEpisodes(400000);
|
||||||
rl.start();
|
rl.start();
|
||||||
|
|
Loading…
Reference in New Issue