remove the Action interface in favour of Enums

This commit is contained in:
Jan Löwenstrom 2019-12-09 17:30:14 +01:00
parent 8a533dda94
commit 0100f2e82a
11 changed files with 93 additions and 109 deletions

View File

@ -5,6 +5,7 @@
<output-test url="file://$MODULE_DIR$/build/classes/java/test" />
<exclude-output />
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/src/main/java" isTestSource="false" />
<excludeFolder url="file://$MODULE_DIR$/.gradle" />
<excludeFolder url="file://$MODULE_DIR$/build" />
</content>

View File

@ -1,8 +0,0 @@
package core;
public interface Action {
int getIndex();
String toString();
int hashCode();
boolean equals(Object obj);
}

View File

@ -1,9 +0,0 @@
package core;
import java.util.List;
public interface ActionSpace<A extends Enum> {
int getNumberOfAction();
void addAction(DiscreteAction<A> a);
void addActions(A[] as);
}

View File

@ -1,36 +0,0 @@
package core;
public class DiscreteAction<A extends Enum> implements Action{
private A action;
public DiscreteAction(A action){
this.action = action;
}
public A getValue(){
return action;
}
@Override
public int getIndex(){
return action.ordinal();
}
@Override
public String toString(){
return action.toString();
}
@Override
public int hashCode() {
return getIndex();
}
@Override
public boolean equals(Object obj) {
if(obj instanceof DiscreteAction){
return getIndex() == ((DiscreteAction) obj).getIndex();
}
return super.equals(obj);
}
}

View File

