import numpy as np
import datetime

def getStateUniform():
    sumValue = np.random.randint(12, 22)
    usableAce = bool(np.random.randint(0, 2))
    cards = []
    if usableAce: # Ace count as 11
        cards.append(1)
        cards.append(sumValue - 11)
    else:
        cards.append(sumValue)
    return cards

class Player(object):
    def __init__(self, dealersCard):
        self.cards = [] # each card is a number 1-10
        self.dealersCard = dealersCard
    def CheckAceUsable(self):
        # check whether it's possible to use Ace
        baseSum = sum(self.cards)
        if 1 in self.cards and baseSum <= 11:
            return True
        else:
            return False
    def GetValue(self):
        # get the current value, if player has an usable Ace, it will be count as 11 here
        currentSum = 0
        hasAce = False

        for card in self.cards:
            if card == 1:
                hasAce = True
                currentSum += 1
            else:
                currentSum += card

        # now the currentSum is the value when we do not use an Ace (all Ace count as 1)
        if hasAce and currentSum <= 11:
            # crucial here, since using an Ace will give us 10 more points, then if we have > 11 points when counting all Ace
            # as 1, then we simply can't use an Ace.
            currentSum += 10
        return currentSum
    def Bust(self):
        return self.GetValue() > 21
    def AddCard(self, card):
        self.cards.append(card)
    def GetState(self):
        return (self.GetValue(), self.CheckAceUsable(), self.dealersCard)
    def ShouldHit(self, policyMap):
        return policyMap[self.GetState()]

class Dealer(object):
    def __init__(self, cards):
        self.cards = cards

    def AddCard(self, card):
        self.cards.append(card)

    def Bust(self):
        return self.GetValue() > 21

    def GetValue(self):
        currentSum = 0
        hasAce = 0

        for card in self.cards:
            if card == 1:
                hasAce = True
                currentSum += 1
            else:
                currentSum += card

        # now the currentSum is the value when we do not use an Ace (all Ace count as 1)
        if hasAce and currentSum <= 11:
            # crucial here, since using an Ace will give us 10 more points, then if we have > 11 points when counting all Ace
            # as 1, then we simply can't use an Ace.
            currentSum += 10
        return currentSum

    def ShouldHit(self):
        if self.GetValue() >= 17:
            return False
        else:
            return True

class StateActionInfo(object):
    def __init__(self):
        self.stateActionPairs = [ ]
        self.stateActionMap = set()

    def AddPair(self, pair):
        if pair in self.stateActionMap:
            return

        self.stateActionPairs.append(pair)
        self.stateActionMap.add(pair)

def EvaluateAndImprovePolicy(actionValueMap, policyMap, returns, stateActionPairs, reward, episode):
    AbsQUpdateErr[seed].append(0)
    for pair in stateActionPairs:
        returns[pair] += 1
        AbsQUpdateErr[seed][episode] += abs(((reward - actionValueMap[pair]) / returns[pair]))
        actionValueMap[pair] = actionValueMap[pair] + ((reward - actionValueMap[pair]) / returns[pair])

        state = pair[0]
        shouldHit = False

        if actionValueMap[(state, True)] >= actionValueMap[(state, False)]:
            shouldHit = True

        policyMap[state] = shouldHit
    AbsQUpdateErr[seed][episode] /=len(stateActionPairs)

def GenerateCard():
    card = np.random.randint(1, 14)

    if card > 9:
        return 10
    else:
        return card

def GetNewCard():
    # will return a card, J, Q, K will be returned as 10
    card = np.random.randint(1, 14)
    if card > 10:
        card = 10
    return card

