Compare commits

..

9 Commits

Author SHA1 Message Date
Jan Löwenstrom 7de2a5d1af spawn start field of antGame in the same spot everytime 2020-04-05 14:10:47 +02:00
Jan Löwenstrom 737d78c6da mend 2020-04-05 12:54:14 +02:00
Jan Löwenstrom e8f4fa06b6 add specific environment RNG 2020-04-05 12:52:49 +02:00
Jan Löwenstrom 42dfebb048 Merge branch 'epsilonBehavior' into DinoSampling
# Conflicts:
#	src/main/java/example/DinoSampling.java
2020-04-05 12:07:39 +02:00
Jan Löwenstrom 3bdcbb39bc reset DinoSampling to advanced Every Visit 2020-04-02 18:45:36 +02:00
Jan Löwenstrom b0ca634b64 add every visit no jump results 2020-04-02 17:07:15 +02:00
Jan Löwenstrom f2aa7487af Merge remote-tracking branch 'origin/epsilonBehavior' into epsilonBehavior
# Conflicts:
#	src/main/java/core/algo/EpisodicLearning.java
2020-04-02 15:57:32 +02:00
Jan Löwenstrom 6477251545 add Every-Visit Monte-Carlo 2020-04-02 15:56:11 +02:00
Jan Löwenstrom 28c40c58dd add updated R script 2020-03-27 17:09:03 +01:00
36 changed files with 286 additions and 316 deletions

View File

@ -11,7 +11,7 @@
</list>
</option>
</component>
<component name="ProjectRootManager" version="2" languageLevel="JDK_11" default="false" project-jdk-name="1.8" project-jdk-type="JavaSDK">
<component name="ProjectRootManager" version="2" languageLevel="JDK_11" default="false" project-jdk-name="11.0.3" project-jdk-type="JavaSDK">
<output url="file://$PROJECT_DIR$/out" />
</component>
</project>

View File

