diff --git a/.idea/refo.iml b/.idea/refo.iml index dbef0c9..9ec6aa6 100644 --- a/.idea/refo.iml +++ b/.idea/refo.iml @@ -5,6 +5,7 @@ + diff --git a/src/main/java/core/Action.java b/src/main/java/core/Action.java deleted file mode 100644 index 86c7c18..0000000 --- a/src/main/java/core/Action.java +++ /dev/null @@ -1,8 +0,0 @@ -package core; - -public interface Action { - int getIndex(); - String toString(); - int hashCode(); - boolean equals(Object obj); -} diff --git a/src/main/java/core/ActionSpace.java b/src/main/java/core/ActionSpace.java deleted file mode 100644 index 7fe9446..0000000 --- a/src/main/java/core/ActionSpace.java +++ /dev/null @@ -1,9 +0,0 @@ -package core; - -import java.util.List; - -public interface ActionSpace { - int getNumberOfAction(); - void addAction(DiscreteAction a); - void addActions(A[] as); -} diff --git a/src/main/java/core/DiscreteAction.java b/src/main/java/core/DiscreteAction.java deleted file mode 100644 index 52a8c0a..0000000 --- a/src/main/java/core/DiscreteAction.java +++ /dev/null @@ -1,36 +0,0 @@ -package core; - -public class DiscreteAction implements Action{ - private A action; - - public DiscreteAction(A action){ - this.action = action; - } - - public A getValue(){ - return action; - } - - @Override - public int getIndex(){ - return action.ordinal(); - } - - @Override - public String toString(){ - return action.toString(); - } - - @Override - public int hashCode() { - return getIndex(); - } - - @Override - public boolean equals(Object obj) { - if(obj instanceof DiscreteAction){ - return getIndex() == ((DiscreteAction) obj).getIndex(); - } - return super.equals(obj); - } -} diff --git a/src/main/java/core/DiscreteActionSpace.java b/src/main/java/core/DiscreteActionSpace.java index fa722b9..a6b38fe 100644 --- a/src/main/java/core/DiscreteActionSpace.java +++ b/src/main/java/core/DiscreteActionSpace.java @@ -1,33 +1,10 @@ package core; -import java.util.ArrayList; import java.util.List; -public class DiscreteActionSpace implements ActionSpace{ - private List> actions; - - public DiscreteActionSpace(){ - actions = new ArrayList<>(); - } - - @Override - public void addAction(DiscreteAction action){ - actions.add(action); - } - - @Override - public void addActions(A[] as) { - for(A a : as){ - actions.add(new DiscreteAction<>(a)); - } - } - - @Override - public int getNumberOfAction(){ - return actions.size(); - } - - public List> getAllDiscreteActions(){ - return actions; - } +public interface DiscreteActionSpace { + int getNumberOfAction(); + void addAction(A a); + void addActions(A... as); + List getAllActions(); } diff --git a/src/main/java/core/Environment.java b/src/main/java/core/Environment.java new file mode 100644 index 0000000..a93d1bb --- /dev/null +++ b/src/main/java/core/Environment.java @@ -0,0 +1,5 @@ +package core; + +public interface Environment { + StepResult step(A action); +} diff --git a/src/main/java/core/ListDiscreteActionSpace.java b/src/main/java/core/ListDiscreteActionSpace.java new file mode 100644 index 0000000..76babaf --- /dev/null +++ b/src/main/java/core/ListDiscreteActionSpace.java @@ -0,0 +1,38 @@ +package core; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +public class ListDiscreteActionSpace implements DiscreteActionSpace { + private List actions; + + public ListDiscreteActionSpace(){ + actions = new ArrayList<>(); + } + + public ListDiscreteActionSpace(A... actions){ + this.actions = new ArrayList<>(Arrays.asList(actions)); + } + + @Override + public void addAction(A action){ + actions.add(action); + } + + @SafeVarargs + @Override + public final void addActions(A... as) { + actions.addAll(Arrays.asList(as)); + } + + @Override + public List getAllActions() { + return actions; + } + + @Override + public int getNumberOfAction(){ + return actions.size(); + } +} diff --git a/src/main/java/core/StateActionHashTable.java b/src/main/java/core/StateActionHashTable.java index 4d7b6a8..66b6ae0 100644 --- a/src/main/java/core/StateActionHashTable.java +++ b/src/main/java/core/StateActionHashTable.java @@ -1,19 +1,21 @@ package core; +import evironment.antGame.AntAction; + import java.util.HashMap; import java.util.Map; /** * Premise: All states have the complete action space */ -public class StateActionHashTable implements StateActionTable { +public class StateActionHashTable implements StateActionTable { - private final Map> table; - private ActionSpace actionSpace; + private final Map> table; + private DiscreteActionSpace discreteActionSpace; - public StateActionHashTable(ActionSpace actionSpace){ + public StateActionHashTable(DiscreteActionSpace discreteActionSpace){ table = new HashMap<>(); - this.actionSpace = actionSpace; + this.discreteActionSpace = discreteActionSpace; } /* @@ -25,8 +27,8 @@ public class StateActionHashTable implements StateActionTable { method. */ @Override - public double getValue(State state, Action action) { - final Map actionValues = table.get(state); + public double getValue(State state, A action) { + final Map actionValues = table.get(state); if (actionValues != null) { return actionValues.get(action); } @@ -40,8 +42,8 @@ public class StateActionHashTable implements StateActionTable { from the action space initialized with the default value. */ @Override - public void setValue(State state, Action action, double value) { - final Map actionValues; + public void setValue(State state, A action, double value) { + final Map actionValues; if (table.containsKey(state)) { actionValues = table.get(state); } else { @@ -52,15 +54,27 @@ public class StateActionHashTable implements StateActionTable { } @Override - public Map getActionValues(State state) { - return null; + public Map getActionValues(State state) { + Map actionValues = table.get(state); + if(actionValues == null){ + actionValues = createDefaultActionValues(); + } + return actionValues; } - private Map createDefaultActionValues(){ - final Map defaultActionValues = new HashMap<>(); - // for(Action action: actionSpace.getAllActions()){ - // defaultActionValues.put(action, DEFAULT_VALUE); - //} + public static void main(String[] args) { + DiscreteActionSpace da = new ListDiscreteActionSpace<>(AntAction.MOVE_RIGHT, AntAction.PICK_UP); + StateActionTable sat = new StateActionHashTable<>(da); + State t = new State() { + }; + + System.out.println(sat.getActionValues(t)); + } + private Map createDefaultActionValues(){ + final Map defaultActionValues = new HashMap<>(); + for(A action: discreteActionSpace.getAllActions()){ + defaultActionValues.put(action, DEFAULT_VALUE); + } return defaultActionValues; } } diff --git a/src/main/java/core/StateActionTable.java b/src/main/java/core/StateActionTable.java index 7fecafb..851ffba 100644 --- a/src/main/java/core/StateActionTable.java +++ b/src/main/java/core/StateActionTable.java @@ -2,11 +2,11 @@ package core; import java.util.Map; -public interface StateActionTable { +public interface StateActionTable { double DEFAULT_VALUE = 0.0; - double getValue(State state, Action action); - void setValue(State state, Action action, double value); + double getValue(State state, A action); + void setValue(State state, A action, double value); - Map getActionValues(State state); + Map getActionValues(State state); } diff --git a/src/main/java/core/StepResult.java b/src/main/java/core/StepResult.java index 89c3bf9..4ed1586 100644 --- a/src/main/java/core/StepResult.java +++ b/src/main/java/core/StepResult.java @@ -8,7 +8,7 @@ import lombok.Setter; @Setter @AllArgsConstructor public class StepResult { - private State observation; + private State state; private double reward; private boolean done; private String info; diff --git a/src/main/java/evironment/antGame/AntWorld.java b/src/main/java/evironment/antGame/AntWorld.java index 9fe900a..ef305e0 100644 --- a/src/main/java/evironment/antGame/AntWorld.java +++ b/src/main/java/evironment/antGame/AntWorld.java @@ -3,10 +3,11 @@ package evironment.antGame; import core.*; import evironment.antGame.gui.MainFrame; + import javax.swing.*; import java.awt.*; -public class AntWorld { +public class AntWorld implements Environment{ /** * */ @@ -53,7 +54,8 @@ public class AntWorld { this(Constants.DEFAULT_GRID_WIDTH, Constants.DEFAULT_GRID_HEIGHT, Constants.DEFAULT_FOOD_DENSITY); } - public StepResult step(DiscreteAction action){ + @Override + public StepResult step(AntAction action){ AntObservation observation; State newState; double reward = 0; @@ -77,7 +79,7 @@ public class AntWorld { // on the starting position boolean checkCompletion = false; - switch (action.getValue()) { + switch (action) { case MOVE_UP: potentialNextPos.y -= 1; stayOnCell = false; @@ -208,13 +210,13 @@ public class AntWorld { public static void main(String[] args) { RNG.setSeed(1993); AntWorld a = new AntWorld(10, 10, 0.1); - DiscreteActionSpace actionSpace = new DiscreteActionSpace<>(); - actionSpace.addActions(AntAction.values()); + ListDiscreteActionSpace actionSpace = + new ListDiscreteActionSpace<>(AntAction.MOVE_LEFT, AntAction.MOVE_RIGHT); for(int i = 0; i< 1000; ++i){ - DiscreteAction selectedAction = actionSpace.getAllDiscreteActions().get(RNG.getRandom().nextInt(actionSpace.getNumberOfAction())); + AntAction selectedAction = actionSpace.getAllActions().get(RNG.getRandom().nextInt(actionSpace.getNumberOfAction())); StepResult step = a.step(selectedAction); - SwingUtilities.invokeLater(()-> a.getGui().update(selectedAction.getValue(), step)); + SwingUtilities.invokeLater(()-> a.getGui().update(selectedAction, step)); try { Thread.sleep(100); } catch (InterruptedException e) {