remove the Action interface in favour of Enums
This commit is contained in:
parent
8a533dda94
commit
0100f2e82a
|
@ -5,6 +5,7 @@
|
|||
<output-test url="file://$MODULE_DIR$/build/classes/java/test" />
|
||||
<exclude-output />
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<sourceFolder url="file://$MODULE_DIR$/src/main/java" isTestSource="false" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/.gradle" />
|
||||
<excludeFolder url="file://$MODULE_DIR$/build" />
|
||||
</content>
|
||||
|
|
|
@ -1,8 +0,0 @@
|
|||
package core;
|
||||
|
||||
public interface Action {
|
||||
int getIndex();
|
||||
String toString();
|
||||
int hashCode();
|
||||
boolean equals(Object obj);
|
||||
}
|
|
@ -1,9 +0,0 @@
|
|||
package core;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface ActionSpace<A extends Enum> {
|
||||
int getNumberOfAction();
|
||||
void addAction(DiscreteAction<A> a);
|
||||
void addActions(A[] as);
|
||||
}
|
|
@ -1,36 +0,0 @@
|
|||
package core;
|
||||
|
||||
public class DiscreteAction<A extends Enum> implements Action{
|
||||
private A action;
|
||||
|
||||
public DiscreteAction(A action){
|
||||
this.action = action;
|
||||
}
|
||||
|
||||
public A getValue(){
|
||||
return action;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getIndex(){
|
||||
return action.ordinal();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(){
|
||||
return action.toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return getIndex();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if(obj instanceof DiscreteAction){
|
||||
return getIndex() == ((DiscreteAction) obj).getIndex();
|
||||
}
|
||||
return super.equals(obj);
|
||||
}
|
||||
}
|
|
@ -1,33 +1,10 @@
|
|||
package core;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class DiscreteActionSpace<A extends Enum> implements ActionSpace<A>{
|
||||
private List<DiscreteAction<A>> actions;
|
||||
|
||||
public DiscreteActionSpace(){
|
||||
actions = new ArrayList<>();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addAction(DiscreteAction<A> action){
|
||||
actions.add(action);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addActions(A[] as) {
|
||||
for(A a : as){
|
||||
actions.add(new DiscreteAction<>(a));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getNumberOfAction(){
|
||||
return actions.size();
|
||||
}
|
||||
|
||||
public List<DiscreteAction<A>> getAllDiscreteActions(){
|
||||
return actions;
|
||||
}
|
||||
public interface DiscreteActionSpace<A extends Enum> {
|
||||
int getNumberOfAction();
|
||||
void addAction(A a);
|
||||
void addActions(A... as);
|
||||
List<A> getAllActions();
|
||||
}
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
package core;
|
||||
|
||||
public interface Environment<A extends Enum> {
|
||||
StepResult step(A action);
|
||||
}
|
|
@ -0,0 +1,38 @@
|
|||
package core;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
public class ListDiscreteActionSpace<A extends Enum> implements DiscreteActionSpace<A> {
|
||||
private List<A> actions;
|
||||
|
||||
public ListDiscreteActionSpace(){
|
||||
actions = new ArrayList<>();
|
||||
}
|
||||
|
||||
public ListDiscreteActionSpace(A... actions){
|
||||
this.actions = new ArrayList<>(Arrays.asList(actions));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addAction(A action){
|
||||
actions.add(action);
|
||||
}
|
||||
|
||||
@SafeVarargs
|
||||
@Override
|
||||
public final void addActions(A... as) {
|
||||
actions.addAll(Arrays.asList(as));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<A> getAllActions() {
|
||||
return actions;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getNumberOfAction(){
|
||||
return actions.size();
|
||||
}
|
||||
}
|
|
@ -1,19 +1,21 @@
|
|||
package core;
|
||||
|
||||
import evironment.antGame.AntAction;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Premise: All states have the complete action space
|
||||
*/
|
||||
public class StateActionHashTable<A extends Enum> implements StateActionTable {
|
||||
public class StateActionHashTable<A extends Enum> implements StateActionTable<A> {
|
||||
|
||||
private final Map<State, Map<Action, Double>> table;
|
||||
private ActionSpace<A> actionSpace;
|
||||
private final Map<State, Map<A, Double>> table;
|
||||
private DiscreteActionSpace<A> discreteActionSpace;
|
||||
|
||||
public StateActionHashTable(ActionSpace<A> actionSpace){
|
||||
public StateActionHashTable(DiscreteActionSpace<A> discreteActionSpace){
|
||||
table = new HashMap<>();
|
||||
this.actionSpace = actionSpace;
|
||||
this.discreteActionSpace = discreteActionSpace;
|
||||
}
|
||||
|
||||
/*
|
||||
|
@ -25,8 +27,8 @@ public class StateActionHashTable<A extends Enum> implements StateActionTable {
|
|||
method.
|
||||
*/
|
||||
@Override
|
||||
public double getValue(State state, Action action) {
|
||||
final Map<Action, Double> actionValues = table.get(state);
|
||||
public double getValue(State state, A action) {
|
||||
final Map<A, Double> actionValues = table.get(state);
|
||||
if (actionValues != null) {
|
||||
return actionValues.get(action);
|
||||
}
|
||||
|
@ -40,8 +42,8 @@ public class StateActionHashTable<A extends Enum> implements StateActionTable {
|
|||
from the action space initialized with the default value.
|
||||
*/
|
||||
@Override
|
||||
public void setValue(State state, Action action, double value) {
|
||||
final Map<Action, Double> actionValues;
|
||||
public void setValue(State state, A action, double value) {
|
||||
final Map<A, Double> actionValues;
|
||||
if (table.containsKey(state)) {
|
||||
actionValues = table.get(state);
|
||||
} else {
|
||||
|
@ -52,15 +54,27 @@ public class StateActionHashTable<A extends Enum> implements StateActionTable {
|
|||
}
|
||||
|
||||
@Override
|
||||
public Map<Action, Double> getActionValues(State state) {
|
||||
return null;
|
||||
public Map<A, Double> getActionValues(State state) {
|
||||
Map<A, Double> actionValues = table.get(state);
|
||||
if(actionValues == null){
|
||||
actionValues = createDefaultActionValues();
|
||||
}
|
||||
return actionValues;
|
||||
}
|
||||
|
||||
private Map<Action, Double> createDefaultActionValues(){
|
||||
final Map<Action, Double> defaultActionValues = new HashMap<>();
|
||||
// for(Action action: actionSpace.getAllActions()){
|
||||
// defaultActionValues.put(action, DEFAULT_VALUE);
|
||||
//}
|
||||
public static void main(String[] args) {
|
||||
DiscreteActionSpace<AntAction> da = new ListDiscreteActionSpace<>(AntAction.MOVE_RIGHT, AntAction.PICK_UP);
|
||||
StateActionTable sat = new StateActionHashTable<>(da);
|
||||
State t = new State() {
|
||||
};
|
||||
|
||||
System.out.println(sat.getActionValues(t));
|
||||
}
|
||||
private Map<A, Double> createDefaultActionValues(){
|
||||
final Map<A, Double> defaultActionValues = new HashMap<>();
|
||||
for(A action: discreteActionSpace.getAllActions()){
|
||||
defaultActionValues.put(action, DEFAULT_VALUE);
|
||||
}
|
||||
return defaultActionValues;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,11 +2,11 @@ package core;
|
|||
|
||||
import java.util.Map;
|
||||
|
||||
public interface StateActionTable {
|
||||
public interface StateActionTable<A extends Enum> {
|
||||
double DEFAULT_VALUE = 0.0;
|
||||
|
||||
double getValue(State state, Action action);
|
||||
void setValue(State state, Action action, double value);
|
||||
double getValue(State state, A action);
|
||||
void setValue(State state, A action, double value);
|
||||
|
||||
Map<Action, Double> getActionValues(State state);
|
||||
Map<A, Double> getActionValues(State state);
|
||||
}
|
||||
|
|
|
@ -8,7 +8,7 @@ import lombok.Setter;
|
|||
@Setter
|
||||
@AllArgsConstructor
|
||||
public class StepResult {
|
||||
private State observation;
|
||||
private State state;
|
||||
private double reward;
|
||||
private boolean done;
|
||||
private String info;
|
||||
|
|
|
@ -3,10 +3,11 @@ package evironment.antGame;
|
|||
import core.*;
|
||||
import evironment.antGame.gui.MainFrame;
|
||||
|
||||
|
||||
import javax.swing.*;
|
||||
import java.awt.*;
|
||||
|
||||
public class AntWorld {
|
||||
public class AntWorld implements Environment<AntAction>{
|
||||
/**
|
||||
*
|
||||
*/
|
||||
|
@ -53,7 +54,8 @@ public class AntWorld {
|
|||
this(Constants.DEFAULT_GRID_WIDTH, Constants.DEFAULT_GRID_HEIGHT, Constants.DEFAULT_FOOD_DENSITY);
|
||||
}
|
||||
|
||||
public StepResult step(DiscreteAction<AntAction> action){
|
||||
@Override
|
||||
public StepResult step(AntAction action){
|
||||
AntObservation observation;
|
||||
State newState;
|
||||
double reward = 0;
|
||||
|
@ -77,7 +79,7 @@ public class AntWorld {
|
|||
// on the starting position
|
||||
boolean checkCompletion = false;
|
||||
|
||||
switch (action.getValue()) {
|
||||
switch (action) {
|
||||
case MOVE_UP:
|
||||
potentialNextPos.y -= 1;
|
||||
stayOnCell = false;
|
||||
|
@ -208,13 +210,13 @@ public class AntWorld {
|
|||
public static void main(String[] args) {
|
||||
RNG.setSeed(1993);
|
||||
AntWorld a = new AntWorld(10, 10, 0.1);
|
||||
DiscreteActionSpace<AntAction> actionSpace = new DiscreteActionSpace<>();
|
||||
actionSpace.addActions(AntAction.values());
|
||||
ListDiscreteActionSpace<AntAction> actionSpace =
|
||||
new ListDiscreteActionSpace<>(AntAction.MOVE_LEFT, AntAction.MOVE_RIGHT);
|
||||
|
||||
for(int i = 0; i< 1000; ++i){
|
||||
DiscreteAction<AntAction> selectedAction = actionSpace.getAllDiscreteActions().get(RNG.getRandom().nextInt(actionSpace.getNumberOfAction()));
|
||||
AntAction selectedAction = actionSpace.getAllActions().get(RNG.getRandom().nextInt(actionSpace.getNumberOfAction()));
|
||||
StepResult step = a.step(selectedAction);
|
||||
SwingUtilities.invokeLater(()-> a.getGui().update(selectedAction.getValue(), step));
|
||||
SwingUtilities.invokeLater(()-> a.getGui().update(selectedAction, step));
|
||||
try {
|
||||
Thread.sleep(100);
|
||||
} catch (InterruptedException e) {
|
||||
|
|
Loading…
Reference in New Issue