Compare commits

...

12 Commits

Author SHA1 Message Date
Jan Löwenstrom 149b8f4bd8 Merge branch 'master' of https://github.com/kono94/refo 2020-04-19 20:56:21 +02:00
Jan Löwenstrom 19d5a87ce0 add multiple food scenario 2020-04-19 20:55:42 +02:00
Jan Löwenstrom 4590562a4c add new every visit results, fix rngEnv NullPointer 2020-04-19 19:14:03 +02:00
Jan Löwenstrom 7d3d097599 add opening dialog to select all learning settings 2020-04-07 11:03:17 +02:00
Jan Löwenstrom 9d1f8dfd46 apply code improvements suggested by intelliJ 2020-04-05 14:44:48 +02:00
Jan Löwenstrom 94ad976a1f spawn start of antgame constant 2020-04-05 14:07:51 +02:00
Jan Löwenstrom bbccef1e71 removed unnecessary stuff from sampling branches 2020-04-05 13:37:38 +02:00
Jan Löwenstrom 0300f3b1fd Merge branch 'antWorldRewardAnalysis'
# Conflicts:
#	src/main/java/core/algo/EpisodicLearning.java
#	src/main/java/core/controller/RLController.java
#	src/main/java/evironment/jumpingDino/DinoWorld.java
#	src/main/java/evironment/jumpingDino/DinoWorldAdvanced.java
#	src/main/java/example/JumpingDino.java
2020-04-05 13:21:20 +02:00
Jan Löwenstrom b1d06293fe add shadowJar 2020-03-05 12:25:42 +01:00
Jan Löwenstrom 1f743cf8f2 fix eps/sec stat 2020-03-05 12:09:36 +01:00
Jan Löwenstrom 18d6e32f64 split DinoWorld between simple and advanced example 2020-03-05 11:58:57 +01:00
Jan Löwenstrom cffec63dc6 apply threading changes to master branch and clean up for tag version
- no testing or epsilon testing stuff
2020-03-05 11:49:51 +01:00
29 changed files with 257 additions and 201 deletions

View File

@ -11,7 +11,7 @@
</list>
</option>
</component>
<component name="ProjectRootManager" version="2" languageLevel="JDK_11" default="false" project-jdk-name="11.0.3" project-jdk-type="JavaSDK">
<component name="ProjectRootManager" version="2" languageLevel="JDK_11" default="false" project-jdk-name="1.8" project-jdk-type="JavaSDK">
<output url="file://$PROJECT_DIR$/out" />
</component>
</project>

View File

