add BlackJack environment and fix save and load

- method names were swapped
This commit is contained in:
Jan Löwenstrom 2020-03-01 13:51:47 +01:00
parent cff1a4e531
commit 18a702ba62
12 changed files with 337 additions and 2 deletions

View File

@ -109,7 +109,7 @@ public class RLController<A extends Enum> implements LearningListener {
}
}
protected void saveState(String fileName) {
protected void loadState(String fileName) {
FileInputStream fis;
ObjectInputStream in;
try {
@ -124,7 +124,7 @@ public class RLController<A extends Enum> implements LearningListener {
}
}
protected void loadState(String fileName) {
protected void saveState(String fileName) {
FileOutputStream fos;
ObjectOutputStream out;
try {

View File

@ -0,0 +1,124 @@
package evironment.blackjack;
import core.Environment;
import core.State;
import core.StepResultEnvironment;
import core.gui.Visualizable;
import evironment.blackjack.cards.CardDeck;
import evironment.blackjack.cards.Rank;
import evironment.blackjack.gui.BlackJackTableComponent;
import lombok.Getter;
import javax.swing.*;
public class BlackJackTable implements Environment<PlayerAction>, Visualizable {
private CardDeck cardDeck;
@Getter
private Player player;
private Player dealer;
private int dealerSumShowing;
@Getter
private int playerSum;
@Getter
private int dealerSum;
private BlackJackTableComponent comp;
public BlackJackTable() {
cardDeck = new CardDeck();
player = new Player(0, false);
dealer = new Player(0, false);
comp = new BlackJackTableComponent(this);
}
@Override
public StepResultEnvironment step(PlayerAction action) {
boolean done = false;
int reward = 0;
if(action == PlayerAction.HIT) {
obtainCard(player);
playerSum = calculateSum(player);
// bust
if(playerSum > 21) {
done = true;
reward = -1;
}
} else if(action == PlayerAction.STICK) {
done = true;
// play out the game
obtainCard(dealer);
// do not change the initial dealerSum that is important for the state
dealerSum = calculateSum(dealer);
// fixed strategy of hitting until sum of 17 or greater
while(dealerSum < Config.DEALER_HOLD_BOUND) {
obtainCard(dealer);
dealerSum = calculateSum(dealer);
comp.repaint();
}
// dealer went bust, player wins
if(dealerSum > 21) {
reward = +1;
} else if(dealerSum == 21 && playerSum == 21) {
// draw; player and dealer got 21
reward = 0;
} else {
int playerDiff = 21 - playerSum;
int dealerDiff = 21 - dealerSum;
// reward based on who is closer to 21
reward = Integer.compare(dealerDiff, playerDiff);
}
}
return new StepResultEnvironment(new TableState(playerSum, dealerSumShowing, player.isUsableAce()), reward, done, "");
}
@Override
public State reset() {
player.setHandValue(0);
player.setUsableAce(false);
dealer.setHandValue(0);
dealer.setUsableAce(false);
// player gets two cards
obtainCard(player);
obtainCard(player);
playerSum = calculateSum(player);
// dealer is only showing one card
obtainCard(dealer);
dealerSumShowing = dealerSum = calculateSum(dealer);
return new TableState(playerSum, dealerSumShowing, player.isUsableAce());
}
private int calculateSum(Player p) {
if(p.isUsableAce()) {
return p.getHandValue() + 10;
} else {
return p.getHandValue();
}
}
private void obtainCard(Player p) {
Rank rank = cardDeck.nextCard().getRank();
if(rank == Rank.JACK || rank == Rank.QUEEN || rank == Rank.KING) {
p.addValue(10);
} else if(rank == Rank.ACE) {
if(p.getHandValue() + 11 <= 21) {
p.setUsableAce(true);
}
p.addValue(1);
} else {
p.addValue(rank.ordinal() + 2);
}
if(p.isUsableAce() && p.getHandValue() + 10 > 21) {
p.setUsableAce(false);
}
}
@Override
public JComponent visualize() {
return comp;
}
}

View File

@ -0,0 +1,7 @@
package evironment.blackjack;
public class Config {
public static final int DEALER_HOLD_BOUND = 17;
public static final int COMPONENT_WIDTH = 400;
public static final int COMPONENT_HEIGHT = 400;
}

View File

@ -0,0 +1,17 @@
package evironment.blackjack;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Setter;
@Getter
@Setter
@AllArgsConstructor
public class Player {
private int handValue;
private boolean usableAce;
public void addValue(int value) {
handValue += value;
}
}

View File

@ -0,0 +1,6 @@
package evironment.blackjack;
public enum PlayerAction {
HIT,
STICK
}

View File

@ -0,0 +1,46 @@
package evironment.blackjack;
import core.State;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Setter;
import java.io.Serializable;
import java.util.Objects;
@AllArgsConstructor
@Getter
@Setter
public class TableState implements State, Serializable {
// between 12 to 21
private int playerSum;
// dealer showing one card on player's turn;
// A = 1
private int dealerCardValue;
// if player holds an ace without going bust
private boolean usableAce;
@Override
public String toString() {
return "PlayerState{" +
"handValue=" + playerSum +
", dealerCardValue=" + dealerCardValue +
", usableAce=" + usableAce +
'}';
}
@Override
public boolean equals(Object o) {
if(this == o) return true;
if(!(o instanceof TableState)) return false;
TableState that = (TableState) o;
return playerSum == that.playerSum &&
usableAce == that.usableAce &&
dealerCardValue == that.dealerCardValue;
}
@Override
public int hashCode() {
return Objects.hash(playerSum, dealerCardValue, usableAce);
}
}

View File

@ -0,0 +1,25 @@
package evironment.blackjack.cards;
public class Card {
private Suit suit;
private Rank rank;
public Card(Suit suit, Rank rank) {
this.suit = suit;
this.rank = rank;
}
public Suit getSuit() {
return suit;
}
public Rank getRank() {
return rank;
}
@Override
public boolean equals(Object o) {
return (o instanceof Card && ((Card) o).rank == rank && ((Card) o).suit == suit);
}
}

View File

@ -0,0 +1,34 @@
package evironment.blackjack.cards;
import core.RNG;
import java.util.ArrayList;
public class CardDeck {
private ArrayList<Card> cards;
public CardDeck() {
cards = new ArrayList<>(Suit.values().length * Rank.values().length);
for(Suit s : Suit.values()) {
for(Rank r : Rank.values()) {
Card c = new Card(s, r);
cards.add(c);
}
}
}
/**
* We assume that cards are dealt from an infinite deck (i.e., with replacement)
* so that there is no advantage to keeping track of the cards already dealt.
* Therefore no card is removed from the deck.
*
* @return next card
*/
public Card nextCard() {
/*
nextInt(int bound) returns random int value from (inclusive) 0
and EXCLUSIVE! bound
*/
return cards.get(RNG.getRandom().nextInt(cards.size()));
}
}

View File

@ -0,0 +1,17 @@
package evironment.blackjack.cards;
public enum Rank {
TWO,
THREE,
FOUR,
FIVE,
SIX,
SEVEN,
EIGHT,
NINE,
TEN,
JACK,
QUEEN,
KING,
ACE;
}

View File

@ -0,0 +1,8 @@
package evironment.blackjack.cards;
public enum Suit {
SPADES,
HEARTS,
DIAMONDS,
CLUBS,
}

View File

@ -0,0 +1,25 @@
package evironment.blackjack.gui;
import evironment.blackjack.BlackJackTable;
import evironment.blackjack.Config;
import javax.swing.*;
import java.awt.*;
public class BlackJackTableComponent extends JComponent {
private BlackJackTable blackJackTable;
public BlackJackTableComponent(BlackJackTable blackJackTable) {
this.blackJackTable = blackJackTable;
setPreferredSize(new Dimension(Config.COMPONENT_WIDTH, Config.COMPONENT_HEIGHT));
setVisible(true);
}
@Override
protected void paintComponent(Graphics g) {
super.paintComponent(g);
g.setColor(Color.BLACK);
g.drawString(blackJackTable.getPlayerSum() + " " + blackJackTable.getPlayer().isUsableAce(), 150, 300);
g.drawString(blackJackTable.getDealerSum() + "", 150, 150);
}
}

View File

@ -0,0 +1,26 @@
package example;
import core.RNG;
import core.algo.Method;
import core.controller.RLController;
import core.controller.RLControllerGUI;
import evironment.blackjack.BlackJackTable;
import evironment.blackjack.PlayerAction;
public class BlackJack {
public static void main(String[] args) {
RNG.setSeed(55);
RLController<PlayerAction> rl = new RLControllerGUI<>(
new BlackJackTable(),
Method.MC_CONTROL_EGREEDY,
PlayerAction.values());
rl.setDelay(1000);
rl.setDiscountFactor(1f);
rl.setEpsilon(0.1f);
rl.setLearningRate(0.5f);
rl.setNrOfEpisodes(1000);
rl.start();
}
}