remove the Action interface in favour of Enums
This commit is contained in:
parent
8a533dda94
commit
0100f2e82a
|
@ -5,6 +5,7 @@
|
||||||
<output-test url="file://$MODULE_DIR$/build/classes/java/test" />
|
<output-test url="file://$MODULE_DIR$/build/classes/java/test" />
|
||||||
<exclude-output />
|
<exclude-output />
|
||||||
<content url="file://$MODULE_DIR$">
|
<content url="file://$MODULE_DIR$">
|
||||||
|
<sourceFolder url="file://$MODULE_DIR$/src/main/java" isTestSource="false" />
|
||||||
<excludeFolder url="file://$MODULE_DIR$/.gradle" />
|
<excludeFolder url="file://$MODULE_DIR$/.gradle" />
|
||||||
<excludeFolder url="file://$MODULE_DIR$/build" />
|
<excludeFolder url="file://$MODULE_DIR$/build" />
|
||||||
</content>
|
</content>
|
||||||
|
|
|
@ -1,8 +0,0 @@
|
||||||
package core;
|
|
||||||
|
|
||||||
public interface Action {
|
|
||||||
int getIndex();
|
|
||||||
String toString();
|
|
||||||
int hashCode();
|
|
||||||
boolean equals(Object obj);
|
|
||||||
}
|
|
|
@ -1,9 +0,0 @@
|
||||||
package core;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public interface ActionSpace<A extends Enum> {
|
|
||||||
int getNumberOfAction();
|
|
||||||
void addAction(DiscreteAction<A> a);
|
|
||||||
void addActions(A[] as);
|
|
||||||
}
|
|
|
@ -1,36 +0,0 @@
|
||||||
package core;
|
|
||||||
|
|
||||||
public class DiscreteAction<A extends Enum> implements Action{
|
|
||||||
private A action;
|
|
||||||
|
|
||||||
public DiscreteAction(A action){
|
|
||||||
this.action = action;
|
|
||||||
}
|
|
||||||
|
|
||||||
public A getValue(){
|
|
||||||
return action;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getIndex(){
|
|
||||||
return action.ordinal();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString(){
|
|
||||||
return action.toString();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int hashCode() {
|
|
||||||
return getIndex();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean equals(Object obj) {
|
|
||||||
if(obj instanceof DiscreteAction){
|
|
||||||
return getIndex() == ((DiscreteAction) obj).getIndex();
|
|
||||||
}
|
|
||||||
return super.equals(obj);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,33 +1,10 @@
|
||||||
package core;
|
package core;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
public class DiscreteActionSpace<A extends Enum> implements ActionSpace<A>{
|
public interface DiscreteActionSpace<A extends Enum> {
|
||||||
private List<DiscreteAction<A>> actions;
|
int getNumberOfAction();
|
||||||
|
void addAction(A a);
|
||||||
public DiscreteActionSpace(){
|
void addActions(A... as);
|
||||||
actions = new ArrayList<>();
|
List<A> getAllActions();
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void addAction(DiscreteAction<A> action){
|
|
||||||
actions.add(action);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void addActions(A[] as) {
|
|
||||||
for(A a : as){
|
|
||||||
actions.add(new DiscreteAction<>(a));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int getNumberOfAction(){
|
|
||||||
return actions.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<DiscreteAction<A>> getAllDiscreteActions(){
|
|
||||||
return actions;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,5 @@
|
||||||
|
package core;
|
||||||
|
|
||||||
|
public interface Environment<A extends Enum> {
|
||||||
|
StepResult step(A action);
|
||||||
|
}
|
|
@ -0,0 +1,38 @@
|
||||||
|
package core;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSpace<A> {
|
||||||
|
private List<A> actions;
|
||||||
|
|
||||||
|
public ListDiscreteActionSpace(){
|
||||||
|
actions = new ArrayList<>();
|
||||||
|
}
|
||||||
|
|
||||||
|
public ListDiscreteActionSpace(A... actions){
|
||||||
|
this.actions = new ArrayList<>(Arrays.asList(actions));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void addAction(A action){
|
||||||
|
actions.add(action);
|
||||||
|
}
|
||||||
|
|
||||||
|
@SafeVarargs
|
||||||
|
@Override
|
||||||
|
public final void addActions(A... as) {
|
||||||
|
actions.addAll(Arrays.asList(as));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<A> getAllActions() {
|
||||||
|
return actions;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getNumberOfAction(){
|
||||||
|
return actions.size();
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,19 +1,21 @@
|
||||||
package core;
|
package core;
|
||||||
|
|
||||||
|
import evironment.antGame.AntAction;
|
||||||
|
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Premise: All states have the complete action space
|
* Premise: All states have the complete action space
|
||||||
*/
|
*/
|
||||||
public class StateActionHashTable<A extends Enum> implements StateActionTable {
|
public class StateActionHashTable<A extends Enum> implements StateActionTable<A> {
|
||||||
|
|
||||||
private final Map<State, Map<Action, Double>> table;
|
private final Map<State, Map<A, Double>> table;
|
||||||
private ActionSpace<A> actionSpace;
|
private DiscreteActionSpace<A> discreteActionSpace;
|
||||||
|
|
||||||
public StateActionHashTable(ActionSpace<A> actionSpace){
|
public StateActionHashTable(DiscreteActionSpace<A> discreteActionSpace){
|
||||||
table = new HashMap<>();
|
table = new HashMap<>();
|
||||||
this.actionSpace = actionSpace;
|
this.discreteActionSpace = discreteActionSpace;
|
||||||
}
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -25,8 +27,8 @@ public class StateActionHashTable<A extends Enum> implements StateActionTable {
|
||||||
method.
|
method.
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public double getValue(State state, Action action) {
|
public double getValue(State state, A action) {
|
||||||
final Map<Action, Double> actionValues = table.get(state);
|
final Map<A, Double> actionValues = table.get(state);
|
||||||
if (actionValues != null) {
|
if (actionValues != null) {
|
||||||
return actionValues.get(action);
|
return actionValues.get(action);
|
||||||
}
|
}
|
||||||
|
@ -40,8 +42,8 @@ public class StateActionHashTable<A extends Enum> implements StateActionTable {
|
||||||
from the action space initialized with the default value.
|
from the action space initialized with the default value.
|
||||||
*/
|
*/
|
||||||
@Override
|
@Override
|
||||||
public void setValue(State state, Action action, double value) {
|
public void setValue(State state, A action, double value) {
|
||||||
final Map<Action, Double> actionValues;
|
final Map<A, Double> actionValues;
|
||||||
if (table.containsKey(state)) {
|
if (table.containsKey(state)) {
|
||||||
actionValues = table.get(state);
|
actionValues = table.get(state);
|
||||||
} else {
|
} else {
|
||||||
|
@ -52,15 +54,27 @@ public class StateActionHashTable<A extends Enum> implements StateActionTable {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Map<Action, Double> getActionValues(State state) {
|
public Map<A, Double> getActionValues(State state) {
|
||||||
return null;
|
Map<A, Double> actionValues = table.get(state);
|
||||||
|
if(actionValues == null){
|
||||||
|
actionValues = createDefaultActionValues();
|
||||||
|
}
|
||||||
|
return actionValues;
|
||||||
}
|
}
|
||||||
|
|
||||||
private Map<Action, Double> createDefaultActionValues(){
|
public static void main(String[] args) {
|
||||||
final Map<Action, Double> defaultActionValues = new HashMap<>();
|
DiscreteActionSpace<AntAction> da = new ListDiscreteActionSpace<>(AntAction.MOVE_RIGHT, AntAction.PICK_UP);
|
||||||
// for(Action action: actionSpace.getAllActions()){
|
StateActionTable sat = new StateActionHashTable<>(da);
|
||||||
// defaultActionValues.put(action, DEFAULT_VALUE);
|
State t = new State() {
|
||||||
//}
|
};
|
||||||
|
|
||||||
|
System.out.println(sat.getActionValues(t));
|
||||||
|
}
|
||||||
|
private Map<A, Double> createDefaultActionValues(){
|
||||||
|
final Map<A, Double> defaultActionValues = new HashMap<>();
|
||||||
|
for(A action: discreteActionSpace.getAllActions()){
|
||||||
|
defaultActionValues.put(action, DEFAULT_VALUE);
|
||||||
|
}
|
||||||
return defaultActionValues;
|
return defaultActionValues;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,11 +2,11 @@ package core;
|
||||||
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
|
||||||
public interface StateActionTable {
|
public interface StateActionTable<A extends Enum> {
|
||||||
double DEFAULT_VALUE = 0.0;
|
double DEFAULT_VALUE = 0.0;
|
||||||
|
|
||||||
double getValue(State state, Action action);
|
double getValue(State state, A action);
|
||||||
void setValue(State state, Action action, double value);
|
void setValue(State state, A action, double value);
|
||||||
|
|
||||||
Map<Action, Double> getActionValues(State state);
|
Map<A, Double> getActionValues(State state);
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,7 +8,7 @@ import lombok.Setter;
|
||||||
@Setter
|
@Setter
|
||||||
@AllArgsConstructor
|
@AllArgsConstructor
|
||||||
public class StepResult {
|
public class StepResult {
|
||||||
private State observation;
|
private State state;
|
||||||
private double reward;
|
private double reward;
|
||||||
private boolean done;
|
private boolean done;
|
||||||
private String info;
|
private String info;
|
||||||
|
|
|
@ -3,10 +3,11 @@ package evironment.antGame;
|
||||||
import core.*;
|
import core.*;
|
||||||
import evironment.antGame.gui.MainFrame;
|
import evironment.antGame.gui.MainFrame;
|
||||||
|
|
||||||
|
|
||||||
import javax.swing.*;
|
import javax.swing.*;
|
||||||
import java.awt.*;
|
import java.awt.*;
|
||||||
|
|
||||||
public class AntWorld {
|
public class AntWorld implements Environment<AntAction>{
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
@ -53,7 +54,8 @@ public class AntWorld {
|
||||||
this(Constants.DEFAULT_GRID_WIDTH, Constants.DEFAULT_GRID_HEIGHT, Constants.DEFAULT_FOOD_DENSITY);
|
this(Constants.DEFAULT_GRID_WIDTH, Constants.DEFAULT_GRID_HEIGHT, Constants.DEFAULT_FOOD_DENSITY);
|
||||||
}
|
}
|
||||||
|
|
||||||
public StepResult step(DiscreteAction<AntAction> action){
|
@Override
|
||||||
|
public StepResult step(AntAction action){
|
||||||
AntObservation observation;
|
AntObservation observation;
|
||||||
State newState;
|
State newState;
|
||||||
double reward = 0;
|
double reward = 0;
|
||||||
|
@ -77,7 +79,7 @@ public class AntWorld {
|
||||||
// on the starting position
|
// on the starting position
|
||||||
boolean checkCompletion = false;
|
boolean checkCompletion = false;
|
||||||
|
|
||||||
switch (action.getValue()) {
|
switch (action) {
|
||||||
case MOVE_UP:
|
case MOVE_UP:
|
||||||
potentialNextPos.y -= 1;
|
potentialNextPos.y -= 1;
|
||||||
stayOnCell = false;
|
stayOnCell = false;
|
||||||
|
@ -208,13 +210,13 @@ public class AntWorld {
|
||||||
public static void main(String[] args) {
|
public static void main(String[] args) {
|
||||||
RNG.setSeed(1993);
|
RNG.setSeed(1993);
|
||||||
AntWorld a = new AntWorld(10, 10, 0.1);
|
AntWorld a = new AntWorld(10, 10, 0.1);
|
||||||
DiscreteActionSpace<AntAction> actionSpace = new DiscreteActionSpace<>();
|
ListDiscreteActionSpace<AntAction> actionSpace =
|
||||||
actionSpace.addActions(AntAction.values());
|
new ListDiscreteActionSpace<>(AntAction.MOVE_LEFT, AntAction.MOVE_RIGHT);
|
||||||
|
|
||||||
for(int i = 0; i< 1000; ++i){
|
for(int i = 0; i< 1000; ++i){
|
||||||
DiscreteAction<AntAction> selectedAction = actionSpace.getAllDiscreteActions().get(RNG.getRandom().nextInt(actionSpace.getNumberOfAction()));
|
AntAction selectedAction = actionSpace.getAllActions().get(RNG.getRandom().nextInt(actionSpace.getNumberOfAction()));
|
||||||
StepResult step = a.step(selectedAction);
|
StepResult step = a.step(selectedAction);
|
||||||
SwingUtilities.invokeLater(()-> a.getGui().update(selectedAction.getValue(), step));
|
SwingUtilities.invokeLater(()-> a.getGui().update(selectedAction, step));
|
||||||
try {
|
try {
|
||||||
Thread.sleep(100);
|
Thread.sleep(100);
|
||||||
} catch (InterruptedException e) {
|
} catch (InterruptedException e) {
|
||||||
|
|
Loading…
Reference in New Issue