@ -7,8 +7,8 @@ plugins {
group 'net.lwenstrom.jan'
version '1.0-SNAPSHOT'
sourceCompatibility = 8
targetCompatibility = 8
sourceCompatibility = 11
targetCompatibility = 11
repositories {
mavenCentral()
@ -29,7 +29,8 @@ dependencies {
}
// Include dependent libraries in archive.
mainClassName = "example.GUIMain"
mainClassName = "example.DinoSampling"
jar {
manifest {

View File

@ -17,5 +17,5 @@
0.8,19870,3288,13724,4492,8159,5058,16764,5648,9462,19071,3914,1242,8262,26004,4036,9421,4914,2535,5362,7298,9587,37133,1837,35325,15272,14922,14138,7115,17236,5123,12157,37380,6086,37390,1672,15573,14241,2049,2602,6802,22362,7936,7544,5330,13155,16016,4544,1489,3780,6326,7794,31553,2808,1493,7788,12646,30464,22312,1681,12084,4163,2197,7950,22478,5106,26771,4382,10615,2586,12214,4799,6297,7589,4585,30365,32302,15734,5480,8626,7387,11932,4245,21532,1710,12737,7132,4740,14578,10680,8266,17300,4213,3264,35920,38026,10272,3984,2279,9739,33900
0.85,5493,10568,19366,5705,15430,8183,5721,13314,36667,33059,3753,40243,23888,25085,21843,6856,2803,9434,4794,29944,10730,39271,4484,23990,6350,16180,8099,4298,11220,4624,5946,24895,8464,4416,6619,2800,4081,12459,1981,12488,6380,9597,10328,1901,24563,13059,3639,12988,2604,4440,22666,1775,4078,5175,1144,3759,11119,1856,34970,10831,2229,5333,17121,9698,14919,2353,3963,8189,36145,13920,5301,16516,2446,46848,3985,,20640,151501,17556,1882,44216,39795,1638,57957,62050,3130,3693,5563,9780,3327,22969,39357,13749,37555,60070,9249,35426,4405,8340,18973
0.9,27355,24592,18962,2318,17604,35725,14327,38167,25602,50236,4999,9023,5562,7541,11799,25139,8724,12642,28509,57095,2147,5909,5414,12572,10018,68830,45393,18962,51656,25601,3444,45667,16813,57110,16492,3991,7315,17775,69277,34769,29824,11087,26371,3479,2540,9597,32593,13169,8588,2794,40136,56004,65307,24864,35523,19491,2673,5363,4799,5852,28566,42427,44011,40146,3757,1115,49574,5798,24249,2576,118943,6169,65584,7057,49505,116138,52083,1809,127776,3214,25689,103442,15260,62754,12390,3233,35309,68989,6615,30593,2503,29359,98237,11900,3240,64969,84134,25361,7384,13141
0.95,24269,14543,6828,3800,41079,47279,27177,17286,9802,7114,3756,85275,14507,34993,15139,15184,90742,27554,23713,6453,15157,7045,8048,47550,84540,93729,68601,6274,4713,30578,5024,94239,7315,8193,46871,96466,3695,70915,62947,32258,66228,2114,5084,12686,62905,19158,20940,36270,9037,34034,15016,15530,46276,11063,8586,15635,7196,70708,50836,22464,13463,86986,43541,2001,40565,28534,44700,5625,6552,16140,2450,8492,3304,22904,20951,100472,131147,131728,43674,514,79827,181148,31431,4761,1515,2075,138139,137795,71014,170145,60000,42790,179835,18982,48085,28398,56788,126115,5442,118289
0.95,24269,14543,6828,3800,41079,47279,27177,17286,9802,7114,3756,85275,14507,34993,15139,15184,90742,27554,23713,6453,15157,7045,8048,47550,84540,93729,68601,6274,4713,30578,5024,94239,7315,8193,46871,96466,3695,70915,62947,32258,66228,2114,5084,12686,62905,19158,20940,36270,9037,34034,15016,15530,46276,11063,8586,15635,7196,70708,50836,22464,13463,86986,43541,2001,40565,28534,44700,5625,6552,16140,2450,8492,3304,22904,20951,100472,131147,131728,43674,514,79827,181148,31431,4761,1515,2075,138139,137795,71014170145,60000,42790,179835,18982,48085,28398,56788,126115,5442,118289,9386
1.0,11364,6363,8012,109822,19730,8425,21388,7864,18427,34072,3126,52381,35105,86487,73913,88033,76264,105864,30103,9522,31049,3180,4838,4078,133687,39236,59239,22968,21540,98395,109063,4050,5612,4990,9933,83766,140114,116077,135653,130826,130070,92207,14994,87801,1577,70868,133816,79790,1587,23322,22071,13903,3584,9721,,38605,52375,67392,10075,97733,46173,29647,2558,28151,162569,4054,10537,30871,45538,97835,45132,35042,70203,3862,100614,84525,140691,81880,80914,35187,11596,51448,2945,56551,39236,84707,64324,100588,78645,12929,32701,63306,163991,2864,34802,72929,198161,71332,98627,137754

View File

@ -1,21 +0,0 @@
# No Jump MC Every Visist
0.05,36,29147,7945,707,12,595,68,4913,8034,300,804,174721,1731,715,64899,423,703,2685,173,468,8220,4182,80071,1418,157366,17769,216,3859,44746,218,6794,2,4182,35,258,9015,2581,62,39,17439,7346,1000,56974,21,71655,20301,4834,18144,18393,613,5,39,609,586,6,28871,1021,18051,457,13629,59217,10278,173,24407,1373,48513,31749,190,25159,130,121036,609,609,55,924,77,4379,72708,202,2096,25965,14394,1230,2833,1791,25512,230,16227
0.1,742,1269,9139,6764,782,25068,539,284,678,9327,2166,3469,3057,2971,1335,136645,64551,22900,35519,62,29,15732,906,16,17,960,190,32183,2,40863,3324,37,6610,97,1578,12006,1784,3388,54,576,64,404,1638,95916,108,281,26288,1825,28,89,2359,49218,853,17821,10514,283,153967,44669,256,6,12708,13085,3466,56628,8513,870,15,53906,20,1533,2164,1458,538,226,99,163,16,31398,76,2075,1474,36605,434,120,3744,91,2885,112,3285,69283,21176,3029,10113
0.15,1085,42379,2065,4847,124,27681,2257,1617,4597,1033,39095,254,120,1343,2007,1907,2992,15207,95,7835,495,16224,362,97541,141,8692,32124,1302,1768,10853,480,2,2099,25,1889,66,67,11874,290,77899,7253,31,1190,72,95,141,7732,32,26,16,5,181,361,12321,22,1631,7180,18,40,310,95444,7146,102,5210,6,19593,434,3581,3120,35,14,980,19,3816,677,1494,337,68,2764,53854,147,28,732,23755,270,4818,15,3046,1557,117,13049,323,26709,2036,431,85
0.2,11,119,181736,8993,10,949,1495,20497,1325,5672,39392,35828,233,56,2707,456,12,3630,112,257,936,108,618,2353,52,2563,1632,10094,12336,24,1390,11,1130,4381,78,14658,1228,514,4706,215,15898,100,639,110,723,340,21,25,836,4,83,922,90321,22,747,11599,88,13,11,2101,43,642,1272,1742,169,614,832,681,41884,17,3159,392,9,178,2183,3271,10,378,69,35741,293,63,2733,3825,745,9,3151,2934,588,214,30150,1138,1304,98,566,1497,4474
0.25,21,222,230,10720,108,2978,2377,1398,3310,10274,27275,1674,22,720,293,52,863,16756,76,520,28985,247,195,5507,97,5916,96,11,1346,77,157,8,14,501,40,1502,32,58,561,8,327,74,91,20,5444,914,4278,30,27,399,2186,4,75,1149,2386,22,3415,3918,56,438,8,160,451,18150,520,329,139654,33850,16,625,128,463,73,159639,147,33074,2545,198,1210,637,1143,180,824,171,293,16,9592,699,8,929,664,1637,2524,18383,256,31,540,233,265,998
0.3,15,2847,8254,2332,2921,62,61322,1967,2375,4488,1142,1502,83,5177,519,128,7727,58,25,552,365,119,141,11859,7,9742,8,754,415,1534,641,8,498,12364,38,1062,75,336,3247,8,540,683,13,396,9242,194,24,935,62,83,611,4,5,5047,3134,243,1647,3294,96,257,9,1777,38112,6059,118,1702,4305,48,25907,113,715,8233,797,1721,147,39342,1111,26,11991,516,3292,24,126,1012,3290,305,5371,145,11939,65779,9057,3241,76,6924,2896,266,4380,902,634,577
0.35,68,47,70,1307,4155,1031,15740,175,829,920,2649,599,214,494,544,5026,1647,2007,48,182,294,670,156,613,20,1161,250,1710,4153,270,78,81,17,81,4816,5955,87,428,189,8,235,225,3151,7884,205,1322,25,31,787,15,3728,4,5,1314,3193,454,1175,167,176,7596,49,74,3110,389,385,2940,317,17,3846,94,99,202,88,22,1957,749,2767,57,334,326,27205,1928,14,285,3293,43,6288,43,643,36225,58,2835,1126,8335,1101,1874,2916,391,2455,13320
0.40,430,2482,4292,2558,7285,165,295,845,3480,43,2391,3,269,448,886,7400,6280,704,61,178,2706,4046,1089,139,947,855,46,415,1668,6922,235,22,5447,19,401,706,67,5379,73,527,178,3307,1254,996,276,4191,760,554,1140,15987,175,4,22,7793,4102,129,583,1056,228,27992,615,51,32,259,1508,867,2,1314,170,1135,390,3982,69,1228,1556,324,288,4165,504,504,173,313,85,2101,390,62,26544,468,793,1533,63,1838,167,3753,4042,402,6142,21336,841,72251
0.45,684,893,1075,3710,10131,9449,999,337,136,3812,80,3,3113,132,11540,277,639,315,74,55,2061,53,23,3426,513,4798,60,37,7000,76680,2757,22,1241,87,1207,4540,49,4373,48,10391,561,1698,214,507,382,251,245,998,142,1680,290,4,21,390,13326,209,916,2873,108,38,417,16,13,503,247,4716,8,23,4030,1862,104,681,114,40,638,647,3693,22959,90,249,85,485,1448,346,1575,524,429,3365,489,1124,177,3188,188,1880,110,263,108,34,408,25180
0.50,42,16830,1941,220,3775,815,2031,498,3436,2897,1061,3,2619,192,1034,672,14834,214,125,3387,1480,345,32,4599,682,350,45,572,863,45168,66,186,231,4381,596,1034,723,3776,10695,110,3889,975,809,86,579,16,40,228,180,259,54,4,1241,649,22145,372,671,1055,2444,1949,61,229,17,2392,266,1586,1820,24,2182,3007,124,270,233,46,1526,295,1276,996,352,1225,17,341,141443,271,512,3197,474,3692,8,888,581,1091,71,4663,87,4268,1127,3638,1529,661
0.55,3815,777,987,185,27,1230,1594,2210,312,303,15,2,2256,3099,7131,134,3533,158,5467,1954,3945,160,281,1095,216,676,60,6476,429,267,1781,966,1156,2818,3082,16615,233,3614,95,1157,2991,1286,78,1461,783,16,40,1439,5110,22,322,4,878,5,111,1464,249,6612,631,10027,145,1804,1304,2465,992,196,32,13,204,1639,255,1151,409,6002,331,585,1023,2579,1625,730,16,249,929,265,782,327,3763,142,24,4778,130,27170,653,30,64,2757,7624,10265,242,782
0.60,72,219,1014,1553,13880,289,3735,1411,4136,112,3301,2,4284,1165,7631,675,1641,3996,728,1337,127,33,2489,3893,429,171,5,3243,128,5598,1049,1360,831,862,178,463,113,86,3557,1211,40,1309,185,34750,75,15,28,6889,743,49,15,4,582,3927,61,83,841,1639,4,984,195,2585,50,2949,2506,7500,33,14,901,2253,267,1783,254,2056,110,225,4108,3027,255,1567,24,905,91,3740,500,2585,2344,2102,160,468,258,3133,3619,5072,2158,2844,2325,343,18,2851
0.65,259,481,1669,5179,15,2300,180,118,200,5180,238,2,15,1424,6185,1967,149,291,9749,666,1616,680,3145,96,650,5056,1247,1082,310,374,1614,937,563,322,1963,161,22,38,78,821,206,1006,390,319,1613,1992,1230,535,96,320,17,4,3067,6450,552,991,17929,3373,28,204,47,143,73,632,2433,2131,78,14,412,3296,782,4784,1067,629,2132,4082,9844,6593,2112,412,401,19872,1336,368,2487,319,814,27,170,1362,100,1678,127,1800,6572,5779,693,1565,2088,798
0.70,1007,185,1973,2285,15,28,406,1468,382,1237,11692,2,55,1204,172,1042,3029,583,47,1430,2597,3431,2986,3363,4073,355,21980,1160,916,6140,893,535,13,418,578,535,3,47,3353,9003,816,3294,393,2578,2874,4252,1128,786,381,207,10,4,6397,895,2210,4982,763,54,11,1537,2,712,26,8748,123,3861,65,958,811,62,116,429,409,394,1029,1611,7386,298,1518,1322,3053,1359,5281,385,3349,4487,1939,5949,533,1506,3234,1492,2744,4,2985,7713,1492,3836,7254,342
0.75,6523,317,3968,233,15,80,2863,4066,16,3262,883,704,117,5443,2409,59,3848,3450,65,87,3764,1652,62,1646,2864,512,1226,1816,3587,378,4243,1837,10,5908,16119,4734,5,328,2639,897,353,6208,4281,1671,1520,610,451,399,3650,2542,10,3,23,66,245,2607,58,470,4,1989,2,60,3256,3603,4903,568,743,251,14701,2671,4777,698,176,26,1302,787,3275,2738,1335,2634,8791,5806,2222,13223,311,4218,8479,5535,8898,189,2425,1919,2482,4,1536,3469,2201,345,390,5228
0.80,83,2976,2164,17000,16090,4483,13563,987,14,505,282,3317,2053,387,12252,247,3519,4557,378,71,756,910,1921,2874,1862,1829,731,54,234,715,682,10015,530,633,3911,4709,265,116,104,9814,1312,601,61,19856,1102,595,10881,3728,3633,3388,9,3,811,1999,4075,2106,1888,2582,4,12564,154,383,15,2438,436,824,947,6212,3273,73,772,49,2818,11875,4322,4457,3292,3180,170,55,4707,100,5989,1835,4546,5414,808,65,7,811,1783,19664,219,4,248,1953,996,32118,7,257
0.85,2075,497,11265,1960,3624,34,366,957,14,2817,4413,18269,5069,2370,2272,12272,330,2063,4130,6284,3628,9,1938,613,15316,3474,5446,100,3894,2899,67,892,4003,348,135,1151,5953,6105,6149,5830,2078,1831,4313,626,269,877,3822,19,1106,139,9,3,18364,1119,8355,22487,3265,1415,5,13005,781,27358,728,1977,616,6789,341,7050,2368,24,1281,1913,15870,1985,619,7965,5509,538,620,2012,2022,10890,20395,1500,12,6141,8327,107,7,8407,109,45,100,4,874,632,2101,396,7,5154
0.90,902,3427,5470,42978,9193,2969,5514,2288,91,2048,1085,982,16,4210,3702,103,14119,5531,1076,2111,1360,16,10217,4417,544,607,1414,5350,1027,9041,3697,12502,43926,105,2941,31782,995,15661,773,338,153,19508,11985,33683,3006,4043,2942,2541,3817,2281,8561,6771,2305,37689,9198,19,899,13884,1818,12144,319,416,250,11815,44066,4825,783,4261,286,162,13125,138,3183,627,609,1736,1554,989,7559,3504,6641,3354,47,972,11,14096,278,15664,7,15571,389,286,1367,4,2434,25,297,3388,7,2743
0.95,493,6960,769,1563,660,4208,778,8917,98,72,21014,100,6872,7998,26519,154,4073,178,11475,3714,184,9,7250,2416,1161,38061,85,3569,25036,7121,30470,300,270,8458,668,5242,610,21,2330,15540,3238,2429,8334,103,3945,6469,166,581,27569,8,3,7695,4772,212,4868,15107,11469,132,2466,1252,4425,21010,6871,2827,155,4470,59064,5478,504,38380,5933,18,4313,4131,26218,5916,1479,4514,14507,64879,4715,2885,11,7370,2070,4241,7,11518,6675,3896,6503,26757,38,15158,1246,63,7,86
1.00,5753,31823,1131,5008,814,2498,3289,4148,1189,3345,455,16054,1727,4275,8701,177,354,21971,2838,2195,9,7887,682,34,1657,23710,296,29250,5562,79,7168,176,369,9651,140,19561,4518,6518,3711,20169,4594,105804,3547,809,14814,220,774,16,2280,3,318,1921,6150,254,8587,122362,7848,23995,2323,7108,2348,58820,3945,62059,984,10839,10909,28613,8046,20249,12086,4325,9060,176,86,47797,1107,119,12405,8481,12,4521,683,171076,92,419,1714,584,596,102,47669,23214,665,98,8,486

Binary file not shown.

Before

Width:  |  Height:  |  Size: 26 KiB

View File

@ -1,10 +1,7 @@
package core;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.*;
/**
* Implementation of a discrete action space.
@ -21,7 +18,6 @@ public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSp
actions = new ArrayList<>();
}
@SafeVarargs
public ListDiscreteActionSpace(A... actions){
this.actions = new ArrayList<>(Arrays.asList(actions));
}

View File

@ -17,30 +17,24 @@ public class RNG {
private static Random rng;
private static Random rngEnv;
private static int seed = 123;
private static int envSeed = 13;
static {
rng = new Random();
rngEnv = new Random();
setSeed(seed, true);
}
public static Random getRandom() {
return rng;
}
public static Random getRandomEnv() {
public static Random getEnvRandom() {
return rngEnv;
}
public static void setSeed(int seed, boolean setEnvRandom) {
public static void setSeed(int seed, boolean setEnvSeed) {
RNG.seed = seed;
rng.setSeed(seed);
if(setEnvRandom) {
rngEnv.setSeed(13);
if(setEnvSeed) {
rngEnv.setSeed(seed);
}
}
public static void setSeed(int seed) {
setSeed(seed, true);
}
}

View File

@ -5,26 +5,32 @@ import core.Environment;
import core.LearningConfig;
import core.StepResult;
import core.listener.LearningListener;
import example.DinoSampling;
import lombok.Getter;
import lombok.Setter;
import java.io.File;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
public abstract class EpisodicLearning<A extends Enum> extends Learning<A> implements Episodic {
private volatile AtomicInteger episodesToLearn = new AtomicInteger(0);
private int episodeSumCurrentSecond;
@Setter
protected int currentEpisode = 0;
protected volatile AtomicInteger episodesToLearn = new AtomicInteger(0);
@Getter
protected volatile int episodePerSecond;
protected int episodeSumCurrentSecond;
protected double sumOfRewards;
protected List<StepResult<A>> episode = new ArrayList<>();
protected int timestampCurrentEpisode = 0;
protected boolean converged;
public EpisodicLearning(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, int delay) {
super(environment, actionSpace, discountFactor, delay);
initBenchMarking();
@ -49,7 +55,7 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
private void initBenchMarking(){
new Thread(()->{
while (true){
while (currentlyLearning){
episodePerSecond = episodeSumCurrentSecond;
episodeSumCurrentSecond = 0;
try {
@ -84,6 +90,21 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
protected void dispatchStepEnd() {
super.dispatchStepEnd();
timestamp++;
timestampCurrentEpisode++;
// TODO: more sophisticated way to check convergence
if(timestampCurrentEpisode > 50000) {
converged = true;
// t
File file = new File(DinoSampling.FILE_NAME);
try {
Files.writeString(Path.of(file.getPath()), currentEpisode/2 + ",", StandardOpenOption.APPEND);
} catch (IOException e) {
e.printStackTrace();
}
System.out.println("converged after: " + currentEpisode/2 + " episode!");
episodesToLearn.set(0);
dispatchLearningEnd();
}
}
@Override
@ -94,7 +115,9 @@ public abstract class EpisodicLearning<A extends Enum> extends Learning<A> imple
private void startLearning(){
dispatchLearningStart();
while(episodesToLearn.get() > 0){
dispatchEpisodeStart();
timestampCurrentEpisode = 0;
nextEpisode();
dispatchEpisodeEnd();
}

View File

@ -16,6 +16,8 @@ import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
/**
*
@ -23,6 +25,10 @@ import java.util.concurrent.CopyOnWriteArrayList;
*/
@Getter
public abstract class Learning<A extends Enum>{
// TODO: temp testing -> extract to dedicated test
protected int checkSum;
protected int rewardCheckSum;
// current discrete timestamp t
protected int timestamp;
protected int currentEpisode;
@ -93,7 +99,7 @@ public abstract class Learning<A extends Enum>{
public void save(ObjectOutputStream oos) throws IOException {
oos.writeObject(rewardHistory);
// oos.writeObject(stateActionTable);
oos.writeObject(stateActionTable);
}
public void load(ObjectInputStream ois) throws IOException, ClassNotFoundException {

View File

@ -3,6 +3,8 @@ package core.algo.mc;
import core.*;
import core.algo.EpisodicLearning;
import core.policy.EpsilonGreedyPolicy;
import core.policy.GreedyPolicy;
import core.policy.Policy;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
@ -12,7 +14,7 @@ import java.io.ObjectOutputStream;
import java.util.*;
/**
* Includes both! variants of Monte-Carlo methods
* Includes both variants of Monte-Carlo methods
* Default method is First-Visit.
* Change to Every-Visit by setting flag "useEveryVisit" in the constructor to true.
* @param <A>
@ -21,12 +23,19 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A
private Map<Pair<State, A>, Double> returnSum;
private Map<Pair<State, A>, Integer> returnCount;
private final boolean isEveryVisit;
private boolean isEveryVisit;
// t
private float epsilon;
// t
private Policy<A> greedyPolicy = new GreedyPolicy<>();
public MonteCarloControlEGreedy(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, int delay, boolean useEveryVisit) {
super(environment, actionSpace, discountFactor, delay);
isEveryVisit = useEveryVisit;
// t
this.epsilon = epsilon;
this.policy = new EpsilonGreedyPolicy<>(epsilon);
this.stateActionTable = new DeterministicStateActionTable<>(this.actionSpace);
returnSum = new HashMap<>();
@ -55,11 +64,17 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A
while(envResult == null || !envResult.isDone()) {
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
A chosenAction = policy.chooseAction(actionValues);
A chosenAction;
if(currentEpisode % 2 == 1){
chosenAction = greedyPolicy.chooseAction(actionValues);
}else{
chosenAction = policy.chooseAction(actionValues);
}
envResult = environment.step(chosenAction);
State nextState = envResult.getState();
sumOfRewards += envResult.getReward();
rewardCheckSum += envResult.getReward();
episode.add(new StepResult<>(state, chosenAction, envResult.getReward()));
state = nextState;
@ -71,6 +86,11 @@ public class MonteCarloControlEGreedy<A extends Enum> extends EpisodicLearning<A
}
timestamp++;
dispatchStepEnd();
if(converged) return;
}
if(currentEpisode % 2 == 1){
return;
}
// System.out.printf("Episode %d \t Reward: %f \n", currentEpisode, sumOfRewards);

View File

@ -10,7 +10,6 @@ import java.util.Map;
public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearning<A> {
private float alpha;
private Policy<A> greedyPolicy = new GreedyPolicy<>();
public QLearningOffPolicyTDControl(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, float learningRate, int delay) {
@ -34,19 +33,29 @@ public class QLearningOffPolicyTDControl<A extends Enum> extends EpisodicLearnin
}
StepResultEnvironment envResult = null;
Map<A, Double> actionValues;
Map<A, Double> actionValues = null;
sumOfRewards = 0;
while(envResult == null || !envResult.isDone()) {
actionValues = stateActionTable.getActionValues(state);
A action = policy.chooseAction(actionValues);
A action;
if(currentEpisode % 2 == 0) {
action = greedyPolicy.chooseAction(actionValues);
} else {
action = policy.chooseAction(actionValues);
}
if(converged) return;
// Take a step
envResult = environment.step(action);
double reward = envResult.getReward();
State nextState = envResult.getState();
sumOfRewards += reward;
if(currentEpisode % 2 == 0) {
state = nextState;
dispatchStepEnd();
continue;
}
// Q Update
double currentQValue = stateActionTable.getActionValues(state).get(action);
// maxQ(S', a);

View File

@ -3,12 +3,15 @@ package core.algo.td;
import core.*;
import core.algo.EpisodicLearning;
import core.policy.EpsilonGreedyPolicy;
import core.policy.GreedyPolicy;
import core.policy.Policy;
import java.util.Map;
public class SARSA<A extends Enum> extends EpisodicLearning<A> {
private float alpha;
private Policy<A> greedyPolicy = new GreedyPolicy<>();
public SARSA(Environment<A> environment, DiscreteActionSpace<A> actionSpace, float discountFactor, float epsilon, float learningRate, int delay) {
super(environment, actionSpace, discountFactor, delay);
@ -32,11 +35,18 @@ public class SARSA<A extends Enum> extends EpisodicLearning<A> {
StepResultEnvironment envResult = null;
Map<A, Double> actionValues = stateActionTable.getActionValues(state);
A action = policy.chooseAction(actionValues);
sumOfRewards = 0;
A action;
if(currentEpisode % 2 == 1){
action = greedyPolicy.chooseAction(actionValues);
}else{
action = policy.chooseAction(actionValues);
}
//A action = policy.chooseAction(actionValues);
sumOfRewards = 0;
while(envResult == null || !envResult.isDone()) {
if(converged) return;
// Take a step
envResult = environment.step(action);
sumOfRewards += envResult.getReward();
@ -46,8 +56,19 @@ public class SARSA<A extends Enum> extends EpisodicLearning<A> {
// Pick next action
actionValues = stateActionTable.getActionValues(nextState);
A nextAction = policy.chooseAction(actionValues);
A nextAction;
if(currentEpisode % 2 == 1){
nextAction = greedyPolicy.chooseAction(actionValues);
}else{
nextAction = policy.chooseAction(actionValues);
}
//A nextAction = policy.chooseAction(actionValues);
if(currentEpisode % 2 == 1){
state = nextState;
action = nextAction;
dispatchStepEnd();
continue;
}
// td update
// target = reward + gamma * Q(nextState, nextAction)
double currentQValue = stateActionTable.getActionValues(state).get(action);

View File

@ -1,117 +0,0 @@
package core.controller;
import core.algo.Method;
import evironment.antGame.AntAction;
import evironment.antGame.AntWorldContinuous;
import evironment.antGame.Constants;
import evironment.blackjack.BlackJackTable;
import evironment.blackjack.PlayerAction;
import evironment.jumpingDino.DinoAction;
import evironment.jumpingDino.DinoWorld;
import evironment.jumpingDino.DinoWorldAdvanced;
import javax.swing.*;
import java.text.NumberFormat;
public class OpeningDialog {
// JSlider is integer value only. Instead of creating a subclass
// the JSlider value is simply divided by this scale factor.
// 100 does mean 0 to 1 in 0.01 steps.
private int scaleFactor = 100;
public OpeningDialog() {
createStartingDialog();
}
private void setLabelText(JSlider slider, JLabel label, boolean scaledValue, String parameterName) {
if(scaledValue) {
label.setText(parameterName + ": " + (float) slider.getValue() / scaleFactor);
} else {
label.setText(parameterName + ": " + slider.getValue());
}
}
private void linkSliderWithLabel(JSlider slider, JLabel label, boolean scaledValue, String parameterName) {
setLabelText(slider, label, scaledValue, parameterName);
slider.addChangeListener(changeEvent -> setLabelText(slider, label, scaledValue, parameterName));
}
private void createStartingDialog() {
JComboBox<Scenario> scenarioSelection = new JComboBox<>(Scenario.values());
JComboBox<Method> algorithmSelection = new JComboBox<>(Method.values());
JSlider delaySlider = new JSlider(1, 1000, 200);
JLabel delayLabel = new JLabel();
linkSliderWithLabel(delaySlider, delayLabel, false, "Delay");
JSlider discountSlider = new JSlider(0, 100, 100);
JLabel discountLabel = new JLabel();
linkSliderWithLabel(discountSlider, discountLabel, true, "Discount Factor (gamma)");
JSlider epsilonSlider = new JSlider(0, 100, 15);
JLabel epsilonLabel = new JLabel();
linkSliderWithLabel(epsilonSlider, epsilonLabel, true, "Exploration Factor (epsilon)");
JSlider learningRateSlider = new JSlider(0, 100, 90);
JLabel learningRateLabel = new JLabel();
linkSliderWithLabel(learningRateSlider, learningRateLabel, true, "Learning rate (alpha)");
JTextField episodesToLearn = new JFormattedTextField(NumberFormat.getIntegerInstance());
episodesToLearn.setText("10000");
JTextField seedTextField = new JFormattedTextField(NumberFormat.getIntegerInstance());
seedTextField.setText("29");
Object[] parameters = {
"Environment:", scenarioSelection,
"Algorithm:", algorithmSelection,
discountLabel, discountSlider,
epsilonLabel, epsilonSlider,
learningRateLabel, learningRateSlider,
delayLabel, delaySlider,
"Episodes to learn:", episodesToLearn,
"RNG Seed: ", seedTextField,
};
int option = JOptionPane.showConfirmDialog(null, parameters, "Learning parameters", JOptionPane.OK_CANCEL_OPTION);
if(option == JOptionPane.OK_OPTION) {
Scenario selectedScenario = (Scenario) scenarioSelection.getSelectedItem();
RLControllerGUI rl;
if(selectedScenario == Scenario.JUMPING_DINO_SIMPLE) {
rl = new RLControllerGUI<DinoAction>(new DinoWorld(), (Method) algorithmSelection.getSelectedItem(), DinoAction.values());
} else if(selectedScenario == Scenario.JUMPING_DINO_ADVANCED) {
rl = new RLControllerGUI<DinoAction>(new DinoWorldAdvanced(), (Method) algorithmSelection.getSelectedItem(), DinoAction.values());
} else if(selectedScenario == Scenario.ANTGAME_ONE_FOOD) {
rl = new RLControllerGUI<AntAction>(new AntWorldContinuous(Constants.DEFAULT_GRID_WIDTH, Constants.DEFAULT_GRID_HEIGHT, 1), (Method) algorithmSelection.getSelectedItem(), AntAction.values());
} else if(selectedScenario == Scenario.ANTGAME_TWO_FOOD) {
rl = new RLControllerGUI<AntAction>(new AntWorldContinuous(Constants.DEFAULT_GRID_WIDTH, Constants.DEFAULT_GRID_HEIGHT, 2), (Method) algorithmSelection.getSelectedItem(), AntAction.values());
} else if(selectedScenario == Scenario.BLACKJACK) {
rl = new RLControllerGUI<PlayerAction>(new BlackJackTable(), (Method) algorithmSelection.getSelectedItem(), PlayerAction.values());
} else {
throw new IllegalArgumentException("Invalid learning scenario selected");
}
rl.setDelay(delaySlider.getValue());
rl.setDiscountFactor((float) discountSlider.getValue() / scaleFactor);
rl.setEpsilon((float) epsilonSlider.getValue() / scaleFactor);
rl.setLearningRate((float) learningRateSlider.getValue() / scaleFactor);
rl.setNrOfEpisodes(Integer.parseInt(episodesToLearn.getText()));
rl.start();
} else {
System.out.println("Parameter dialog canceled");
}
}
private enum Scenario {
JUMPING_DINO_SIMPLE,
JUMPING_DINO_ADVANCED,
ANTGAME_ONE_FOOD,
ANTGAME_TWO_FOOD,
BLACKJACK
}
}

View File

@ -39,7 +39,6 @@ public class RLController<A extends Enum> implements LearningListener {
protected int prevDelay;
protected volatile boolean printNextEpisode;
@SafeVarargs
public RLController(Environment<A> env, Method method, A... actions) {
setEnvironment(env);
setMethod(method);
@ -88,7 +87,7 @@ public class RLController<A extends Enum> implements LearningListener {
private void initLearning() {
if(learning instanceof EpisodicLearning) {
System.out.println("Starting learning of <" + nrOfEpisodes + "> episodes");
((EpisodicLearning<A>) learning).learn(nrOfEpisodes);
((EpisodicLearning) learning).learn(nrOfEpisodes);
} else {
learning.learn();
}
@ -101,9 +100,11 @@ public class RLController<A extends Enum> implements LearningListener {
protected void learnMoreEpisodes(int nrOfEpisodes) {
if(learning instanceof EpisodicLearning) {
if(learning.isCurrentlyLearning()){
((EpisodicLearning<A>) learning).learnMoreEpisodes(nrOfEpisodes);
((EpisodicLearning) learning).learnMoreEpisodes(nrOfEpisodes);
}else{
new Thread(() -> ((EpisodicLearning<A>) learning).learn(nrOfEpisodes)).start();
new Thread(() -> {
((EpisodicLearning) learning).learn(nrOfEpisodes);
}).start();
}
} else {
throw new RuntimeException("Triggering onLearnMoreEpisodes on non-episodic learning!");
@ -178,8 +179,8 @@ public class RLController<A extends Enum> implements LearningListener {
public void onEpisodeEnd(List<Double> rewardHistory) {
latestRewardsHistory = rewardHistory;
if(printNextEpisode) {
System.out.println("Episode " + learning.getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1));
System.out.println("Eps/sec: " + ((EpisodicLearning<A>) learning).getEpisodePerSecond());
System.out.println("Episode " + ((EpisodicLearning) learning).getCurrentEpisode() + " Latest Reward: " + rewardHistory.get(rewardHistory.size() - 1));
System.out.println("Eps/sec: " + ((EpisodicLearning) learning).getEpisodePerSecond());
printNextEpisode = false;
}
}

View File

@ -13,12 +13,10 @@ import java.util.List;
public class RLControllerGUI<A extends Enum> extends RLController<A> implements ViewListener {
private LearningView learningView;
@SafeVarargs
public RLControllerGUI(Environment<A> env, Method method, A... actions) {
super(env, method, actions);
}
@Override
protected void initListeners() {
SwingUtilities.invokeLater(() -> {
@ -104,7 +102,7 @@ public class RLControllerGUI<A extends Enum> extends RLController<A> implements
@Override
public void onLearningEnd() {
super.onLearningEnd();
onSaveState(method.toString() + System.currentTimeMillis() / 1000 + (learning instanceof EpisodicLearning ? "e" + learning.getCurrentEpisode() : ""));
onSaveState(method.toString() + System.currentTimeMillis() / 1000 + (learning instanceof EpisodicLearning ? "e" + ((EpisodicLearning) learning).getCurrentEpisode() : ""));
SwingUtilities.invokeLater(() -> learningView.updateRewardGraph(latestRewardsHistory));
}
}

View File

@ -38,7 +38,7 @@ public class LearningInfoPanel extends JPanel {
episodeLabel = new JLabel();
add(episodeLabel);
}
delaySlider = new JSlider(1, 1000, learning.getDelay());
delaySlider = new JSlider(0, 1000, learning.getDelay());
delaySlider.addChangeListener(e -> viewListener.onDelayChange(delaySlider.getValue()));
add(policyLabel);
add(discountLabel);
@ -60,9 +60,9 @@ public class LearningInfoPanel extends JPanel {
viewListener.onFastLearnChange(fastLearning);
});
smoothGraphCheckbox = new JCheckBox("Smoothen Graph");
smoothGraphCheckbox.setSelected(true);
smoothGraphCheckbox.setSelected(false);
last100Checkbox = new JCheckBox("Only show last 100 Rewards");
last100Checkbox.setSelected(false);
last100Checkbox.setSelected(true);
drawEnvironmentCheckbox = new JCheckBox("Update Environment");
drawEnvironmentCheckbox.setSelected(true);
@ -85,7 +85,9 @@ public class LearningInfoPanel extends JPanel {
add(learnMoreEpisodesButton);
}
showQTableButton = new JButton("Show Q-Table");
showQTableButton.addActionListener(e -> viewListener.onShowQTable());
showQTableButton.addActionListener(e -> {
viewListener.onShowQTable();
});
add(drawEnvironmentCheckbox);
add(smoothGraphCheckbox);
add(last100Checkbox);

View File

@ -1,6 +1,7 @@
package evironment.antGame;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Setter;

View File

@ -86,12 +86,12 @@ public class AntState implements State, Visualizable {
public JComponent visualize() {
return new JScrollPane() {
private int cellSize;
private final int paneWidth = 500;
private final int paneHeight = 500;
private Font font;
{
int paneWidth = 500;
int paneHeight = 500;
setPreferredSize(new Dimension(paneWidth, paneHeight));
cellSize = (paneWidth - knownWorld.length) / knownWorld.length;
cellSize = (paneWidth- knownWorld.length) /knownWorld.length;
font = new Font("plain", Font.BOLD, cellSize);
JPanel worldPanel = new JPanel(){
{

View File

@ -36,14 +36,12 @@ public class AntWorld implements Environment<AntAction>, Visualizable {
* various lectures could be possible as well.
*/
protected AntAgent antAgent;
protected int numberOfConcurrentFood;
protected int tick;
private int maxEpisodeTicks;
public AntWorld(int width, int height, int numberOfConcurrentFood) {
this.numberOfConcurrentFood = numberOfConcurrentFood;
grid = new Grid(width, height, numberOfConcurrentFood);
public AntWorld(int width, int height) {
grid = new Grid(width, height);
antAgent = new AntAgent(width, height);
myAnt = new Ant();
maxEpisodeTicks = 1000;
@ -51,12 +49,12 @@ public class AntWorld implements Environment<AntAction>, Visualizable {
}
public AntWorld(){
this(Constants.DEFAULT_GRID_WIDTH, Constants.DEFAULT_GRID_HEIGHT, Constants.DEFAULT_CONCURRENT_FOOD);
this(Constants.DEFAULT_GRID_WIDTH, Constants.DEFAULT_GRID_HEIGHT);
}
protected StepCalculation processStep(AntAction action) {
StepCalculation sc = new StepCalculation();
sc.reward = Reward.DEFAULT_REWARD;
sc.reward = -1;
sc.info = "";
sc.done = false;
Cell currentCell = grid.getCell(myAnt.getPos());
@ -64,7 +62,7 @@ public class AntWorld implements Environment<AntAction>, Visualizable {
sc.stayOnCell = true;
// flag to enable a check if all food has been collected only fired if food was dropped
// on the starting position
sc.foodCollected = false;
sc.checkCompletion = false;
switch(action) {
case MOVE_UP:
@ -86,10 +84,10 @@ public class AntWorld implements Environment<AntAction>, Visualizable {
case PICK_UP:
if(myAnt.hasFood()) {
// Ant tries to pick up food but can only hold one piece
sc.reward = Reward.FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY;
sc.reward += Reward.FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY;
} else if(currentCell.getFood() == 0) {
// Ant tries to pick up food on cell that has no food on it
sc.reward = Reward.FOOD_PICK_UP_FAIL_NO_FOOD;
sc.reward += Reward.FOOD_PICK_UP_FAIL_NO_FOOD;
} else if(currentCell.getFood() > 0) {
// Ant successfully picks up food
currentCell.setFood(currentCell.getFood() - 1);
@ -100,19 +98,19 @@ public class AntWorld implements Environment<AntAction>, Visualizable {
case DROP_DOWN:
if(!myAnt.hasFood()) {
// Ant had no food to drop
sc.reward = Reward.FOOD_DROP_DOWN_FAIL_NO_FOOD;
sc.reward += Reward.FOOD_DROP_DOWN_FAIL_NO_FOOD;
} else {
myAnt.setHasFood(false);
// negative reward if the agent drops food on any other field
// than the starting point
if(currentCell.getType() != CellType.START) {
sc.reward = Reward.FOOD_DROP_DOWN_FAIL_NOT_START;
sc.reward += Reward.FOOD_DROP_DOWN_FAIL_NOT_START;
// Drop food onto the ground
currentCell.setFood(currentCell.getFood() + 1);
} else {
sc.reward = Reward.FOOD_DROP_DOWN_SUCCESS;
myAnt.setPoints(myAnt.getPoints() + 1);
sc.foodCollected = true;
sc.checkCompletion = true;
}
}
break;
@ -124,10 +122,10 @@ public class AntWorld implements Environment<AntAction>, Visualizable {
if(!sc.stayOnCell) {
if(!isInGrid(sc.potentialNextPos)) {
sc.stayOnCell = true;
sc.reward = Reward.RAN_INTO_WALL;
sc.reward += Reward.RAN_INTO_WALL;
} else if(hitObstacle(sc.potentialNextPos)) {
sc.stayOnCell = true;
sc.reward = Reward.RAN_INTO_OBSTACLE;
sc.reward += Reward.RAN_INTO_OBSTACLE;
}
}
@ -141,10 +139,16 @@ public class AntWorld implements Environment<AntAction>, Visualizable {
// valid movement
if(!sc.stayOnCell) {
myAnt.getPos().setLocation(sc.potentialNextPos);
antAgent.getCell(myAnt.getPos());// the ant will move to a cell that was previously unknown
if(antAgent.getCell(myAnt.getPos()).getType() == CellType.UNKNOWN){
// the ant will move to a cell that was previously unknown
// TODO: not optimal for going straight for food
// sc.reward = Reward.UNKNOWN_FIELD_EXPLORED;
}
}
if(sc.foodCollected) {
if(sc.checkCompletion) {
sc.done = grid.isAllFoodCollected();
}
@ -179,7 +183,7 @@ public class AntWorld implements Environment<AntAction>, Visualizable {
boolean stayOnCell = true;
// flag to enable a check if all food has been collected only fired if food was dropped
// on the starting position
boolean foodCollected = false;
boolean checkCompletion = false;
}
public State reset() {

View File

@ -4,8 +4,8 @@ import core.State;
import core.StepResultEnvironment;
public class AntWorldContinuous extends AntWorld {
public AntWorldContinuous(int width, int height, int numberOfConcurrentFood) {
super(width, height, numberOfConcurrentFood);
public AntWorldContinuous(int width, int height) {
super(width, height);
}
public AntWorldContinuous() {
@ -14,15 +14,13 @@ public class AntWorldContinuous extends AntWorld {
@Override
public StepResultEnvironment step(AntAction action) {
Cell currentCell = grid.getCell(myAnt.getPos());
StepCalculation sc = processStep(action);
// flag is set to true if food gets dropped onto starts
if(sc.foodCollected) {
grid.removeAllFood();
System.out.println(numberOfConcurrentFood);
for(int i = 0; i < numberOfConcurrentFood; ++i) {
grid.spawnNewFood();
}
if(sc.checkCompletion) {
grid.spawnNewFood();
}
// valid movement
if(!sc.stayOnCell) {

View File

@ -3,8 +3,8 @@ package evironment.antGame;
import core.State;
public class AntWorldContinuousOriginalState extends AntWorldContinuous {
public AntWorldContinuousOriginalState(int width, int height, int numberOfConcurrentFood) {
super(width, height, numberOfConcurrentFood);
public AntWorldContinuousOriginalState(int width, int height) {
super(width, height);
}
public AntWorldContinuousOriginalState() {

View File

@ -1,9 +1,8 @@
package evironment.antGame;
public class Constants {
public static final int DEFAULT_CONCURRENT_FOOD = 1;
public static final int DEFAULT_GRID_WIDTH = 8;
public static final int DEFAULT_GRID_HEIGHT = 8;
public static final int DEFAULT_GRID_WIDTH = 5;
public static final int DEFAULT_GRID_HEIGHT = 5;
public static final int START_X = 5;
public static final int START_Y = 2;
}

View File

@ -10,12 +10,10 @@ public class Grid {
private Point start;
private Cell[][] grid;
private Cell[][] initialGrid;
private int numberOfConcurrentFood;
public Grid(int width, int height, int numberOfConcurrentFood) {
public Grid(int width, int height) {
this.width = width;
this.height = height;
this.numberOfConcurrentFood = numberOfConcurrentFood;
grid = new Cell[width][height];
initialGrid = new Cell[width][height];
start = new Point(Constants.START_X, Constants.START_Y);
@ -32,11 +30,11 @@ public class Grid {
initialGrid[x][y] = new Cell(new Point(x, y), CellType.FREE);
}
}
spawnNewFood(initialGrid);
spawnObstacles();
initialGrid[start.x][start.y] = new Cell(new Point(start.x, start.y), CellType.START);
for(int i = 0; i < numberOfConcurrentFood; ++i) {
spawnNewFood(initialGrid);
}
;
}
//TODO
@ -62,8 +60,8 @@ public class Grid {
Point potFood = new Point(0, 0);
CellType potFieldType;
while(!foodSpawned) {
potFood.x = RNG.getRandomEnv().nextInt(width);
potFood.y = RNG.getRandomEnv().nextInt(height);
potFood.x = RNG.getEnvRandom().nextInt(width);
potFood.y = RNG.getEnvRandom().nextInt(height);
potFieldType = grid[potFood.x][potFood.y].getType();
if(potFieldType != CellType.START && grid[potFood.x][potFood.y].getFood() == 0 && potFieldType != CellType.OBSTACLE) {
grid[potFood.x][potFood.y].setFood(1);
@ -74,14 +72,6 @@ public class Grid {
}
}
public void removeAllFood() {
for(int x = 0; x < width; ++x) {
for(int y = 0; y < height; ++y) {
grid[x][y].setFood(0);
}
}
}
public void spawnNewFood() {
spawnNewFood(grid);
}

View File

@ -1,17 +1,16 @@
package evironment.antGame;
public class Reward {
public static final double DEFAULT_REWARD = -1;
public static final double FOOD_PICK_UP_SUCCESS = 0;
public static final double FOOD_PICK_UP_FAIL_NO_FOOD = -2;
public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = -2;
public static final double FOOD_PICK_UP_SUCCESS = 1;
public static final double FOOD_PICK_UP_FAIL_NO_FOOD = -1;
public static final double FOOD_PICK_UP_FAIL_HAS_FOOD_ALREADY = -1;
public static final double FOOD_DROP_DOWN_FAIL_NO_FOOD = -2;
public static final double FOOD_DROP_DOWN_FAIL_NOT_START = -2;
public static final double FOOD_DROP_DOWN_SUCCESS = 1;
public static final double FOOD_DROP_DOWN_FAIL_NO_FOOD = -1;
public static final double FOOD_DROP_DOWN_FAIL_NOT_START = -1;
public static final double FOOD_DROP_DOWN_SUCCESS = 40;
public static final double UNKNOWN_FIELD_EXPLORED = 0;
public static final double RAN_INTO_WALL = -2;
public static final double RAN_INTO_OBSTACLE = -2;
public static final double RAN_INTO_WALL = -1;
public static final double RAN_INTO_OBSTACLE = -1;
}

View File

@ -7,8 +7,10 @@ import javax.swing.*;
import java.awt.*;
public class AntWorldComponent extends JComponent {
private AntWorld antWorld;
public AntWorldComponent(AntWorld antWorld, AntAgent antAgent){
this.antWorld = antWorld;
setLayout(new BorderLayout());
CellsScrollPane worldPane = new CellsScrollPane(antWorld.getCellArray(), antWorld.getAnt(), 10);
CellsScrollPane antBrainPane = new CellsScrollPane(antAgent.getKnownWorld(), antWorld.getAnt(), 10);

View File

@ -29,6 +29,6 @@ public class CardDeck {
nextInt(int bound) returns random int value from (inclusive) 0
and EXCLUSIVE! bound
*/
return cards.get(RNG.getRandomEnv().nextInt(cards.size()));
return cards.get(RNG.getEnvRandom().nextInt(cards.size()));
}
}

View File

@ -2,20 +2,21 @@ package evironment.jumpingDino;
import core.State;
import core.gui.Visualizable;
import lombok.AllArgsConstructor;
import lombok.Getter;
import javax.swing.*;
import java.awt.*;
import java.io.Serializable;
import java.util.Objects;
@AllArgsConstructor
@Getter
public class DinoState extends DinoStateSimple implements State, Serializable, Visualizable {
public class DinoState implements State, Serializable, Visualizable {
private int xDistanceToObstacle;
private boolean isJumping;
public DinoState(int xDistanceToObstacle, boolean isJumping) {
super(xDistanceToObstacle);
this.isJumping = isJumping;
}
protected final double scale = 0.5;
@Override
public String toString() {
@ -39,15 +40,29 @@ public class DinoState extends DinoStateSimple implements State, Serializable, V
}
@Override
protected void drawDinoInfo(Graphics g) {
int dinoY;
if(!isJumping) {
dinoY = Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE;
g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (dinoY)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
} else {
dinoY = Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE - (int) (scale * Config.MAX_JUMP_HEIGHT);
g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (dinoY)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
}
g.drawString("Distance: " + xDistanceToObstacle + " inJump: " + isJumping, (int) (scale * Config.DINO_STARTING_X), (int) (scale * (dinoY - 20)));
public JComponent visualize() {
return new JComponent() {
{
setPreferredSize(new Dimension(Config.FRAME_WIDTH, (int)(scale * Config.FRAME_HEIGHT)));
setVisible(true);
}
@Override
protected void paintComponent(Graphics g) {
super.paintComponents(g);
drawObjects(g);
}
};
}
public void drawObjects(Graphics g){
g.setColor(Color.BLACK);
g.fillRect(0, (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y)), Config.FRAME_WIDTH, 2);
g.fillRect((int)(scale * Config.DINO_STARTING_X), (int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE)), (int)(scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
g.drawString("Distance: " + xDistanceToObstacle, (int)(scale * Config.DINO_STARTING_X),(int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40) ));
g.fillRect((int)(scale * (Config.DINO_STARTING_X + getXDistanceToObstacle())), (int)(scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE)), (int)(scale * Config.OBSTACLE_SIZE), (int)(scale *Config.OBSTACLE_SIZE));
}
}

View File

@ -14,7 +14,7 @@ import java.util.Objects;
@Getter
public class DinoStateSimple implements State, Serializable, Visualizable {
protected final double scale = 0.5;
protected int xDistanceToObstacle;
private int xDistanceToObstacle;
@Override
public String toString() {
@ -40,7 +40,7 @@ public class DinoStateSimple implements State, Serializable, Visualizable {
public JComponent visualize() {
return new JComponent() {
{
setPreferredSize(new Dimension((int) (scale * Config.FRAME_WIDTH), (int) (scale * Config.FRAME_HEIGHT)));
setPreferredSize(new Dimension(Config.FRAME_WIDTH, (int) (scale * Config.FRAME_HEIGHT)));
setVisible(true);
}
@ -52,15 +52,14 @@ public class DinoStateSimple implements State, Serializable, Visualizable {
};
}
protected void drawDinoInfo(Graphics g) {
g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
g.drawString("Distance: " + xDistanceToObstacle, (int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40)));
}
public void drawObjects(Graphics g) {
g.setColor(Color.BLACK);
g.fillRect(0, (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y)), Config.FRAME_WIDTH, 2);
g.fillRect((int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.DINO_SIZE)), (int) (scale * Config.DINO_SIZE), (int) (scale * Config.DINO_SIZE));
g.drawString("Distance: " + xDistanceToObstacle, (int) (scale * Config.DINO_STARTING_X), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE - 40)));
g.fillRect((int) (scale * (Config.DINO_STARTING_X + getXDistanceToObstacle())), (int) (scale * (Config.FRAME_HEIGHT - Config.GROUND_Y - Config.OBSTACLE_SIZE)), (int) (scale * Config.OBSTACLE_SIZE), (int) (scale * Config.OBSTACLE_SIZE));
drawDinoInfo(g);
}
}

View File

@ -49,6 +49,19 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
if(action == DinoAction.JUMP){
dino.jump();
}
// for(int i= 0; i < 5; ++i){
// dino.tick();
// currentObstacle.tick();
// if(currentObstacle.getX() < -Config.OBSTACLE_SIZE){
// spawnNewObstacle();
// }
// comp.repaint();
// if(ranIntoObstacle()){
// done = true;
// break;
// }
// }
dino.tick();
currentObstacle.tick();
if(currentObstacle.getX() < -Config.OBSTACLE_SIZE) {
@ -63,7 +76,7 @@ public class DinoWorld implements Environment<DinoAction>, Visualizable {
}
protected State generateReturnState(){
return new DinoState(getDistanceToObstacle(), dino.isInJump());
return new DinoStateSimple(getDistanceToObstacle());
}
protected void spawnNewObstacle(){

View File

@ -31,7 +31,7 @@ public class DinoWorldAdvanced extends DinoWorld{
protected void spawnNewObstacle() {
int dx;
int xSpawn;
double ran = RNG.getRandomEnv().nextDouble();
double ran = RNG.getEnvRandom().nextDouble();
if(ran < 0.25){
dx = -(int) (0.35 * Config.OBSTACLE_SPEED);
}else if(ran < 0.5){
@ -41,7 +41,7 @@ public class DinoWorldAdvanced extends DinoWorld{
} else{
dx = -(int) (3.5 * Config.OBSTACLE_SPEED);
}
double ran2 = RNG.getRandomEnv().nextDouble();
double ran2 = RNG.getEnvRandom().nextDouble();
if(ran2 < 0.25) {
// randomly spawning more right outside of the screen
xSpawn = Config.FRAME_WIDTH + Config.FRAME_WIDTH + Config.OBSTACLE_SIZE;

View File

@ -9,7 +9,7 @@ import evironment.blackjack.PlayerAction;
public class BlackJack {
public static void main(String[] args) {
RNG.setSeed(55);
RNG.setSeed(55, true);
RLController<PlayerAction> rl = new RLControllerGUI<>(
new BlackJackTable(),

View File

@ -1,24 +0,0 @@
package example;
import core.RNG;
import core.algo.Method;
import core.controller.RLController;
import core.controller.RLControllerGUI;
import evironment.antGame.AntAction;
import evironment.antGame.AntWorldContinuous;
public class ContinuousAnt {
public static void main(String[] args) {
RNG.setSeed(13, true);
RLController<AntAction> rl = new RLControllerGUI<>(
new AntWorldContinuous(8, 8, 1),
Method.Q_LEARNING_OFF_POLICY_CONTROL,
AntAction.values());
rl.setDelay(200);
rl.setNrOfEpisodes(1);
rl.setDiscountFactor(0.3f);
rl.setLearningRate(0.9f);
rl.setEpsilon(0.15f);
rl.start();
}
}

View File

@ -0,0 +1,52 @@
package example;
import core.RNG;
import core.algo.Method;
import core.controller.RLController;
import evironment.jumpingDino.DinoAction;
import evironment.jumpingDino.DinoWorldAdvanced;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
public class DinoSampling {
public static final String FILE_NAME = "advancedEveryVisit.txt";
public static void main(String[] args) {
File file = new File(FILE_NAME);
try {
file.createNewFile();
} catch (IOException e) {
e.printStackTrace();
}
for(float f = 0.05f; f <= 1.003; f += 0.05f) {
try {
Files.writeString(Path.of(file.getPath()), f + ",", StandardOpenOption.APPEND);
} catch (IOException e) {
e.printStackTrace();
}
for(int i = 1; i <= 100; i++) {
System.out.println("seed: " + i * 13);
RNG.setSeed(i * 13, true);
RLController<DinoAction> rl = new RLController<>(
new DinoWorldAdvanced(),
Method.MC_CONTROL_EVERY_VISIT,
DinoAction.values());
rl.setDelay(0);
rl.setDiscountFactor(1f);
rl.setEpsilon(f);
rl.setLearningRate(1f);
rl.setNrOfEpisodes(400000);
rl.start();
}
try {
Files.writeString(Path.of(file.getPath()), "\n", StandardOpenOption.APPEND);
} catch (IOException e) {
e.printStackTrace();
}
}
}
}

View File

@ -1,9 +0,0 @@
package example;
import core.controller.OpeningDialog;
public class GUIMain {
public static void main(String[] args) {
new OpeningDialog();
}
}

View File

@ -3,20 +3,19 @@ package example;
import core.RNG;
import core.algo.Method;
import core.controller.RLController;
import core.controller.RLControllerGUI;
import evironment.jumpingDino.DinoAction;
import evironment.jumpingDino.DinoWorldAdvanced;
public class JumpingDino {
public static void main(String[] args) {
RNG.setSeed(29);
RNG.setSeed(29, true);
RLController<DinoAction> rl = new RLControllerGUI<>(
RLController<DinoAction> rl = new RLController<>(
new DinoWorldAdvanced(),
Method.MC_CONTROL_EVERY_VISIT,
Method.MC_CONTROL_FIRST_VISIT,
DinoAction.values());
rl.setDelay(200);
rl.setDelay(0);
rl.setDiscountFactor(1f);
rl.setEpsilon(0.05f);
rl.setLearningRate(1f);

View File

@ -9,17 +9,16 @@ import evironment.antGame.AntWorld;
public class RunningAnt {
public static void main(String[] args) {
RNG.setSeed(56);
RNG.setSeed(56, true);
RLController<AntAction> rl = new RLControllerGUI<>(
new AntWorld(8, 8, 1),
new AntWorld(8, 8),
Method.Q_LEARNING_OFF_POLICY_CONTROL,
AntAction.values());
rl.setDelay(200);
rl.setNrOfEpisodes(10000);
rl.setDiscountFactor(0.9f);
rl.setLearningRate(0.9f);
rl.setEpsilon(0.15f);
rl.start();
}