diff --git a/.idea/refo.iml b/.idea/refo.iml
index dbef0c9..9ec6aa6 100644
--- a/.idea/refo.iml
+++ b/.idea/refo.iml
@@ -5,6 +5,7 @@
+
diff --git a/src/main/java/core/Action.java b/src/main/java/core/Action.java
deleted file mode 100644
index 86c7c18..0000000
--- a/src/main/java/core/Action.java
+++ /dev/null
@@ -1,8 +0,0 @@
-package core;
-
-public interface Action {
- int getIndex();
- String toString();
- int hashCode();
- boolean equals(Object obj);
-}
diff --git a/src/main/java/core/ActionSpace.java b/src/main/java/core/ActionSpace.java
deleted file mode 100644
index 7fe9446..0000000
--- a/src/main/java/core/ActionSpace.java
+++ /dev/null
@@ -1,9 +0,0 @@
-package core;
-
-import java.util.List;
-
-public interface ActionSpace {
- int getNumberOfAction();
- void addAction(DiscreteAction a);
- void addActions(A[] as);
-}
diff --git a/src/main/java/core/DiscreteAction.java b/src/main/java/core/DiscreteAction.java
deleted file mode 100644
index 52a8c0a..0000000
--- a/src/main/java/core/DiscreteAction.java
+++ /dev/null
@@ -1,36 +0,0 @@
-package core;
-
-public class DiscreteAction implements Action{
- private A action;
-
- public DiscreteAction(A action){
- this.action = action;
- }
-
- public A getValue(){
- return action;
- }
-
- @Override
- public int getIndex(){
- return action.ordinal();
- }
-
- @Override
- public String toString(){
- return action.toString();
- }
-
- @Override
- public int hashCode() {
- return getIndex();
- }
-
- @Override
- public boolean equals(Object obj) {
- if(obj instanceof DiscreteAction){
- return getIndex() == ((DiscreteAction) obj).getIndex();
- }
- return super.equals(obj);
- }
-}
diff --git a/src/main/java/core/DiscreteActionSpace.java b/src/main/java/core/DiscreteActionSpace.java
index fa722b9..a6b38fe 100644
--- a/src/main/java/core/DiscreteActionSpace.java
+++ b/src/main/java/core/DiscreteActionSpace.java
@@ -1,33 +1,10 @@
package core;
-import java.util.ArrayList;
import java.util.List;
-public class DiscreteActionSpace implements ActionSpace{
- private List> actions;
-
- public DiscreteActionSpace(){
- actions = new ArrayList<>();
- }
-
- @Override
- public void addAction(DiscreteAction action){
- actions.add(action);
- }
-
- @Override
- public void addActions(A[] as) {
- for(A a : as){
- actions.add(new DiscreteAction<>(a));
- }
- }
-
- @Override
- public int getNumberOfAction(){
- return actions.size();
- }
-
- public List> getAllDiscreteActions(){
- return actions;
- }
+public interface DiscreteActionSpace {
+ int getNumberOfAction();
+ void addAction(A a);
+ void addActions(A... as);
+ List getAllActions();
}
diff --git a/src/main/java/core/Environment.java b/src/main/java/core/Environment.java
new file mode 100644
index 0000000..a93d1bb
--- /dev/null
+++ b/src/main/java/core/Environment.java
@@ -0,0 +1,5 @@
+package core;
+
+public interface Environment {
+ StepResult step(A action);
+}
diff --git a/src/main/java/core/ListDiscreteActionSpace.java b/src/main/java/core/ListDiscreteActionSpace.java
new file mode 100644
index 0000000..76babaf
--- /dev/null
+++ b/src/main/java/core/ListDiscreteActionSpace.java
@@ -0,0 +1,38 @@
+package core;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+public class ListDiscreteActionSpace implements DiscreteActionSpace {
+ private List actions;
+
+ public ListDiscreteActionSpace(){
+ actions = new ArrayList<>();
+ }
+
+ public ListDiscreteActionSpace(A... actions){
+ this.actions = new ArrayList<>(Arrays.asList(actions));
+ }
+
+ @Override
+ public void addAction(A action){
+ actions.add(action);
+ }
+
+ @SafeVarargs
+ @Override
+ public final void addActions(A... as) {
+ actions.addAll(Arrays.asList(as));
+ }
+
+ @Override
+ public List getAllActions() {
+ return actions;
+ }
+
+ @Override
+ public int getNumberOfAction(){
+ return actions.size();
+ }
+}
diff --git a/src/main/java/core/StateActionHashTable.java b/src/main/java/core/StateActionHashTable.java
index 4d7b6a8..66b6ae0 100644
--- a/src/main/java/core/StateActionHashTable.java
+++ b/src/main/java/core/StateActionHashTable.java
@@ -1,19 +1,21 @@
package core;
+import evironment.antGame.AntAction;
+
import java.util.HashMap;
import java.util.Map;
/**
* Premise: All states have the complete action space
*/
-public class StateActionHashTable implements StateActionTable {
+public class StateActionHashTable implements StateActionTable {
- private final Map> table;
- private ActionSpace actionSpace;
+ private final Map> table;
+ private DiscreteActionSpace discreteActionSpace;
- public StateActionHashTable(ActionSpace actionSpace){
+ public StateActionHashTable(DiscreteActionSpace discreteActionSpace){
table = new HashMap<>();
- this.actionSpace = actionSpace;
+ this.discreteActionSpace = discreteActionSpace;
}
/*
@@ -25,8 +27,8 @@ public class StateActionHashTable implements StateActionTable {
method.
*/
@Override
- public double getValue(State state, Action action) {
- final Map actionValues = table.get(state);
+ public double getValue(State state, A action) {
+ final Map actionValues = table.get(state);
if (actionValues != null) {
return actionValues.get(action);
}
@@ -40,8 +42,8 @@ public class StateActionHashTable implements StateActionTable {
from the action space initialized with the default value.
*/
@Override
- public void setValue(State state, Action action, double value) {
- final Map actionValues;
+ public void setValue(State state, A action, double value) {
+ final Map actionValues;
if (table.containsKey(state)) {
actionValues = table.get(state);
} else {
@@ -52,15 +54,27 @@ public class StateActionHashTable implements StateActionTable {
}
@Override
- public Map getActionValues(State state) {
- return null;
+ public Map getActionValues(State state) {
+ Map actionValues = table.get(state);
+ if(actionValues == null){
+ actionValues = createDefaultActionValues();
+ }
+ return actionValues;
}
- private Map createDefaultActionValues(){
- final Map defaultActionValues = new HashMap<>();
- // for(Action action: actionSpace.getAllActions()){
- // defaultActionValues.put(action, DEFAULT_VALUE);
- //}
+ public static void main(String[] args) {
+ DiscreteActionSpace da = new ListDiscreteActionSpace<>(AntAction.MOVE_RIGHT, AntAction.PICK_UP);
+ StateActionTable sat = new StateActionHashTable<>(da);
+ State t = new State() {
+ };
+
+ System.out.println(sat.getActionValues(t));
+ }
+ private Map createDefaultActionValues(){
+ final Map defaultActionValues = new HashMap<>();
+ for(A action: discreteActionSpace.getAllActions()){
+ defaultActionValues.put(action, DEFAULT_VALUE);
+ }
return defaultActionValues;
}
}
diff --git a/src/main/java/core/StateActionTable.java b/src/main/java/core/StateActionTable.java
index 7fecafb..851ffba 100644
--- a/src/main/java/core/StateActionTable.java
+++ b/src/main/java/core/StateActionTable.java
@@ -2,11 +2,11 @@ package core;
import java.util.Map;
-public interface StateActionTable {
+public interface StateActionTable {
double DEFAULT_VALUE = 0.0;
- double getValue(State state, Action action);
- void setValue(State state, Action action, double value);
+ double getValue(State state, A action);
+ void setValue(State state, A action, double value);
- Map getActionValues(State state);
+ Map getActionValues(State state);
}
diff --git a/src/main/java/core/StepResult.java b/src/main/java/core/StepResult.java
index 89c3bf9..4ed1586 100644
--- a/src/main/java/core/StepResult.java
+++ b/src/main/java/core/StepResult.java
@@ -8,7 +8,7 @@ import lombok.Setter;
@Setter
@AllArgsConstructor
public class StepResult {
- private State observation;
+ private State state;
private double reward;
private boolean done;
private String info;
diff --git a/src/main/java/evironment/antGame/AntWorld.java b/src/main/java/evironment/antGame/AntWorld.java
index 9fe900a..ef305e0 100644
--- a/src/main/java/evironment/antGame/AntWorld.java
+++ b/src/main/java/evironment/antGame/AntWorld.java
@@ -3,10 +3,11 @@ package evironment.antGame;
import core.*;
import evironment.antGame.gui.MainFrame;
+
import javax.swing.*;
import java.awt.*;
-public class AntWorld {
+public class AntWorld implements Environment{
/**
*
*/
@@ -53,7 +54,8 @@ public class AntWorld {
this(Constants.DEFAULT_GRID_WIDTH, Constants.DEFAULT_GRID_HEIGHT, Constants.DEFAULT_FOOD_DENSITY);
}
- public StepResult step(DiscreteAction action){
+ @Override
+ public StepResult step(AntAction action){
AntObservation observation;
State newState;
double reward = 0;
@@ -77,7 +79,7 @@ public class AntWorld {
// on the starting position
boolean checkCompletion = false;
- switch (action.getValue()) {
+ switch (action) {
case MOVE_UP:
potentialNextPos.y -= 1;
stayOnCell = false;
@@ -208,13 +210,13 @@ public class AntWorld {
public static void main(String[] args) {
RNG.setSeed(1993);
AntWorld a = new AntWorld(10, 10, 0.1);
- DiscreteActionSpace actionSpace = new DiscreteActionSpace<>();
- actionSpace.addActions(AntAction.values());
+ ListDiscreteActionSpace actionSpace =
+ new ListDiscreteActionSpace<>(AntAction.MOVE_LEFT, AntAction.MOVE_RIGHT);
for(int i = 0; i< 1000; ++i){
- DiscreteAction selectedAction = actionSpace.getAllDiscreteActions().get(RNG.getRandom().nextInt(actionSpace.getNumberOfAction()));
+ AntAction selectedAction = actionSpace.getAllActions().get(RNG.getRandom().nextInt(actionSpace.getNumberOfAction()));
StepResult step = a.step(selectedAction);
- SwingUtilities.invokeLater(()-> a.getGui().update(selectedAction.getValue(), step));
+ SwingUtilities.invokeLater(()-> a.getGui().update(selectedAction, step));
try {
Thread.sleep(100);
} catch (InterruptedException e) {