@ -1,33 +1,10 @@
package core;
import java.util.ArrayList;
import java.util.List;
public class DiscreteActionSpace<A extends Enum> implements ActionSpace<A>{
private List<DiscreteAction<A>> actions;
public DiscreteActionSpace(){
actions = new ArrayList<>();
}
@Override
public void addAction(DiscreteAction<A> action){
actions.add(action);
}
@Override
public void addActions(A[] as) {
for(A a : as){
actions.add(new DiscreteAction<>(a));
}
}
@Override
public int getNumberOfAction(){
return actions.size();
}
public List<DiscreteAction<A>> getAllDiscreteActions(){
return actions;
}
public interface DiscreteActionSpace<A extends Enum> {
int getNumberOfAction();
void addAction(A a);
void addActions(A... as);
List<A> getAllActions();
}

View File

@ -0,0 +1,5 @@
package core;
public interface Environment<A extends Enum> {
StepResult step(A action);
}

View File

@ -0,0 +1,38 @@
package core;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSpace<A> {
private List<A> actions;
public ListDiscreteActionSpace(){
actions = new ArrayList<>();
}
public ListDiscreteActionSpace(A... actions){
this.actions = new ArrayList<>(Arrays.asList(actions));
}
@Override
public void addAction(A action){
actions.add(action);
}
@SafeVarargs
@Override
public final void addActions(A... as) {
actions.addAll(Arrays.asList(as));
}
@Override
public List<A> getAllActions() {
return actions;
}
@Override
public int getNumberOfAction(){
return actions.size();
}
}

View File

@ -1,19 +1,21 @@
package core;
import evironment.antGame.AntAction;
import java.util.HashMap;
import java.util.Map;
/**
* Premise: All states have the complete action space
*/
public class StateActionHashTable<A extends Enum> implements StateActionTable {
public class StateActionHashTable<A extends Enum> implements StateActionTable<A> {
private final Map<State, Map<Action, Double>> table;
private ActionSpace<A> actionSpace;
private final Map<State, Map<A, Double>> table;
private DiscreteActionSpace<A> discreteActionSpace;
public StateActionHashTable(ActionSpace<A> actionSpace){
public StateActionHashTable(DiscreteActionSpace<A> discreteActionSpace){
table = new HashMap<>();
this.actionSpace = actionSpace;
this.discreteActionSpace = discreteActionSpace;
}
/*
@ -25,8 +27,8 @@ public class StateActionHashTable<A extends Enum> implements StateActionTable {
method.
*/
@Override
public double getValue(State state, Action action) {
final Map<Action, Double> actionValues = table.get(state);
public double getValue(State state, A action) {
final Map<A, Double> actionValues = table.get(state);
if (actionValues != null) {
return actionValues.get(action);
}
@ -40,8 +42,8 @@ public class StateActionHashTable<A extends Enum> implements StateActionTable {
from the action space initialized with the default value.
*/
@Override
public void setValue(State state, Action action, double value) {
final Map<Action, Double> actionValues;
public void setValue(State state, A action, double value) {
final Map<A, Double> actionValues;
if (table.containsKey(state)) {
actionValues = table.get(state);
} else {
@ -52,15 +54,27 @@ public class StateActionHashTable<A extends Enum> implements StateActionTable {
}
@Override
public Map<Action, Double> getActionValues(State state) {
return null;
public Map<A, Double> getActionValues(State state) {
Map<A, Double> actionValues = table.get(state);
if(actionValues == null){
actionValues = createDefaultActionValues();
}
return actionValues;
}
private Map<Action, Double> createDefaultActionValues(){
final Map<Action, Double> defaultActionValues = new HashMap<>();
// for(Action action: actionSpace.getAllActions()){
// defaultActionValues.put(action, DEFAULT_VALUE);
//}
public static void main(String[] args) {
DiscreteActionSpace<AntAction> da = new ListDiscreteActionSpace<>(AntAction.MOVE_RIGHT, AntAction.PICK_UP);
StateActionTable sat = new StateActionHashTable<>(da);
State t = new State() {
};
System.out.println(sat.getActionValues(t));
}
private Map<A, Double> createDefaultActionValues(){
final Map<A, Double> defaultActionValues = new HashMap<>();
for(A action: discreteActionSpace.getAllActions()){
defaultActionValues.put(action, DEFAULT_VALUE);
}
return defaultActionValues;
}
}

View File

@ -2,11 +2,11 @@ package core;
import java.util.Map;
public interface StateActionTable {
public interface StateActionTable<A extends Enum> {
double DEFAULT_VALUE = 0.0;
double getValue(State state, Action action);
void setValue(State state, Action action, double value);
double getValue(State state, A action);
void setValue(State state, A action, double value);
Map<Action, Double> getActionValues(State state);
Map<A, Double> getActionValues(State state);
}

View File

@ -8,7 +8,7 @@ import lombok.Setter;
@Setter
@AllArgsConstructor
public class StepResult {
private State observation;
private State state;
private double reward;
private boolean done;
private String info;

View File

@ -3,10 +3,11 @@ package evironment.antGame;
import core.*;
import evironment.antGame.gui.MainFrame;
import javax.swing.*;
import java.awt.*;
public class AntWorld {
public class AntWorld implements Environment<AntAction>{
/**
*
*/
@ -53,7 +54,8 @@ public class AntWorld {
this(Constants.DEFAULT_GRID_WIDTH, Constants.DEFAULT_GRID_HEIGHT, Constants.DEFAULT_FOOD_DENSITY);
}
public StepResult step(DiscreteAction<AntAction> action){
@Override
public StepResult step(AntAction action){
AntObservation observation;
State newState;
double reward = 0;
@ -77,7 +79,7 @@ public class AntWorld {
// on the starting position
boolean checkCompletion = false;
switch (action.getValue()) {
switch (action) {
case MOVE_UP:
potentialNextPos.y -= 1;
stayOnCell = false;
@ -208,13 +210,13 @@ public class AntWorld {
public static void main(String[] args) {
RNG.setSeed(1993);
AntWorld a = new AntWorld(10, 10, 0.1);
DiscreteActionSpace<AntAction> actionSpace = new DiscreteActionSpace<>();
actionSpace.addActions(AntAction.values());
ListDiscreteActionSpace<AntAction> actionSpace =
new ListDiscreteActionSpace<>(AntAction.MOVE_LEFT, AntAction.MOVE_RIGHT);
for(int i = 0; i< 1000; ++i){
DiscreteAction<AntAction> selectedAction = actionSpace.getAllDiscreteActions().get(RNG.getRandom().nextInt(actionSpace.getNumberOfAction()));
AntAction selectedAction = actionSpace.getAllActions().get(RNG.getRandom().nextInt(actionSpace.getNumberOfAction()));
StepResult step = a.step(selectedAction);
SwingUtilities.invokeLater(()-> a.getGui().update(selectedAction.getValue(), step));
SwingUtilities.invokeLater(()-> a.getGui().update(selectedAction, step));
try {
Thread.sleep(100);
} catch (InterruptedException e) {