add constant for default reward

This commit is contained in:
Jan Löwenstrom 2020-04-02 14:01:37 +02:00
parent e7404a8d24
commit 740289ee2b
12 changed files with 14 additions and 19 deletions

View File

@ -11,7 +11,7 @@
</list> </list>
</option> </option>
</component> </component>
<component name="ProjectRootManager" version="2" languageLevel="JDK_11" default="false" project-jdk-name="11.0.5" project-jdk-type="JavaSDK"> <component name="ProjectRootManager" version="2" languageLevel="JDK_11" default="false" project-jdk-name="11.0.3" project-jdk-type="JavaSDK">
<output url="file://$PROJECT_DIR$/out" /> <output url="file://$PROJECT_DIR$/out" />
</component> </component>
</project> </project>

Binary file not shown.

Before

Width:  |  Height:  |  Size: 18 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 28 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 34 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 27 KiB

View File

@ -16,8 +16,6 @@ import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
/** /**
* *
@ -99,7 +97,7 @@ public abstract class Learning<A extends Enum>{
public void save(ObjectOutputStream oos) throws IOException { public void save(ObjectOutputStream oos) throws IOException {
oos.writeObject(rewardHistory); oos.writeObject(rewardHistory);
oos.writeObject(stateActionTable); // oos.writeObject(stateActionTable);
} }
public void load(ObjectInputStream ois) throws IOException, ClassNotFoundException { public void load(ObjectInputStream ois) throws IOException, ClassNotFoundException {

View File

@ -7,7 +7,6 @@ import core.policy.GreedyPolicy;
import core.policy.Policy; import core.policy.Policy;
import evironment.antGame.Reward; import evironment.antGame.Reward;
import example.ContinuousAnt; import example.ContinuousAnt;
import example.DinoSampling;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
@ -77,7 +76,7 @@ public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearnin
foodCollected++; foodCollected++;
foodTimestampsTotal += timestampTilFood; foodTimestampsTotal += timestampTilFood;
if(foodCollected % 1000 == 0){ if(foodCollected % 1000 == 0){
System.out.println(foodTimestampsTotal/1000f); System.out.println(foodTimestampsTotal / 1000f + " " + timestampCurrentEpisode);
File file = new File(ContinuousAnt.FILE_NAME); File file = new File(ContinuousAnt.FILE_NAME);
try { try {
Files.writeString(Path.of(file.getPath()), foodTimestampsTotal/1000f +",", StandardOpenOption.APPEND); Files.writeString(Path.of(file.getPath()), foodTimestampsTotal/1000f +",", StandardOpenOption.APPEND);

View File

@ -54,7 +54,7 @@ public class AntWorld implements Environment<AntAction>, Visualizable {
protected StepCalculation processStep(AntAction action) { protected StepCalculation processStep(AntAction action) {
StepCalculation sc = new StepCalculation(); StepCalculation sc = new StepCalculation();
sc.reward = -1; sc.reward = Reward.DEFAULT_REWARD;
sc.info = ""; sc.info = "";
sc.done = false; sc.done = false;
Cell currentCell = grid.getCell(myAnt.getPos()); Cell currentCell = grid.getCell(myAnt.getPos());

View File

@ -1,6 +1,7 @@
package evironment.antGame; package evironment.antGame;
public class Reward { public class Reward {
public static final double DEFAULT_REWARD = -1;
public static final double FOOD_PICK_UP_SUCCESS = 0; public static final double FOOD_PICK_UP_SUCCESS = 0;
public static final double FOOD_PICK_UP_FAIL_NO_FOOD = -1; public static final double FOOD_PICK_UP_FAIL_NO_FOOD = -1;
public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = -1; public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = -1;

View File

@ -44,7 +44,7 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
@Override @Override
public StepResultEnvironment step(DinoAction action) { public StepResultEnvironment step(DinoAction action) {
boolean done = false; boolean done = false;
int reward = 0; int reward = 1;
if(action == DinoAction.JUMP){ if(action == DinoAction.JUMP){
dino.jump(); dino.jump();
@ -68,7 +68,7 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
spawnNewObstacle(); spawnNewObstacle();
} }
if(ranIntoObstacle()) { if(ranIntoObstacle()) {
reward = -1; reward = 0;
done = true; done = true;
} }

View File

@ -3,10 +3,8 @@ package example;
import core.RNG; import core.RNG;
import core.algo.Method; import core.algo.Method;
import core.controller.RLController; import core.controller.RLController;
import core.controller.RLControllerGUI;
import evironment.antGame.AntAction; import evironment.antGame.AntAction;
import evironment.antGame.AntWorldContinuous; import evironment.antGame.AntWorldContinuous;
import evironment.antGame.AntWorldContinuousOriginalState;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
@ -31,7 +29,7 @@ public class ContinuousAnt {
rl.setNrOfEpisodes(1); rl.setNrOfEpisodes(1);
//0.99 0.9 0.5 //0.99 0.9 0.5
//0.99 0.95 0.9 0.7 0.5 0.3 0.1 //0.99 0.95 0.9 0.7 0.5 0.3 0.1
rl.setDiscountFactor(0.05f); rl.setDiscountFactor(0.1f);
// 0.1, 0.3, 0.5, 0.7 0.9 // 0.1, 0.3, 0.5, 0.7 0.9
rl.setLearningRate(0.9f); rl.setLearningRate(0.9f);
rl.setEpsilon(0.2f); rl.setEpsilon(0.2f);

View File

@ -6,7 +6,6 @@ import core.controller.RLController;
import core.controller.RLControllerGUI; import core.controller.RLControllerGUI;
import evironment.jumpingDino.DinoAction; import evironment.jumpingDino.DinoAction;
import evironment.jumpingDino.DinoWorld; import evironment.jumpingDino.DinoWorld;
import evironment.jumpingDino.DinoWorldAdvanced;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
@ -34,13 +33,13 @@ public class DinoSampling {
System.out.println("seed: " + i * 13); System.out.println("seed: " + i * 13);
RNG.setSeed(i * 13); RNG.setSeed(i * 13);
RLController<DinoAction> rl = new RLController<>( RLController<DinoAction> rl = new RLControllerGUI<>(
new DinoWorldAdvanced(), new DinoWorld(),
Method.Q_LEARNING_OFF_POLICY_CONTROL, Method.MC_CONTROL_FIRST_VISIT,
DinoAction.values()); DinoAction.values());
rl.setDelay(0); rl.setDelay(300);
rl.setDiscountFactor(0.99f); rl.setDiscountFactor(1f);
rl.setEpsilon(f); rl.setEpsilon(0.5f);
rl.setLearningRate(0.9f); rl.setLearningRate(0.9f);
rl.setNrOfEpisodes(400000); rl.setNrOfEpisodes(400000);
rl.start(); rl.start();