def GenerateEpisode(actionValueMap, policyMap, returns, episode, uniform_initial = False, initial_update = False):
    # standard initial version
    if not uniform_initial:
        dealersCard1 = GetNewCard()
        dealer = Dealer([dealersCard1])
        player = Player(dealersCard1)
        # get 2 inital cards
        player.AddCard(GetNewCard())
        player.AddCard(GetNewCard())
        # if value is < 11, then keep getting cards
        while player.GetValue() < 11: # keep getting cards until we have at least a value of 11
            player.AddCard(GetNewCard())
    # uniform initial version
    else:
        dealersCard1 = GetNewCard()
        dealer = Dealer([dealersCard1])
        player = Player(dealersCard1)

        # player can have ace/no-ace and value is 11-21

        # has usable ace: player must have ace, base sum <= 11

        # no usable ace: 1) player has base card numbers of > 11: impossible to use ace
        # 2) player does not have ace yet: player has a base of 11, but no ace.

        sum = np.random.randint(12, 22)
        useableAce = bool(np.random.randint(0, 2))
        if useableAce:
            player.AddCard(1)
            player.AddCard(sum - 11)
        else:
            player.AddCard(sum // 3)
            player.AddCard(sum // 3)
            player.AddCard(sum - (sum // 3) * 2)
        # dealersCard1 = np.random.randint(1, 11)

    stateActionInfo = StateActionInfo() # used to store s,a pairs in the episode
    shouldHit = bool(np.random.randint(0, 2)) # initial action
    stateActionInfo.AddPair((player.GetState(), shouldHit))

    if shouldHit:
        player.AddCard(GenerateCard())
        while not player.Bust() and player.ShouldHit(policyMap):
            if not initial_update:
                stateActionInfo.AddPair((player.GetState(), True))
            player.AddCard(GenerateCard())

    if player.Bust():
        EvaluateAndImprovePolicy(actionValueMap, policyMap, returns, stateActionInfo.stateActionPairs, -1, episode)
        return

    if not initial_update:
        stateActionInfo.AddPair((player.GetState(), False))
    dealer.AddCard(GenerateCard()) # dealer get 2nd card (mandatory by rule)

    while not dealer.Bust() and dealer.ShouldHit():
        dealer.cards.append(GenerateCard())

    if dealer.Bust() or dealer.GetValue() < player.GetValue():
        EvaluateAndImprovePolicy(actionValueMap, policyMap, returns, stateActionInfo.stateActionPairs, 1, episode)
    elif dealer.GetValue() > player.GetValue():
        EvaluateAndImprovePolicy(actionValueMap, policyMap, returns, stateActionInfo.stateActionPairs, -1, episode)
    else:
        EvaluateAndImprovePolicy(actionValueMap, policyMap, returns, stateActionInfo.stateActionPairs, 0, episode)

def EvaluatePerformance(policyMap, game_number = 10000):
    performance = []
    for i in range(game_number):
        dealersCard1 = GetNewCard()
        dealer_test = Dealer([dealersCard1])
        player_test = Player(dealersCard1)
        player_test.AddCard(GetNewCard())
        player_test.AddCard(GetNewCard())
        while player_test.GetValue() < 11: # keep getting cards until we have at least a value of 11
            player_test.AddCard(GetNewCard())

        while not player_test.Bust() and player_test.ShouldHit(policyMap):
            player_test.AddCard(GenerateCard())

        if player_test.Bust():
            performance.append(-1)
        else:
            while not dealer_test.Bust() and dealer_test.ShouldHit():
                dealer_test.cards.append(GenerateCard())
            if dealer_test.Bust() or dealer_test.GetValue() < player_test.GetValue():
                performance.append(1)
            elif dealer_test.GetValue() > player_test.GetValue():
                performance.append(-1)
            else:
                performance.append(0)

    return sum(performance)/game_number

def PerformMonteCarloES(seed, uniform_initial = False, multi_update = True):
    actionValueMap = { }
    policyMap = { } # map playerState to True or False, True is hit
    returns = { }
    ground_true_policy = {}
    # init
    for usableAce in range(2):
        for playerSum in range(11, 22):
            for dealersCard in range(1, 11):
                playerState = (playerSum, bool(usableAce), dealersCard)
                actionValueMap[(playerState, False)] = 0 # Q value
                actionValueMap[(playerState, True)] = 0
                returns[(playerState, False)] = 0
                returns[(playerState, True)] = 0

                if playerSum == 20 or playerSum == 21:
                    policyMap[playerState] = False
                else:
                    policyMap[playerState] = True

                if usableAce:
                    if dealersCard < 2 or dealersCard > 8:
                        if playerSum > 18:
                            ground_true_policy[playerState] = False
                        else:
                            ground_true_policy[playerState] = True
                    else:
                        if playerSum > 17:
                            ground_true_policy[playerState] = False
                        else:
                            ground_true_policy[playerState] = True
                else:
                    if dealersCard < 2 or dealersCard > 6:
                        if playerSum > 16:
                            ground_true_policy[playerState] = False
                        else:
                            ground_true_policy[playerState] = True
                    elif dealersCard < 4:
                        if playerSum > 12:
                            ground_true_policy[playerState] = False
                        else:
                            ground_true_policy[playerState] = True
                    else:
                        if playerSum > 11:
                            ground_true_policy[playerState] = False
                        else:
                            ground_true_policy[playerState] = True


    n_episode = int(1e7)
    for i in range(n_episode):
        GenerateEpisode(actionValueMap, policyMap, returns, i, uniform_initial, not multi_update) # includes running an episode and change policy
        if (i+1) % 10 == 0 and i < int(1e6):
           score = EvaluatePerformance(policyMap, 1000)
           multiseed_performance[seed].append(score)
        if (i+1) % 100000 == 0:
           print(i)


    x11 = [ ]
    y11 = [ ]

    x12 = [ ]
    y12 = [ ]

    x21 = [ ]
    y21 = [ ]

    x22 = [ ]
    y22 = [ ]

    # for every state, check what is the policy, either add a red dot or add a blue dot
    for playerState in policyMap:
        if playerState[1]: # if usable ace
            if policyMap[playerState]: # if policy is to hit
                x11.append(playerState[2] - 1) # playerState[2] is dealer card
                y11.append(playerState[0] - 11) # playerState[0] is player card
            else:
                x12.append(playerState[2] - 1)
                y12.append(playerState[0] - 11)
        else:
            if policyMap[playerState]:
                x21.append(playerState[2] - 1)
                y21.append(playerState[0] - 11)
            else:
                x22.append(playerState[2] - 1)
                y22.append(playerState[0] - 11)

    for usableAce in range(2):
        print("usable:", usableAce)
        for playerSum in range(11, 22):
            line = ''
            for dealersCard in range(1, 11):
                playerState = (playerSum, bool(usableAce), dealersCard)
                line += '%d,%d:%.3f/%.3f\t' % (playerSum, dealersCard, actionValueMap[(playerState, True)], actionValueMap[(playerState, False)])
                # returns[(playerState, False)] = 0 # number of returns?
            print(line)
        print()

# vary 1: multi-update:S & B, uniform initialization
# vary 2: multi-update:S & B, standard initialization
# vary 3: first-update:Tsitsiklis (MIT), uniform initialization
# vary 4: first-update:Tsitsiklis (MIT), standard initialization

date = datetime.date.today()
for uniform_initialization in range(0,2):
    for multiple_update in range(0,2):
        uniform_initialization = bool(uniform_initialization)
        multiple_update = bool(multiple_update)

        multiseed_performance = [[] for seed in range(5)]
        AbsQUpdateErr = [[] for seed in range(5)]
        for seed in range(5):
            np.random.seed(seed)
            PerformMonteCarloES(seed, uniform_initialization, multiple_update) #seed, whether uniform initial, whether multi-update

        # save files
        performanceFile = np.array(multiseed_performance)
        performanceFile.tofile('Performance_UI_%s_M_%s_%s.csv' % (uniform_initialization, multiple_update, date), sep=',')

        QFile = np.array(AbsQUpdateErr)
        QFile.tofile('AbsQUpdateErr_UI_%s_M_%s_%s.csv' % (uniform_initialization, multiple_update, date), sep=',')

