remove the Action interface in favour of Enums

This commit is contained in:
Jan Löwenstrom 2019-12-09 17:30:14 +01:00
parent 8a533dda94
commit 0100f2e82a
11 changed files with 93 additions and 109 deletions

View File

@ -5,6 +5,7 @@
<output-test url="file://$MODULE_DIR$/build/classes/java/test" /> <output-test url="file://$MODULE_DIR$/build/classes/java/test" />
<exclude-output /> <exclude-output />
<content url="file://$MODULE_DIR$"> <content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/src/main/java" isTestSource="false" />
<excludeFolder url="file://$MODULE_DIR$/.gradle" /> <excludeFolder url="file://$MODULE_DIR$/.gradle" />
<excludeFolder url="file://$MODULE_DIR$/build" /> <excludeFolder url="file://$MODULE_DIR$/build" />
</content> </content>

View File

@ -1,8 +0,0 @@
package core;
public interface Action {
int getIndex();
String toString();
int hashCode();
boolean equals(Object obj);
}

View File

@ -1,9 +0,0 @@
package core;
import java.util.List;
public interface ActionSpace<A extends Enum> {
int getNumberOfAction();
void addAction(DiscreteAction<A> a);
void addActions(A[] as);
}

View File

@ -1,36 +0,0 @@
package core;
public class DiscreteAction<A extends Enum> implements Action{
private A action;
public DiscreteAction(A action){
this.action = action;
}
public A getValue(){
return action;
}
@Override
public int getIndex(){
return action.ordinal();
}
@Override
public String toString(){
return action.toString();
}
@Override
public int hashCode() {
return getIndex();
}
@Override
public boolean equals(Object obj) {
if(obj instanceof DiscreteAction){
return getIndex() == ((DiscreteAction) obj).getIndex();
}
return super.equals(obj);
}
}

View File

@ -1,33 +1,10 @@
package core; package core;
import java.util.ArrayList;
import java.util.List; import java.util.List;
public class DiscreteActionSpace<A extends Enum> implements ActionSpace<A>{ public interface DiscreteActionSpace<A extends Enum> {
private List<DiscreteAction<A>> actions; int getNumberOfAction();
void addAction(A a);
public DiscreteActionSpace(){ void addActions(A... as);
actions = new ArrayList<>(); List<A> getAllActions();
}
@Override
public void addAction(DiscreteAction<A> action){
actions.add(action);
}
@Override
public void addActions(A[] as) {
for(A a : as){
actions.add(new DiscreteAction<>(a));
}
}
@Override
public int getNumberOfAction(){
return actions.size();
}
public List<DiscreteAction<A>> getAllDiscreteActions(){
return actions;
}
} }

View File

@ -0,0 +1,5 @@
package core;
public interface Environment<A extends Enum> {
StepResult step(A action);
}

View File

@ -0,0 +1,38 @@
package core;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSpace<A> {
private List<A> actions;
public ListDiscreteActionSpace(){
actions = new ArrayList<>();
}
public ListDiscreteActionSpace(A... actions){
this.actions = new ArrayList<>(Arrays.asList(actions));
}
@Override
public void addAction(A action){
actions.add(action);
}
@SafeVarargs
@Override
public final void addActions(A... as) {
actions.addAll(Arrays.asList(as));
}
@Override
public List<A> getAllActions() {
return actions;
}
@Override
public int getNumberOfAction(){
return actions.size();
}
}

View File

