from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import math
import random
import numpy as np
from scipy import stats
import torch
import copy
import torch.nn as nn
import torch.nn.functional as F
from open_spiel.python import policy
import pyspiel

AdvantageMemory = collections.namedtuple("AdvantageMemory", "info_state iteration advantage")
StrategyMemory = collections.namedtuple("StrategyMemory", "info_state iteration strategy_action_probs")


class SonnetLinear(nn.Module):
    """A Sonnet linear module.

    Always includes biases and only supports ReLU activations.
    """
    def __init__(self, in_size, out_size, activate_relu=True):
        """Creates a Sonnet linear layer.

        Args:
          in_size: (int) number of inputs
          out_size: (int) number of outputs
          activate_relu: (bool) whether to include a ReLU activation layer
        """
        super(SonnetLinear, self).__init__()
        self._activate_relu = activate_relu
        self._in_size = in_size
        self._out_size = out_size       
        self._weight = None
        self._bias = None
        self.reset()

    def forward(self, tensor):
        y = F.linear(tensor, self._weight, self._bias)
        return F.relu(y) if self._activate_relu else y

    def reset(self):
        stddev = 1.0 / math.sqrt(self._in_size)
        mean = 0
        lower = (-2 * stddev - mean) / stddev
        upper = (2 * stddev - mean) / stddev
        self._weight = nn.Parameter(
            torch.Tensor(
                stats.truncnorm.rvs(
                    lower,
                    upper,
                    loc=mean,
                    scale=stddev,
                    size=[self._out_size, self._in_size])))
        self._bias = nn.Parameter(torch.zeros([self._out_size]))


class MLP(nn.Module):
    """A simple network built from nn.linear layers."""
    def __init__(self,
                 input_size,
                 hidden_sizes,
                 output_size,
                 activate_final=False):
        """Create the MLP.

        Args:
          input_size: (int) number of inputs
          hidden_sizes: (list) sizes (number of units) of each hidden layer
          output_size: (int) number of outputs
          activate_final: (bool) should final layer should include a ReLU
        """

        super(MLP, self).__init__()
        self._layers = []
        # Hidden layers
        for size in hidden_sizes:
            self._layers.append(SonnetLinear(in_size=input_size, out_size=size))
            input_size = size
        # Output layer
        self._layers.append(
            SonnetLinear(
                in_size=input_size,
                out_size=output_size,
                activate_relu=activate_final))

        self.model = nn.ModuleList(self._layers)

    def forward(self, x):
        for layer in self.model:
            x = layer(x)
        return x

    def reset(self):
        for layer in self._layers:
            layer.reset()


class ReservoirBuffer(object):
    def __init__(self, reservoir_buffer_capacity):
        self._reservoir_buffer_capacity = reservoir_buffer_capacity
        self._data = []
        self._add_calls = 0

    def add(self, element):
        """Potentially adds `element` to the reservoir buffer.

        Args:
          element: data to be added to the reservoir buffer.
        """
        if len(self._data) < self._reservoir_buffer_capacity:
            self._data.append(element)
        else:
            idx = np.random.randint(0, self._add_calls + 1)
            if idx < self._reservoir_buffer_capacity:
                self._data[idx] = element
        self._add_calls += 1

    def sample(self, num_samples):
        if len(self._data) < num_samples:
            raise ValueError("{} elements could not be sampled from size {}".format(
                num_samples, len(self._data)))
        return random.sample(self._data, num_samples)

    def clear(self):
        self._data = []
        self._add_calls = 0

    def __len__(self):
        return len(self._data)

    def __iter__(self):
        return iter(self._data)