@ -7,8 +7,8 @@ plugins {
group 'net.lwenstrom.jan'
version '1.0-SNAPSHOT'
sourceCompatibility = 11
targetCompatibility = 11
sourceCompatibility = 8
targetCompatibility = 8
repositories {
mavenCentral()
@ -29,8 +29,7 @@ dependencies {
}
// Include dependent libraries in archive.
mainClassName = "example.DinoSampling"
mainClassName = "example.GUIMain"
jar {
manifest {

View File

@ -0,0 +1,21 @@
# No Jump MC Every Visist
0.05,36,29147,7945,707,12,595,68,4913,8034,300,804,174721,1731,715,64899,423,703,2685,173,468,8220,4182,80071,1418,157366,17769,216,3859,44746,218,6794,2,4182,35,258,9015,2581,62,39,17439,7346,1000,56974,21,71655,20301,4834,18144,18393,613,5,39,609,586,6,28871,1021,18051,457,13629,59217,10278,173,24407,1373,48513,31749,190,25159,130,121036,609,609,55,924,77,4379,72708,202,2096,25965,14394,1230,2833,1791,25512,230,16227
0.1,742,1269,9139,6764,782,25068,539,284,678,9327,2166,3469,3057,2971,1335,136645,64551,22900,35519,62,29,15732,906,16,17,960,190,32183,2,40863,3324,37,6610,97,1578,12006,1784,3388,54,576,64,404,1638,95916,108,281,26288,1825,28,89,2359,49218,853,17821,10514,283,153967,44669,256,6,12708,13085,3466,56628,8513,870,15,53906,20,1533,2164,1458,538,226,99,163,16,31398,76,2075,1474,36605,434,120,3744,91,2885,112,3285,69283,21176,3029,10113
0.15,1085,42379,2065,4847,124,27681,2257,1617,4597,1033,39095,254,120,1343,2007,1907,2992,15207,95,7835,495,16224,362,97541,141,8692,32124,1302,1768,10853,480,2,2099,25,1889,66,67,11874,290,77899,7253,31,1190,72,95,141,7732,32,26,16,5,181,361,12321,22,1631,7180,18,40,310,95444,7146,102,5210,6,19593,434,3581,3120,35,14,980,19,3816,677,1494,337,68,2764,53854,147,28,732,23755,270,4818,15,3046,1557,117,13049,323,26709,2036,431,85
0.2,11,119,181736,8993,10,949,1495,20497,1325,5672,39392,35828,233,56,2707,456,12,3630,112,257,936,108,618,2353,52,2563,1632,10094,12336,24,1390,11,1130,4381,78,14658,1228,514,4706,215,15898,100,639,110,723,340,21,25,836,4,83,922,90321,22,747,11599,88,13,11,2101,43,642,1272,1742,169,614,832,681,41884,17,3159,392,9,178,2183,3271,10,378,69,35741,293,63,2733,3825,745,9,3151,2934,588,214,30150,1138,1304,98,566,1497,4474
0.25,21,222,230,10720,108,2978,2377,1398,3310,10274,27275,1674,22,720,293,52,863,16756,76,520,28985,247,195,5507,97,5916,96,11,1346,77,157,8,14,501,40,1502,32,58,561,8,327,74,91,20,5444,914,4278,30,27,399,2186,4,75,1149,2386,22,3415,3918,56,438,8,160,451,18150,520,329,139654,33850,16,625,128,463,73,159639,147,33074,2545,198,1210,637,1143,180,824,171,293,16,9592,699,8,929,664,1637,2524,18383,256,31,540,233,265,998
0.3,15,2847,8254,2332,2921,62,61322,1967,2375,4488,1142,1502,83,5177,519,128,7727,58,25,552,365,119,141,11859,7,9742,8,754,415,1534,641,8,498,12364,38,1062,75,336,3247,8,540,683,13,396,9242,194,24,935,62,83,611,4,5,5047,3134,243,1647,3294,96,257,9,1777,38112,6059,118,1702,4305,48,25907,113,715,8233,797,1721,147,39342,1111,26,11991,516,3292,24,126,1012,3290,305,5371,145,11939,65779,9057,3241,76,6924,2896,266,4380,902,634,577
0.35,68,47,70,1307,4155,1031,15740,175,829,920,2649,599,214,494,544,5026,1647,2007,48,182,294,670,156,613,20,1161,250,1710,4153,270,78,81,17,81,4816,5955,87,428,189,8,235,225,3151,7884,205,1322,25,31,787,15,3728,4,5,1314,3193,454,1175,167,176,7596,49,74,3110,389,385,2940,317,17,3846,94,99,202,88,22,1957,749,2767,57,334,326,27205,1928,14,285,3293,43,6288,43,643,36225,58,2835,1126,8335,1101,1874,2916,391,2455,13320
0.40,430,2482,4292,2558,7285,165,295,845,3480,43,2391,3,269,448,886,7400,6280,704,61,178,2706,4046,1089,139,947,855,46,415,1668,6922,235,22,5447,19,401,706,67,5379,73,527,178,3307,1254,996,276,4191,760,554,1140,15987,175,4,22,7793,4102,129,583,1056,228,27992,615,51,32,259,1508,867,2,1314,170,1135,390,3982,69,1228,1556,324,288,4165,504,504,173,313,85,2101,390,62,26544,468,793,1533,63,1838,167,3753,4042,402,6142,21336,841,72251
0.45,684,893,1075,3710,10131,9449,999,337,136,3812,80,3,3113,132,11540,277,639,315,74,55,2061,53,23,3426,513,4798,60,37,7000,76680,2757,22,1241,87,1207,4540,49,4373,48,10391,561,1698,214,507,382,251,245,998,142,1680,290,4,21,390,13326,209,916,2873,108,38,417,16,13,503,247,4716,8,23,4030,1862,104,681,114,40,638,647,3693,22959,90,249,85,485,1448,346,1575,524,429,3365,489,1124,177,3188,188,1880,110,263,108,34,408,25180
0.50,42,16830,1941,220,3775,815,2031,498,3436,2897,1061,3,2619,192,1034,672,14834,214,125,3387,1480,345,32,4599,682,350,45,572,863,45168,66,186,231,4381,596,1034,723,3776,10695,110,3889,975,809,86,579,16,40,228,180,259,54,4,1241,649,22145,372,671,1055,2444,1949,61,229,17,2392,266,1586,1820,24,2182,3007,124,270,233,46,1526,295,1276,996,352,1225,17,341,141443,271,512,3197,474,3692,8,888,581,1091,71,4663,87,4268,1127,3638,1529,661
0.55,3815,777,987,185,27,1230,1594,2210,312,303,15,2,2256,3099,7131,134,3533,158,5467,1954,3945,160,281,1095,216,676,60,6476,429,267,1781,966,1156,2818,3082,16615,233,3614,95,1157,2991,1286,78,1461,783,16,40,1439,5110,22,322,4,878,5,111,1464,249,6612,631,10027,145,1804,1304,2465,992,196,32,13,204,1639,255,1151,409,6002,331,585,1023,2579,1625,730,16,249,929,265,782,327,3763,142,24,4778,130,27170,653,30,64,2757,7624,10265,242,782
0.60,72,219,1014,1553,13880,289,3735,1411,4136,112,3301,2,4284,1165,7631,675,1641,3996,728,1337,127,33,2489,3893,429,171,5,3243,128,5598,1049,1360,831,862,178,463,113,86,3557,1211,40,1309,185,34750,75,15,28,6889,743,49,15,4,582,3927,61,83,841,1639,4,984,195,2585,50,2949,2506,7500,33,14,901,2253,267,1783,254,2056,110,225,4108,3027,255,1567,24,905,91,3740,500,2585,2344,2102,160,468,258,3133,3619,5072,2158,2844,2325,343,18,2851
0.65,259,481,1669,5179,15,2300,180,118,200,5180,238,2,15,1424,6185,1967,149,291,9749,666,1616,680,3145,96,650,5056,1247,1082,310,374,1614,937,563,322,1963,161,22,38,78,821,206,1006,390,319,1613,1992,1230,535,96,320,17,4,3067,6450,552,991,17929,3373,28,204,47,143,73,632,2433,2131,78,14,412,3296,782,4784,1067,629,2132,4082,9844,6593,2112,412,401,19872,1336,368,2487,319,814,27,170,1362,100,1678,127,1800,6572,5779,693,1565,2088,798
0.70,1007,185,1973,2285,15,28,406,1468,382,1237,11692,2,55,1204,172,1042,3029,583,47,1430,2597,3431,2986,3363,4073,355,21980,1160,916,6140,893,535,13,418,578,535,3,47,3353,9003,816,3294,393,2578,2874,4252,1128,786,381,207,10,4,6397,895,2210,4982,763,54,11,1537,2,712,26,8748,123,3861,65,958,811,62,116,429,409,394,1029,1611,7386,298,1518,1322,3053,1359,5281,385,3349,4487,1939,5949,533,1506,3234,1492,2744,4,2985,7713,1492,3836,7254,342
0.75,6523,317,3968,233,15,80,2863,4066,16,3262,883,704,117,5443,2409,59,3848,3450,65,87,3764,1652,62,1646,2864,512,1226,1816,3587,378,4243,1837,10,5908,16119,4734,5,328,2639,897,353,6208,4281,1671,1520,610,451,399,3650,2542,10,3,23,66,245,2607,58,470,4,1989,2,60,3256,3603,4903,568,743,251,14701,2671,4777,698,176,26,1302,787,3275,2738,1335,2634,8791,5806,2222,13223,311,4218,8479,5535,8898,189,2425,1919,2482,4,1536,3469,2201,345,390,5228
0.80,83,2976,2164,17000,16090,4483,13563,987,14,505,282,3317,2053,387,12252,247,3519,4557,378,71,756,910,1921,2874,1862,1829,731,54,234,715,682,10015,530,633,3911,4709,265,116,104,9814,1312,601,61,19856,1102,595,10881,3728,3633,3388,9,3,811,1999,4075,2106,1888,2582,4,12564,154,383,15,2438,436,824,947,6212,3273,73,772,49,2818,11875,4322,4457,3292,3180,170,55,4707,100,5989,1835,4546,5414,808,65,7,811,1783,19664,219,4,248,1953,996,32118,7,257
0.85,2075,497,11265,1960,3624,34,366,957,14,2817,4413,18269,5069,2370,2272,12272,330,2063,4130,6284,3628,9,1938,613,15316,3474,5446,100,3894,2899,67,892,4003,348,135,1151,5953,6105,6149,5830,2078,1831,4313,626,269,877,3822,19,1106,139,9,3,18364,1119,8355,22487,3265,1415,5,13005,781,27358,728,1977,616,6789,341,7050,2368,24,1281,1913,15870,1985,619,7965,5509,538,620,2012,2022,10890,20395,1500,12,6141,8327,107,7,8407,109,45,100,4,874,632,2101,396,7,5154
0.90,902,3427,5470,42978,9193,2969,5514,2288,91,2048,1085,982,16,4210,3702,103,14119,5531,1076,2111,1360,16,10217,4417,544,607,1414,5350,1027,9041,3697,12502,43926,105,2941,31782,995,15661,773,338,153,19508,11985,33683,3006,4043,2942,2541,3817,2281,8561,6771,2305,37689,9198,19,899,13884,1818,12144,319,416,250,11815,44066,4825,783,4261,286,162,13125,138,3183,627,609,1736,1554,989,7559,3504,6641,3354,47,972,11,14096,278,15664,7,15571,389,286,1367,4,2434,25,297,3388,7,2743
0.95,493,6960,769,1563,660,4208,778,8917,98,72,21014,100,6872,7998,26519,154,4073,178,11475,3714,184,9,7250,2416,1161,38061,85,3569,25036,7121,30470,300,270,8458,668,5242,610,21,2330,15540,3238,2429,8334,103,3945,6469,166,581,27569,8,3,7695,4772,212,4868,15107,11469,132,2466,1252,4425,21010,6871,2827,155,4470,59064,5478,504,38380,5933,18,4313,4131,26218,5916,1479,4514,14507,64879,4715,2885,11,7370,2070,4241,7,11518,6675,3896,6503,26757,38,15158,1246,63,7,86
1.00,5753,31823,1131,5008,814,2498,3289,4148,1189,3345,455,16054,1727,4275,8701,177,354,21971,2838,2195,9,7887,682,34,1657,23710,296,29250,5562,79,7168,176,369,9651,140,19561,4518,6518,3711,20169,4594,105804,3547,809,14814,220,774,16,2280,3,318,1921,6150,254,8587,122362,7848,23995,2323,7108,2348,58820,3945,62059,984,10839,10909,28613,8046,20249,12086,4325,9060,176,86,47797,1107,119,12405,8481,12,4521,683,171076,92,419,1714,584,596,102,47669,23214,665,98,8,486

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

View File

@ -1,7 +1,10 @@
package core;
import java.io.Serializable;
import java.util.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
/**
* Implementation of a discrete action space.
@ -18,6 +21,7 @@ public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSp
actions = new ArrayList<>();
}
@SafeVarargs
public ListDiscreteActionSpace(A... actions){
this.actions = new ArrayList<>(Arrays.asList(actions));
}

View File

@ -16,16 +16,15 @@ import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
public abstract class EpisodicLearning<A extends Enum> extends Learning<A> implements Episodic {
private volatile AtomicInteger episodesToLearn = new AtomicInteger(0);
private int episodeSumCurrentSecond;
@Setter
protected int currentEpisode = 0;
protected volatile AtomicInteger episodesToLearn = new AtomicInteger(0);
@Getter
protected volatile int episodePerSecond;
protected int episodeSumCurrentSecond;
protected double sumOfRewards;
protected List<StepResult<A>> episode = new ArrayList<>();
protected int timestampCurrentEpisode = 0;
protected boolean converged;
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay) {
super(environment, actionSpace, discountFactor, delay);
initBenchMarking();
@ -50,7 +49,7 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
private void initBenchMarking(){
new Thread(()->{
while (currentlyLearning){
while (true){
episodePerSecond = episodeSumCurrentSecond;
episodeSumCurrentSecond = 0;
try {
@ -85,7 +84,6 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
protected void dispatchStepEnd() {
super.dispatchStepEnd();
timestamp++;
timestampCurrentEpisode++;
}
@Override
@ -96,9 +94,7 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
private void startLearning(){
dispatchLearningStart();
while(episodesToLearn.get() > 0){
dispatchEpisodeStart();
timestampCurrentEpisode = 0;
nextEpisode();
dispatchEpisodeEnd();
}

View File

@ -23,10 +23,6 @@ import java.util.concurrent.CopyOnWriteArrayList;
*/
@Getter
public abstract class Learning<A extends Enum>{
// TODO: temp testing -> extract to dedicated test
protected int checkSum;
protected int rewardCheckSum;
// current discrete timestamp t
protected int timestamp;
protected int currentEpisode;

View File

@ -21,7 +21,7 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A
private Map<Pair<State, A>, Double> returnSum;
private Map<Pair<State, A>, Integer> returnCount;
private boolean isEveryVisit;
private final boolean isEveryVisit;
public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay, boolean useEveryVisit) {
@ -60,7 +60,6 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A
envResult = environment.step(chosenAction);
State nextState = envResult.getState();
sumOfRewards += envResult.getReward();
rewardCheckSum += envResult.getReward();
episode.add(new StepResult<>(state, chosenAction, envResult.getReward()));
state = nextState;
@ -74,8 +73,6 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A
dispatchStepEnd();
}
// System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);
HashMap<Pair<State, A>, List<Integer>> stateActionPairs = new LinkedHashMap<>();

View File

@ -5,18 +5,12 @@ import core.algo.EpisodicLearning;
import core.policy.EpsilonGreedyPolicy;
import core.policy.GreedyPolicy;
import core.policy.Policy;
import evironment.antGame.Reward;
import example.ContinuousAnt;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.Map;
public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearning<A> {
private float alpha;
private Policy<A> greedyPolicy = new GreedyPolicy<>();
public QLearningOffPolicyTDControl(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, float learningRate, int delay) {
@ -40,13 +34,9 @@ public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearnin
}
StepResultEnvironment envResult = null;
Map<A, Double> actionValues = null;
Map<A, Double> actionValues;
sumOfRewards = 0;
int timestampTilFood = 0;
int foodCollected = 0;
int foodTimestampsTotal= 0;
while(envResult == null || !envResult.isDone()) {
actionValues = stateActionTable.getActionValues(state);
A action = policy.chooseAction(actionValues);
@ -56,44 +46,6 @@ public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearnin
double reward = envResult.getReward();
State nextState = envResult.getState();
sumOfRewards += reward;
timestampTilFood++;
if(reward == Reward.FOOD_DROP_DOWN_SUCCESS) {
foodCollected++;
foodTimestampsTotal += timestampTilFood;
File file = new File(ContinuousAnt.FILE_NAME);
if(foodCollected % 1000 == 0) {
System.out.println(foodTimestampsTotal / 1000f + " " + timestampCurrentEpisode);
try {
Files.writeString(Path.of(file.getPath()), foodTimestampsTotal / 1000f + ",", StandardOpenOption.APPEND);
} catch (IOException e) {
e.printStackTrace();
}
foodTimestampsTotal = 0;
}
if(foodCollected == 1000){
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(0.15f);
}
if(foodCollected == 2000){
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(0.10f);
}
if(foodCollected == 3000){
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(0.05f);
}
if(foodCollected == 4000){
System.out.println("Reached 0 exploration");
((EpsilonGreedyPolicy<A>) this.policy).setEpsilon(0.00f);
}
if(foodCollected == 15000){
try {
Files.writeString(Path.of(file.getPath()), "\n", StandardOpenOption.APPEND);
} catch (IOException e) {
e.printStackTrace();
}
return;
}
timestampTilFood = 0;
}
// Q Update
double currentQValue = stateActionTable.getActionValues(state).get(action);

View File

@ -3,8 +3,6 @@ package core.algo.td;
import core.*;
import core.algo.EpisodicLearning;
import core.policy.EpsilonGreedyPolicy;
import core.policy.GreedyPolicy;
import core.policy.Policy;
import java.util.Map;
@ -35,10 +33,8 @@ public class SARSA<A extends Enum> extends EpisodicLearning<A> {
StepResultEnvironment envResult = null;
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
A action = policy.chooseAction(actionValues);
//A action = policy.chooseAction(actionValues);
sumOfRewards = 0;
while(envResult == null || !envResult.isDone()) {
// Take a step

View File

@ -0,0 +1,117 @@
package core.controller;
import core.algo.Method;
import evironment.antGame.AntAction;
import evironment.antGame.AntWorldContinuous;
import evironment.antGame.Constants;
import evironment.blackjack.BlackJackTable;
import evironment.blackjack.PlayerAction;
import evironment.jumpingDino.DinoAction;
import evironment.jumpingDino.DinoWorld;
import evironment.jumpingDino.DinoWorldAdvanced;
import javax.swing.*;
import java.text.NumberFormat;
public class OpeningDialog {
// JSlider is integer value only. Instead of creating a subclass
// the JSlider value is simply divided by this scale factor.
// 100 does mean 0 to 1 in 0.01 steps.
private int scaleFactor = 100;
public OpeningDialog() {
createStartingDialog();
}
private void setLabelText(JSlider slider, JLabel label, boolean scaledValue, String parameterName) {
if(scaledValue) {
label.setText(parameterName + ": " + (float) slider.getValue() / scaleFactor);
} else {
label.setText(parameterName + ": " + slider.getValue());
}
}
private void linkSliderWithLabel(JSlider slider, JLabel label, boolean scaledValue, String parameterName) {
setLabelText(slider, label, scaledValue, parameterName);
slider.addChangeListener(changeEvent -> setLabelText(slider, label, scaledValue, parameterName));
}
private void createStartingDialog() {
JComboBox<Scenario> scenarioSelection = new JComboBox<>(Scenario.values());
JComboBox<Method> algorithmSelection = new JComboBox<>(Method.values());
JSlider delaySlider = new JSlider(1, 1000, 200);
JLabel delayLabel = new JLabel();
linkSliderWithLabel(delaySlider, delayLabel, false, "Delay");
JSlider discountSlider = new JSlider(0, 100, 100);
JLabel discountLabel = new JLabel();
linkSliderWithLabel(discountSlider, discountLabel, true, "Discount Factor (gamma)");
JSlider epsilonSlider = new JSlider(0, 100, 15);
JLabel epsilonLabel = new JLabel();
linkSliderWithLabel(epsilonSlider, epsilonLabel, true, "Exploration Factor (epsilon)");
JSlider learningRateSlider = new JSlider(0, 100, 90);
JLabel learningRateLabel = new JLabel();
linkSliderWithLabel(learningRateSlider, learningRateLabel, true, "Learning rate (alpha)");
JTextField episodesToLearn = new JFormattedTextField(NumberFormat.getIntegerInstance());
episodesToLearn.setText("10000");
JTextField seedTextField = new JFormattedTextField(NumberFormat.getIntegerInstance());
seedTextField.setText("29");
Object[] parameters = {
"Environment:", scenarioSelection,
"Algorithm:", algorithmSelection,
discountLabel, discountSlider,
epsilonLabel, epsilonSlider,
learningRateLabel, learningRateSlider,
delayLabel, delaySlider,
"Episodes to learn:", episodesToLearn,
"RNG Seed: ", seedTextField,
};
int option = JOptionPane.showConfirmDialog(null, parameters, "Learning parameters", JOptionPane.OK_CANCEL_OPTION);
if(option == JOptionPane.OK_OPTION) {
Scenario selectedScenario = (Scenario) scenarioSelection.getSelectedItem();
RLControllerGUI rl;
if(selectedScenario == Scenario.JUMPING_DINO_SIMPLE) {
rl = new RLControllerGUI<DinoAction>(new DinoWorld(), (Method) algorithmSelection.getSelectedItem(), DinoAction.values());
} else if(selectedScenario == Scenario.JUMPING_DINO_ADVANCED) {
rl = new RLControllerGUI<DinoAction>(new DinoWorldAdvanced(), (Method) algorithmSelection.getSelectedItem(), DinoAction.values());
} else if(selectedScenario == Scenario.ANTGAME_ONE_FOOD) {
rl = new RLControllerGUI<AntAction>(new AntWorldContinuous(Constants.DEFAULT_GRID_WIDTH, Constants.DEFAULT_GRID_HEIGHT, 1), (Method) algorithmSelection.getSelectedItem(), AntAction.values());
} else if(selectedScenario == Scenario.ANTGAME_TWO_FOOD) {
rl = new RLControllerGUI<AntAction>(new AntWorldContinuous(Constants.DEFAULT_GRID_WIDTH, Constants.DEFAULT_GRID_HEIGHT, 2), (Method) algorithmSelection.getSelectedItem(), AntAction.values());
} else if(selectedScenario == Scenario.BLACKJACK) {
rl = new RLControllerGUI<PlayerAction>(new BlackJackTable(), (Method) algorithmSelection.getSelectedItem(), PlayerAction.values());
} else {
throw new IllegalArgumentException("Invalid learning scenario selected");
}
rl.setDelay(delaySlider.getValue());
rl.setDiscountFactor((float) discountSlider.getValue() / scaleFactor);
rl.setEpsilon((float) epsilonSlider.getValue() / scaleFactor);
rl.setLearningRate((float) learningRateSlider.getValue() / scaleFactor);
rl.setNrOfEpisodes(Integer.parseInt(episodesToLearn.getText()));
rl.start();
} else {
System.out.println("Parameter dialog canceled");
}
}
private enum Scenario {
JUMPING_DINO_SIMPLE,
JUMPING_DINO_ADVANCED,
ANTGAME_ONE_FOOD,
ANTGAME_TWO_FOOD,
BLACKJACK
}
}

View File

@ -39,6 +39,7 @@ public class RLController<A extends Enum> implements LearningListener {
protected int prevDelay;
protected volatile boolean printNextEpisode;
@SafeVarargs
public RLController(Environment<A> env, Method method, A... actions) {
setEnvironment(env);
setMethod(method);
@ -87,7 +88,7 @@ public class RLController<A extends Enum> implements LearningListener {
private void initLearning() {
if(learning instanceof EpisodicLearning) {
System.out.println("Starting learning of <" + nrOfEpisodes + "> episodes");
((EpisodicLearning) learning).learn(nrOfEpisodes);
((EpisodicLearning<A>) learning).learn(nrOfEpisodes);
} else {
learning.learn();
}
@ -100,11 +101,9 @@ public class RLController<A extends Enum> implements LearningListener {
protected void learnMoreEpisodes(int nrOfEpisodes) {
if(learning instanceof EpisodicLearning) {
if(learning.isCurrentlyLearning()){
((EpisodicLearning) learning).learnMoreEpisodes(nrOfEpisodes);
((EpisodicLearning<A>) learning).learnMoreEpisodes(nrOfEpisodes);
}else{
new Thread(() -> {
((EpisodicLearning) learning).learn(nrOfEpisodes);
}).start();
new Thread(() -> ((EpisodicLearning<A>) learning).learn(nrOfEpisodes)).start();
}
} else {
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
@ -179,8 +178,8 @@ public class RLController<A extends Enum> implements LearningListener {
public void onEpisodeEnd(List<Double> rewardHistory) {
latestRewardsHistory = rewardHistory;
if(printNextEpisode) {
System.out.println("Episode " + ((EpisodicLearning) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1));
System.out.println("Eps/sec: " + ((EpisodicLearning) learning).getEpisodePerSecond());
System.out.println("Episode " + learning.getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1));
System.out.println("Eps/sec: " + ((EpisodicLearning<A>) learning).getEpisodePerSecond());
printNextEpisode = false;
}
}

View File

@ -13,10 +13,12 @@ import java.util.List;
public class RLControllerGUI<A extends Enum> extends RLController<A> implements ViewListener {
private LearningView learningView;
@SafeVarargs
public RLControllerGUI(Environment<A> env, Method method, A... actions) {
super(env, method, actions);
}
@Override
protected void initListeners() {
SwingUtilities.invokeLater(() -> {
@ -102,7 +104,7 @@ public class RLControllerGUI<A extends Enum> extends RLController<A> implements
@Override
public void onLearningEnd() {
super.onLearningEnd();
onSaveState(method.toString() + System.currentTimeMillis() / 1000 + (learning instanceof EpisodicLearning ? "e" + ((EpisodicLearning) learning).getCurrentEpisode() : ""));
onSaveState(method.toString() + System.currentTimeMillis() / 1000 + (learning instanceof EpisodicLearning ? "e" + learning.getCurrentEpisode() : ""));
SwingUtilities.invokeLater(() -> learningView.updateRewardGraph(latestRewardsHistory));
}
}

View File

@ -38,7 +38,7 @@ public class LearningInfoPanel extends JPanel {
episodeLabel = new JLabel();
add(episodeLabel);
}
delaySlider = new JSlider(0, 1000, learning.getDelay());
delaySlider = new JSlider(1, 1000, learning.getDelay());
delaySlider.addChangeListener(e -> viewListener.onDelayChange(delaySlider.getValue()));
add(policyLabel);
add(discountLabel);
@ -60,9 +60,9 @@ public class LearningInfoPanel extends JPanel {
viewListener.onFastLearnChange(fastLearning);
});
smoothGraphCheckbox = new JCheckBox("Smoothen Graph");
smoothGraphCheckbox.setSelected(false);
smoothGraphCheckbox.setSelected(true);
last100Checkbox = new JCheckBox("Only show last 100 Rewards");
last100Checkbox.setSelected(true);
last100Checkbox.setSelected(false);
drawEnvironmentCheckbox = new JCheckBox("Update Environment");
drawEnvironmentCheckbox.setSelected(true);
@ -85,9 +85,7 @@ public class LearningInfoPanel extends JPanel {
add(learnMoreEpisodesButton);
}
showQTableButton = new JButton("Show Q-Table");
showQTableButton.addActionListener(e -> {
viewListener.onShowQTable();
});
showQTableButton.addActionListener(e -> viewListener.onShowQTable());
add(drawEnvironmentCheckbox);
add(smoothGraphCheckbox);
add(last100Checkbox);

View File

@ -1,7 +1,6 @@
package evironment.antGame;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Setter;

View File

@ -86,10 +86,10 @@ public class AntState implements State, Visualizable {
public JComponent visualize() {
return new JScrollPane() {
private int cellSize;
private final int paneWidth = 500;
private final int paneHeight = 500;
private Font font;
{
int paneWidth = 500;
int paneHeight = 500;
setPreferredSize(new Dimension(paneWidth, paneHeight));
cellSize = (paneWidth - knownWorld.length) / knownWorld.length;
font = new Font("plain", Font.BOLD, cellSize);

View File

@ -36,12 +36,14 @@ public class AntWorld implements Environment<AntAction>, Visualizable {
* various lectures could be possible as well.
*/
protected AntAgent antAgent;
protected int numberOfConcurrentFood;
protected int tick;
private int maxEpisodeTicks;
public AntWorld(int width, int height) {
grid = new Grid(width, height);
public AntWorld(int width, int height, int numberOfConcurrentFood) {
this.numberOfConcurrentFood = numberOfConcurrentFood;
grid = new Grid(width, height, numberOfConcurrentFood);
antAgent = new AntAgent(width, height);
myAnt = new Ant();
maxEpisodeTicks = 1000;
@ -49,7 +51,7 @@ public class AntWorld implements Environment<AntAction>, Visualizable {
}
public AntWorld(){
this(Constants.DEFAULT_GRID_WIDTH, Constants.DEFAULT_GRID_HEIGHT);
this(Constants.DEFAULT_GRID_WIDTH, Constants.DEFAULT_GRID_HEIGHT, Constants.DEFAULT_CONCURRENT_FOOD);
}
protected StepCalculation processStep(AntAction action) {
@ -62,7 +64,7 @@ public class AntWorld implements Environment<AntAction>, Visualizable {
sc.stayOnCell = true;
// flag to enable a check if all food has been collected only fired if food was dropped
// on the starting position
sc.checkCompletion = false;
sc.foodCollected = false;
switch(action) {
case MOVE_UP:
@ -110,7 +112,7 @@ public class AntWorld implements Environment<AntAction>, Visualizable {
} else {
sc.reward = Reward.FOOD_DROP_DOWN_SUCCESS;
myAnt.setPoints(myAnt.getPoints() + 1);
sc.checkCompletion = true;
sc.foodCollected = true;
}
}
break;
@ -139,16 +141,10 @@ public class AntWorld implements Environment<AntAction>, Visualizable {
// valid movement
if(!sc.stayOnCell) {
myAnt.getPos().setLocation(sc.potentialNextPos);
if(antAgent.getCell(myAnt.getPos()).getType() == CellType.UNKNOWN){
// the ant will move to a cell that was previously unknown
// TODO: not optimal for going straight for food
// sc.reward = Reward.UNKNOWN_FIELD_EXPLORED;
}
antAgent.getCell(myAnt.getPos());// the ant will move to a cell that was previously unknown
}
if(sc.checkCompletion) {
if(sc.foodCollected) {
sc.done = grid.isAllFoodCollected();
}
@ -183,7 +179,7 @@ public class AntWorld implements Environment<AntAction>, Visualizable {
boolean stayOnCell = true;
// flag to enable a check if all food has been collected only fired if food was dropped
// on the starting position
boolean checkCompletion = false;
boolean foodCollected = false;
}
public State reset() {

View File

@ -4,8 +4,8 @@ import core.State;
import core.StepResultEnvironment;
public class AntWorldContinuous extends AntWorld {
public AntWorldContinuous(int width, int height) {
super(width, height);
public AntWorldContinuous(int width, int height, int numberOfConcurrentFood) {
super(width, height, numberOfConcurrentFood);
}
public AntWorldContinuous() {
@ -14,14 +14,16 @@ public class AntWorldContinuous extends AntWorld {
@Override
public StepResultEnvironment step(AntAction action) {
Cell currentCell = grid.getCell(myAnt.getPos());
StepCalculation sc = processStep(action);
// flag is set to true if food gets dropped onto starts
if(sc.checkCompletion) {
if(sc.foodCollected) {
grid.removeAllFood();
System.out.println(numberOfConcurrentFood);
for(int i = 0; i < numberOfConcurrentFood; ++i) {
grid.spawnNewFood();
}
}
// valid movement
if(!sc.stayOnCell) {
myAnt.getPos().setLocation(sc.potentialNextPos);

View File

@ -3,8 +3,8 @@ package evironment.antGame;
import core.State;
public class AntWorldContinuousOriginalState extends AntWorldContinuous {
public AntWorldContinuousOriginalState(int width, int height) {
super(width, height);
public AntWorldContinuousOriginalState(int width, int height, int numberOfConcurrentFood) {
super(width, height, numberOfConcurrentFood);
}
public AntWorldContinuousOriginalState() {

View File

@ -1,6 +1,9 @@
package evironment.antGame;
public class Constants {
public static final int DEFAULT_GRID_WIDTH = 5;
public static final int DEFAULT_GRID_HEIGHT = 5;
public static final int DEFAULT_CONCURRENT_FOOD = 1;
public static final int DEFAULT_GRID_WIDTH = 8;
public static final int DEFAULT_GRID_HEIGHT = 8;
public static final int START_X = 5;
public static final int START_Y = 2;
}

View File

@ -10,12 +10,15 @@ public class Grid {
private Point start;
private Cell[][] grid;
private Cell[][] initialGrid;
private int numberOfConcurrentFood;
public Grid(int width, int height) {
public Grid(int width, int height, int numberOfConcurrentFood) {
this.width = width;
this.height = height;
this.numberOfConcurrentFood = numberOfConcurrentFood;
grid = new Cell[width][height];
initialGrid = new Cell[width][height];
start = new Point(Constants.START_X, Constants.START_Y);
initRandomWorld();
}
@ -29,10 +32,11 @@ public class Grid {
initialGrid[x][y] = new Cell(new Point(x, y), CellType.FREE);
}
}
start = new Point(RNG.getRandomEnv().nextInt(width), RNG.getRandomEnv().nextInt(height));
initialGrid[start.x][start.y] = new Cell(new Point(start.x, start.y), CellType.START);
spawnNewFood(initialGrid);
spawnObstacles();
initialGrid[start.x][start.y] = new Cell(new Point(start.x, start.y), CellType.START);
for(int i = 0; i < numberOfConcurrentFood; ++i) {
spawnNewFood(initialGrid);
}
}
//TODO
@ -53,7 +57,7 @@ public class Grid {
/**
* Spawns one additional food on a random field EXCEPT for the starting position
*/
public void spawnNewFood(Cell[][] grid) {
private void spawnNewFood(Cell[][] grid) {
boolean foodSpawned = false;
Point potFood = new Point(0, 0);
CellType potFieldType;
@ -70,6 +74,14 @@ public class Grid {
}
}
public void removeAllFood() {
for(int x = 0; x < width; ++x) {
for(int y = 0; y < height; ++y) {
grid[x][y].setFood(0);
}
}
}
public void spawnNewFood() {
spawnNewFood(grid);
}

View File

@ -7,10 +7,8 @@ import javax.swing.*;
import java.awt.*;
public class AntWorldComponent extends JComponent {
private AntWorld antWorld;
public AntWorldComponent(AntWorld antWorld, AntAgent antAgent){
this.antWorld = antWorld;
setLayout(new BorderLayout());
CellsScrollPane worldPane = new CellsScrollPane(antWorld.getCellArray(), antWorld.getAnt(), 10);
CellsScrollPane antBrainPane = new CellsScrollPane(antAgent.getKnownWorld(), antWorld.getAnt(), 10);

View File

@ -2,21 +2,20 @@ package evironment.jumpingDino;
import core.State;
import core.gui.Visualizable;
import lombok.AllArgsConstructor;
import lombok.Getter;
import javax.swing.*;
import java.awt.*;
import java.io.Serializable;
import java.util.Objects;
@AllArgsConstructor
@Getter
public class DinoState implements State, Serializable, Visualizable {
private int xDistanceToObstacle;
public class DinoState extends DinoStateSimple implements State, Serializable, Visualizable {
private boolean isJumping;
protected final double scale = 0.5;
public DinoState(int xDistanceToObstacle, boolean isJumping) {
super(xDistanceToObstacle);
this.isJumping = isJumping;
}
@Override
public String toString() {
@ -40,29 +39,15 @@ public class DinoState implements State, Serializable, Visualizable {
}
@Override
public JComponent visualize() {
return new JComponent() {
{
setPreferredSize(new Dimension(Config.FRAME_WIDTH, (int)(scale * Config.FRAME_HEIGHT)));
setVisible(true);
protected void drawDinoInfo(Graphics g) {
int dinoY;
if(!isJumping) {
dinoY = Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE;
g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (dinoY)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
} else {
dinoY = Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE - (int) (scale * Config.MAX_JUMP_HEIGHT);
g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (dinoY)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
}
@Override
protected void paintComponent(Graphics g) {
super.paintComponents(g);
drawObjects(g);
}
};
}
public void drawObjects(Graphics g){
g.setColor(Color.BLACK);
g.fillRect(0, (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y)), Config.FRAME_WIDTH, 2);
g.fillRect((int)(scale * Config.DINO_STARTING_X), (int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE)), (int)(scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
g.drawString("Distance: " + xDistanceToObstacle, (int)(scale * Config.DINO_STARTING_X),(int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40) ));
g.fillRect((int)(scale * (Config.DINO_STARTING_X + getXDistanceToObstacle())), (int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE)), (int)(scale * Config.OBSTACLE_SIZE), (int)(scale *Config.OBSTACLE_SIZE));
g.drawString("Distance: " + xDistanceToObstacle + " inJump: " + isJumping, (int) (scale * Config.DINO_STARTING_X), (int) (scale * (dinoY - 20)));
}
}

View File

@ -14,7 +14,7 @@ import java.util.Objects;
@Getter
public class DinoStateSimple implements State, Serializable, Visualizable {
protected final double scale = 0.5;
private int xDistanceToObstacle;
protected int xDistanceToObstacle;
@Override
public String toString() {
@ -40,7 +40,7 @@ public class DinoStateSimple implements State, Serializable, Visualizable {
public JComponent visualize() {
return new JComponent() {
{
setPreferredSize(new Dimension(Config.FRAME_WIDTH, (int) (scale * Config.FRAME_HEIGHT)));
setPreferredSize(new Dimension((int) (scale * Config.FRAME_WIDTH), (int) (scale * Config.FRAME_HEIGHT)));
setVisible(true);
}
@ -52,14 +52,15 @@ public class DinoStateSimple implements State, Serializable, Visualizable {
};
}
protected void drawDinoInfo(Graphics g) {
g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
g.drawString("Distance: " + xDistanceToObstacle, (int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40)));
}
public void drawObjects(Graphics g) {
g.setColor(Color.BLACK);
g.fillRect(0, (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y)), Config.FRAME_WIDTH, 2);
g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
g.drawString("Distance: " + xDistanceToObstacle, (int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40)));
g.fillRect((int) (scale * (Config.DINO_STARTING_X + getXDistanceToObstacle())), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE)), (int) (scale * Config.OBSTACLE_SIZE), (int) (scale * Config.OBSTACLE_SIZE));
drawDinoInfo(g);
}
}

View File

@ -49,19 +49,6 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
if(action == DinoAction.JUMP){
dino.jump();
}
// for(int i= 0; i < 5; ++i){
// dino.tick();
// currentObstacle.tick();
// if(currentObstacle.getX() < -Config.OBSTACLE_SIZE){
// spawnNewObstacle();
// }
// comp.repaint();
// if(ranIntoObstacle()){
// done = true;
// break;
// }
// }
dino.tick();
currentObstacle.tick();
if(currentObstacle.getX() < -Config.OBSTACLE_SIZE) {
@ -76,7 +63,7 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
}
protected State generateReturnState(){
return new DinoStateSimple(getDistanceToObstacle());
return new DinoState(getDistanceToObstacle(), dino.isInJump());
}
protected void spawnNewObstacle(){

View File

@ -7,31 +7,18 @@ import core.controller.RLControllerGUI;
import evironment.antGame.AntAction;
import evironment.antGame.AntWorldContinuous;
import java.io.File;
import java.io.IOException;
public class ContinuousAnt {
public static final String FILE_NAME = "converge.txt";
public static void main(String[] args) {
File file = new File(FILE_NAME);
try {
file.createNewFile();
} catch (IOException e) {
e.printStackTrace();
}
RNG.setSeed(13, true);
RLController<AntAction> rl = new RLControllerGUI<>(
new AntWorldContinuous(8, 8),
new AntWorldContinuous(8, 8, 1),
Method.Q_LEARNING_OFF_POLICY_CONTROL,
AntAction.values());
rl.setDelay(20);
rl.setDelay(200);
rl.setNrOfEpisodes(1);
// 0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 0.99
rl.setDiscountFactor(0.05f);
// 0.1, 0.3, 0.5, 0.7 0.9
rl.setDiscountFactor(0.3f);
rl.setLearningRate(0.9f);
rl.setEpsilon(0.2f);
rl.setEpsilon(0.15f);
rl.start();
}
}

View File

@ -0,0 +1,9 @@
package example;
import core.controller.OpeningDialog;
public class GUIMain {
public static void main(String[] args) {
new OpeningDialog();
}
}

View File

@ -3,23 +3,23 @@ package example;
import core.RNG;
import core.algo.Method;
import core.controller.RLController;
import core.controller.RLControllerGUI;
import evironment.jumpingDino.DinoAction;
import evironment.jumpingDino.DinoWorld;
import evironment.jumpingDino.DinoWorldAdvanced;
public class JumpingDino {
public static void main(String[] args) {
RNG.setSeed(29);
RLController<DinoAction> rl = new RLController<>(
RLController<DinoAction> rl = new RLControllerGUI<>(
new DinoWorldAdvanced(),
Method.MC_CONTROL_FIRST_VISIT,
Method.MC_CONTROL_EVERY_VISIT,
DinoAction.values());
rl.setDelay(0);
rl.setDiscountFactor(9f);
rl.setDelay(200);
rl.setDiscountFactor(1f);
rl.setEpsilon(0.05f);
rl.setLearningRate(0.8f);
rl.setLearningRate(1f);
rl.setNrOfEpisodes(100000);
rl.start();
}

View File

@ -12,7 +12,7 @@ public class RunningAnt {
RNG.setSeed(56);
RLController<AntAction> rl = new RLControllerGUI<>(
new AntWorld(8, 8),
new AntWorld(8, 8, 1),
Method.Q_LEARNING_OFF_POLICY_CONTROL,
AntAction.values());