@ -1,19 +1,21 @@
package core; package core;
import evironment.antGame.AntAction;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
/** /**
* Premise: All states have the complete action space * Premise: All states have the complete action space
*/ */
public class StateActionHashTable<A extends Enum> implements StateActionTable { public class StateActionHashTable<A extends Enum> implements StateActionTable<A> {
private final Map<State, Map<Action, Double>> table; private final Map<State, Map<A, Double>> table;
private ActionSpace<A> actionSpace; private DiscreteActionSpace<A> discreteActionSpace;
public StateActionHashTable(ActionSpace<A> actionSpace){ public StateActionHashTable(DiscreteActionSpace<A> discreteActionSpace){
table = new HashMap<>(); table = new HashMap<>();
this.actionSpace = actionSpace; this.discreteActionSpace = discreteActionSpace;
} }
/* /*
@ -25,8 +27,8 @@ public class StateActionHashTable<A extends Enum> implements StateActionTable {
method. method.
*/ */
@Override @Override
public double getValue(State state, Action action) { public double getValue(State state, A action) {
final Map<Action, Double> actionValues = table.get(state); final Map<A, Double> actionValues = table.get(state);
if (actionValues != null) { if (actionValues != null) {
return actionValues.get(action); return actionValues.get(action);
} }
@ -40,8 +42,8 @@ public class StateActionHashTable<A extends Enum> implements StateActionTable {
from the action space initialized with the default value. from the action space initialized with the default value.
*/ */
@Override @Override
public void setValue(State state, Action action, double value) { public void setValue(State state, A action, double value) {
final Map<Action, Double> actionValues; final Map<A, Double> actionValues;
if (table.containsKey(state)) { if (table.containsKey(state)) {
actionValues = table.get(state); actionValues = table.get(state);
} else { } else {
@ -52,15 +54,27 @@ public class StateActionHashTable<A extends Enum> implements StateActionTable {
} }
@Override @Override
public Map<Action, Double> getActionValues(State state) { public Map<A, Double> getActionValues(State state) {
return null; Map<A, Double> actionValues = table.get(state);
if(actionValues == null){
actionValues = createDefaultActionValues();
}
return actionValues;
} }
private Map<Action, Double> createDefaultActionValues(){ public static void main(String[] args) {
final Map<Action, Double> defaultActionValues = new HashMap<>(); DiscreteActionSpace<AntAction> da = new ListDiscreteActionSpace<>(AntAction.MOVE_RIGHT, AntAction.PICK_UP);
// for(Action action: actionSpace.getAllActions()){ StateActionTable sat = new StateActionHashTable<>(da);
// defaultActionValues.put(action, DEFAULT_VALUE); State t = new State() {
//} };
System.out.println(sat.getActionValues(t));
}
private Map<A, Double> createDefaultActionValues(){
final Map<A, Double> defaultActionValues = new HashMap<>();
for(A action: discreteActionSpace.getAllActions()){
defaultActionValues.put(action, DEFAULT_VALUE);
}
return defaultActionValues; return defaultActionValues;
} }
} }

View File

@ -2,11 +2,11 @@ package core;
import java.util.Map; import java.util.Map;
public interface StateActionTable { public interface StateActionTable<A extends Enum> {
double DEFAULT_VALUE = 0.0; double DEFAULT_VALUE = 0.0;
double getValue(State state, Action action); double getValue(State state, A action);
void setValue(State state, Action action, double value); void setValue(State state, A action, double value);
Map<Action, Double> getActionValues(State state); Map<A, Double> getActionValues(State state);
} }

View File

@ -8,7 +8,7 @@ import lombok.Setter;
@Setter @Setter
@AllArgsConstructor @AllArgsConstructor
public class StepResult { public class StepResult {
private State observation; private State state;
private double reward; private double reward;
private boolean done; private boolean done;
private String info; private String info;

View File

@ -3,10 +3,11 @@ package evironment.antGame;
import core.*; import core.*;
import evironment.antGame.gui.MainFrame; import evironment.antGame.gui.MainFrame;
import javax.swing.*; import javax.swing.*;
import java.awt.*; import java.awt.*;
public class AntWorld { public class AntWorld implements Environment<AntAction>{
/** /**
* *
*/ */
@ -53,7 +54,8 @@ public class AntWorld {
this(Constants.DEFAULT_GRID_WIDTH, Constants.DEFAULT_GRID_HEIGHT, Constants.DEFAULT_FOOD_DENSITY); this(Constants.DEFAULT_GRID_WIDTH, Constants.DEFAULT_GRID_HEIGHT, Constants.DEFAULT_FOOD_DENSITY);
} }
public StepResult step(DiscreteAction<AntAction> action){ @Override
public StepResult step(AntAction action){
AntObservation observation; AntObservation observation;
State newState; State newState;
double reward = 0; double reward = 0;
@ -77,7 +79,7 @@ public class AntWorld {
// on the starting position // on the starting position
boolean checkCompletion = false; boolean checkCompletion = false;
switch (action.getValue()) { switch (action) {
case MOVE_UP: case MOVE_UP:
potentialNextPos.y -= 1; potentialNextPos.y -= 1;
stayOnCell = false; stayOnCell = false;
@ -208,13 +210,13 @@ public class AntWorld {
public static void main(String[] args) { public static void main(String[] args) {
RNG.setSeed(1993); RNG.setSeed(1993);
AntWorld a = new AntWorld(10, 10, 0.1); AntWorld a = new AntWorld(10, 10, 0.1);
DiscreteActionSpace<AntAction> actionSpace = new DiscreteActionSpace<>(); ListDiscreteActionSpace<AntAction> actionSpace =
actionSpace.addActions(AntAction.values()); new ListDiscreteActionSpace<>(AntAction.MOVE_LEFT, AntAction.MOVE_RIGHT);
for(int i = 0; i< 1000; ++i){ for(int i = 0; i< 1000; ++i){
DiscreteAction<AntAction> selectedAction = actionSpace.getAllDiscreteActions().get(RNG.getRandom().nextInt(actionSpace.getNumberOfAction())); AntAction selectedAction = actionSpace.getAllActions().get(RNG.getRandom().nextInt(actionSpace.getNumberOfAction()));
StepResult step = a.step(selectedAction); StepResult step = a.step(selectedAction);
SwingUtilities.invokeLater(()-> a.getGui().update(selectedAction.getValue(), step)); SwingUtilities.invokeLater(()-> a.getGui().update(selectedAction, step));
try { try {
Thread.sleep(100); Thread.sleep(100);
} catch (InterruptedException e) { } catch (InterruptedException e) {