class DeepCFRSolver(policy.Policy):
    def __init__(self,
                 game,
                 env_model,
                 device,
                 policy_network_layers=(256, 256),
                 advantage_network_layers=(128, 128),
                 num_iterations: int = 100,
                 num_traversals: int = 20,
                 learning_rate: float = 1e-4,
                 batch_size_advantage=None,
                 batch_size_strategy=None,
                 memory_capacity: int = int(1e6),
                 policy_network_train_steps: int = 1,
                 advantage_network_train_steps: int = 1,
                 reinitialize_advantage_networks: bool = True):

        all_players = list(range(game.num_players()))
        super(DeepCFRSolver, self).__init__(game, all_players)
        self._game = game
        self.env_model = env_model
        self.device = device

        if game.get_type().dynamics == pyspiel.GameType.Dynamics.SIMULTANEOUS:
            # `_traverse_game_tree` does not take into account this option.
            raise ValueError("Simulatenous games are not supported.")
        self._batch_size_advantage = batch_size_advantage
        self._batch_size_strategy = batch_size_strategy
        self._policy_network_train_steps = policy_network_train_steps
        self._advantage_network_train_steps = advantage_network_train_steps
        self._num_players = game.num_players()
        self._root_node = self._game.new_initial_state()
        self._embedding_size = len(self._root_node.information_state_tensor(0))
        self._num_iterations = num_iterations
        self._num_traversals = num_traversals
        self._reinitialize_advantage_networks = reinitialize_advantage_networks
        self._num_actions = game.num_distinct_actions()
        self._iteration = 1

        # Define strategy network, loss & memory.
        self._strategy_memories = ReservoirBuffer(memory_capacity)
        self._policy_network = MLP(self._embedding_size,
                                   list(policy_network_layers),
                                   self._num_actions).to(self.device)
        # Illegal actions are handled in the traversal code where expected payoff
        # and sampled regret is computed from the advantage networks.
        self._policy_sm = nn.Softmax(dim=-1).to(self.device)
        self._loss_policy = nn.MSELoss()
        self._optimizer_policy = torch.optim.Adam(
            self._policy_network.parameters(), lr=learning_rate)

        # Define advantage network, loss & memory. (One per player)
        self._advantage_memories = [
            ReservoirBuffer(memory_capacity) for _ in range(self._num_players)
        ]
        self._advantage_networks = [
            MLP(self._embedding_size, list(advantage_network_layers),
                self._num_actions).to(self.device) for _ in range(self._num_players)
        ]
        self._loss_advantages = nn.MSELoss(reduction="mean")
        self._optimizer_advantages = []
        for p in range(self._num_players):
            self._optimizer_advantages.append(
                torch.optim.Adam(
                    self._advantage_networks[p].parameters(), lr=learning_rate))
        self._learning_rate = learning_rate

    @property
    def advantage_buffers(self):
        return self._advantage_memories

    @property
    def strategy_buffer(self):
        return self._strategy_memories

    def clear_advantage_buffers(self):
        for p in range(self._num_players):
            self._advantage_memories[p].clear()

    def reinitialize_advantage_network(self, player):
        self._advantage_networks[player].reset()
        self._optimizer_advantages[player] = torch.optim.Adam(
            self._advantage_networks[player].parameters(), lr=self._learning_rate)

    def reinitialize_advantage_networks(self):
        for p in range(self._num_players):
            self.reinitialize_advantage_network(p)

    def solve(self):
        advantage_losses = collections.defaultdict(list)
        for _ in range(self._num_iterations):
            for p in range(self._num_players):
                for _ in range(self._num_traversals):
                    self._traverse_game_tree(self._root_node, p)
                if self._reinitialize_advantage_networks:
                    # Re-initialize advantage network for player and train from scratch.
                    self.reinitialize_advantage_network(p)
                # Re-initialize advantage networks and train from scratch.
                advantage_losses[p].append(self._learn_advantage_network(p))
            self._iteration += 1
            # Train policy network.
        policy_loss = self._learn_strategy_network()
        return self._policy_network, advantage_losses, policy_loss

    def mb_traverse_game_tree(self, time_step, player, game_length, environment_steps=0):
        expected_payoff = collections.defaultdict(float)

        if time_step.last() or environment_steps > game_length:
            # Terminal state get returns.
            return time_step.rewards[player]

        elif time_step.observations["current_player"] == player:
            sampled_regret = collections.defaultdict(float)
            _, strategy = self.mb_sample_action_from_advantage(time_step, player)
            legal_actions = time_step.observations["legal_actions"][player]
            time_step_temp = copy.deepcopy(time_step)
            for action in legal_actions:
                time_step = copy.deepcopy(time_step_temp)
                next_time_step = self.env_model.step(time_step, [action])
                expected_payoff[action] = self.mb_traverse_game_tree(next_time_step, player, game_length,
                                                                     environment_steps+1)

            cfv = 0
            for a_ in legal_actions:
                cfv += strategy[a_] * expected_payoff[a_]
            for action in legal_actions:
                sampled_regret[action] = expected_payoff[action]
                sampled_regret[action] -= cfv
            sampled_regret_arr = [0] * self._num_actions
            for action in sampled_regret:
                sampled_regret_arr[action] = sampled_regret[action]

            information_state_tensor = time_step.observations["info_state"][player]
            self._advantage_memories[player].add(AdvantageMemory(information_state_tensor, self._iteration,
                                                                 sampled_regret_arr))
            return cfv
        else:
            other_player = time_step.observations["current_player"]
            _, strategy = self.mb_sample_action_from_advantage(time_step, other_player)
            # Recompute distribution dor numerical errors.
            probs = np.array(strategy)
            probs /= probs.sum()
            sampled_action = np.random.choice(range(self._num_actions), p=probs)

            information_state_tensor = time_step.observations["info_state"][other_player]
            self._strategy_memories.add(StrategyMemory(information_state_tensor, self._iteration, strategy))

            return self.mb_traverse_game_tree(self.env_model.step(time_step, [sampled_action]), player, game_length,
                                              environment_steps + 1)

    def mb_sample_action_from_advantage(self, time_step, player):
        info_state = time_step.observations["info_state"][player]
        legal_actions = time_step.observations["legal_actions"][player]
        with torch.no_grad():
            state_tensor = torch.FloatTensor(np.expand_dims(info_state, axis=0)).to(self.device)
            raw_advantages = self._advantage_networks[player](state_tensor)[0].cpu().numpy()
        advantages = [max(0., advantage) for advantage in raw_advantages]
        cumulative_regret = np.sum([advantages[action] for action in legal_actions])
        matched_regrets = np.array([0.] * self._num_actions)
        if cumulative_regret > 0.:
            for action in legal_actions:
                matched_regrets[action] = advantages[action] / cumulative_regret
        else:
            matched_regrets[max(legal_actions, key=lambda a: raw_advantages[a])] = 1
        return advantages, matched_regrets

    def action_probabilities(self, state, player_id=None):
        cur_player = state.current_player()
        legal_actions = state.legal_actions(cur_player)
        info_state_vector = np.array(state.information_state_tensor())
        if len(info_state_vector.shape) == 1:
            info_state_vector = np.expand_dims(info_state_vector, axis=0)
        with torch.no_grad():
            logits = self._policy_network(torch.FloatTensor(info_state_vector).to(self.device))
            probs = self._policy_sm(logits).cpu().numpy()
        return {action: probs[0][action] for action in legal_actions}

    def _learn_advantage_network(self, player):
        for _ in range(self._advantage_network_train_steps):

            if self._batch_size_advantage:
                if self._batch_size_advantage > len(self._advantage_memories[player]):
                    ## Skip if there aren't enough samples
                    return None
                samples = self._advantage_memories[player].sample(
                    self._batch_size_advantage)
            else:
                samples = self._advantage_memories[player]
            info_states = []
            advantages = []
            iterations = []
            for s in samples:
                info_states.append(s.info_state)
                advantages.append(s.advantage)
                iterations.append([s.iteration])
            # Ensure some samples have been gathered.
            if not info_states:
                return None
            self._optimizer_advantages[player].zero_grad()
            advantages = torch.FloatTensor(np.array(advantages)).to(self.device)
            iters = torch.FloatTensor(np.sqrt(np.array(iterations))).to(self.device)
            outputs = self._advantage_networks[player](
                torch.FloatTensor(np.array(info_states)).to(self.device))
            loss_advantages = self._loss_advantages(iters * outputs,
                                                    iters * advantages)
            loss_advantages.backward()
            self._optimizer_advantages[player].step()

        return loss_advantages.cpu().detach().numpy()

    def _learn_strategy_network(self):
        for _ in range(self._policy_network_train_steps):
            if self._batch_size_strategy:
                if self._batch_size_strategy > len(self._strategy_memories):
                    ## Skip if there aren't enough samples
                    return None
                samples = self._strategy_memories.sample(self._batch_size_strategy)
            else:
                samples = self._strategy_memories
            info_states = []
            action_probs = []
            iterations = []
            for s in samples:
                info_states.append(s.info_state)
                action_probs.append(s.strategy_action_probs)
                iterations.append([s.iteration])

            self._optimizer_policy.zero_grad()
            iters = torch.FloatTensor(np.sqrt(np.array(iterations))).to(self.device)
            ac_probs = torch.FloatTensor(np.array(np.squeeze(action_probs))).to(self.device)
            logits = self._policy_network(torch.FloatTensor(np.array(info_states)).to(self.device))
            outputs = self._policy_sm(logits)
            loss_strategy = self._loss_policy(iters * outputs, iters * ac_probs)
            loss_strategy.backward()
            self._optimizer_policy.step()

        return loss_strategy.cpu().detach().numpy()

    def get_policy_network(self):
        return self._policy_network
