import copy

from ExtensiveGame import ExtensiveGame
import numpy as np


def deal_cards(state):
    children = state.children
    infosets = []
    for child in children:
        if child.i_set not in infosets:
            infosets.append(child.i_set)
    index = np.random.choice(list(range(len(infosets))))
    infoset = infosets[index]
    return [child for child in children if child.i_set == infoset], index


def normalize_ranges(ranges):
    range_sum = np.sum(ranges)
    if range_sum == 0:
        for i in range(len(ranges)):
            ranges[i] = 1. / len(ranges)
    else:
        for i in range(len(ranges)):
            ranges[i] /= range_sum


def win_probability_from_ranges(win_probabilities, ranges):
    return np.dot(win_probabilities, ranges)


def vector_win_probability(states):
    win_probabilities = []
    for state in states:
        win_probabilities.append(state_win_probability(state))
    return win_probabilities


def state_win_probability(state):
    wp = 0
    if state.is_terminal():
        if state.value > 0:
            wp = 1
        elif state.value < 0:
            wp = 0
        else:
            wp = 0.5
    elif state.player == 2:
        for child in state.children:
            wp += state_win_probability(child)
        wp /= len(state.children)
    else:
        action = None
        for index in range(len(state.children)):
            if state.labels[index] == "c":
                action = index
                break
        wp = state_win_probability(state.children[action])
    return wp


def leduc_local_best_response_step(states, ranges, cards, strategy, last_action, pot, play_round, win_probabilities):
    assert np.abs(np.sum(ranges) - 1) < 10 ** -6, np.sum(ranges)
    player_card = cards[0]
    sample_state = states[0]
    if sample_state.player == 2:
        n_children = 5
        table_card = np.random.choice(list(range(n_children)))
        for index in range(len(states)):
            if table_card == index:
                states[index] = states[index].children[0]
            else:
                states[index] = states[index].children[table_card if table_card < index else table_card - 1]
        ranges[table_card] = 0
        if table_card >= player_card:
            table_card += 1
        cards[1] = table_card
        normalize_ranges(ranges)
        last_action[0] = str(table_card)
        play_round[0] += 1
    elif sample_state.player == 1:
        state_strategy = []
        for action in range(len(sample_state.children)):
            action_probability = 0
            for index in range(len(states)):
                action_probability += ranges[index] * strategy[states[index].i_set][action]
            state_strategy.append(action_probability)
        assert len(sample_state.children) == len(state_strategy)
        action = np.random.choice(list(range(len(sample_state.children))), p=state_strategy)
        if sample_state.labels[action] == "b":
            pot[1] = 2 * play_round[0] + pot[0]
        if sample_state.labels[action] == "c":
            pot[1] = pot[0]
        last_action[0] = sample_state.labels[action]
        for index in range(len(states)):
            ranges[index] *= strategy[states[index].i_set][action]
            states[index] = states[index].children[action]
        normalize_ranges(ranges)
    elif sample_state.player == 0:
        wp = win_probability_from_ranges(win_probabilities, ranges)
        asked = pot[1] - pot[0]
        full_pot = pot[1] + pot[0]
        call_value = wp * full_pot - (1 - wp) * asked
        bet_value = -100
        if "b" in sample_state.labels:
            bet_action = None
            for index in range(len(sample_state.children)):
                if sample_state.labels[index] == "b":
                    bet_action = index
            fp = 0
            new_ranges = copy.deepcopy(ranges)
            for index in range(len(states)):
                child = states[index].children[bet_action]
                assert child.labels[0] == "f"
                fp += ranges[index] * strategy[child.i_set][0]
                new_ranges[index] *= (1 - strategy[child.i_set][0])
            bet_wp = win_probability_from_ranges(win_probabilities, new_ranges)
            bet_value = fp * full_pot + (1 - fp) * (bet_wp * (full_pot + 2 * play_round[0]) - (1 - bet_wp) * (asked + 2 * play_round[0]))
        if bet_value > call_value:
            if bet_value > 0:
                action_text = "b"
            else:
                action_text = "f"
        else:
            if call_value > 0:
                action_text = "c"
            else:
                action_text = "f"
        if action_text not in sample_state.labels:
            action_text = "c"
        action = None
        for index in range(len(sample_state.children)):
            if sample_state.labels[index] == action_text:
                action = index
        last_action[0] = sample_state.labels[action]
        if sample_state.labels[action] == "b":
            pot[0] = 2 * play_round[0] + pot[1]
        if sample_state.labels[action] == "c":
            pot[0] = pot[1]
        for index in range(len(states)):
            states[index] = states[index].children[action]
    else:
        raise ValueError("Impossible player value.")


def leduc_local_best_response(strategy):
    leduc = ExtensiveGame()
    leduc.load("data/leduc_holdem.efg")

    states, player_card = deal_cards(leduc.root)
    cards = [player_card, None]
    ranges = [1. / len(states)] * len(states)
    last_action = [str(player_card)]
    pot = [1, 1]
    play_round = [1]
    win_probabilities = vector_win_probability(states)
    while not states[0].is_terminal():
        leduc_local_best_response_step(states, ranges, cards, strategy, last_action, pot, play_round, win_probabilities)
        if last_action != "f" and last_action != "b" and last_action != "c":
            win_probabilities = vector_win_probability(states)
    won = 0
    for index in range(len(states)):
        assert states[index].value == 0 or states[index].value == pot[1] or states[index].value == -pot[0]
        won += states[index].value * ranges[index]
    return won
