import torch
import open_spiel.python.rl_agent as rl_agent
from algorithms.escher_orig.escher_orig import ESCHERSolver
from torch.distributions import Categorical
import numpy as np
import torch.nn as nn
import wandb


class EscherRLAgent(rl_agent.AbstractAgent):
    def __init__(self, policy_network, player_id, n_obs, n_actions):
        self.model = policy_network
        self.model.to("cpu")
        self.player_id = player_id
        self.n_obs = n_obs
        self.n_actions = n_actions

    def step(self, time_step, is_evaluation=False):
        obs = time_step.observations["info_state"][self.player_id]
        legal_actions = time_step.observations["legal_actions"][self.player_id]
        legal_actions = np.array(legal_actions)
        legal_actions_mask = torch.zeros((self.n_actions,), dtype=torch.bool)
        legal_actions_mask[legal_actions] = True
        logits = self.model(torch.tensor(np.array(obs), dtype=torch.float))
        logits = torch.where(legal_actions_mask.bool(), logits, -1e6)
        dist = Categorical(logits=logits)
        action = dist.sample().detach().numpy()
        probs = np.zeros(self.n_actions)
        probs[action] = 1.0
        return rl_agent.StepOutput(action=action, probs=probs)

    def get_model(self):
        def model(info_state):
            action_values = self.model(torch.tensor(info_state, dtype=torch.float32))
            return action_values

        return model


class MLP(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):
        super(MLP, self).__init__()
        # Create a list of all layers
        layers = []
        current_size = input_size
        for hidden_size in hidden_sizes:
            layers.append(nn.Linear(current_size, hidden_size))
            layers.append(nn.ReLU())
            current_size = hidden_size
        layers.append(nn.Linear(current_size, output_size))
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

    def load_params_from_dict(self, params):
        weights = params[::2]
        biases = params[1::2]
        param_idx = 0
        for i, layer in enumerate(self.layers):
            if isinstance(layer, nn.Linear):
                # need to transpose the weights
                layer.weight.data = torch.tensor(weights[param_idx].numpy()).T
                layer.bias.data = torch.tensor(biases[param_idx].numpy())
                param_idx += 1


def convert_tf_list_to_torch(params):
    layer_shapes = [layer_param.numpy().shape for layer_param in params][::2]
    torch_mlp = MLP(
        layer_shapes[0][0],
        [shape[1] for shape in layer_shapes[:-1]],
        layer_shapes[-1][1],
    )
    torch_mlp.load_params_from_dict(params)
    return torch_mlp


class RunEscherOrig:
    def __init__(self, args, game, expl_callback):
        self.args = args
        self.game = game
        self.expl_callback = expl_callback
        self.deep_cfr_solver = ESCHERSolver(
            self.game,
            num_traversals=args.num_traversals,
            num_val_fn_traversals=args.num_val_fn_traversals,
            num_iterations=args.iters,
            regret_network_train_steps=args.regret_train_steps,
            policy_network_train_steps=args.policy_net_train_steps,
            batch_size_regret=args.batch_size_regret,
            value_network_train_steps=args.val_train_steps,
            batch_size_value=args.batch_size_val,
            train_device=args.train_device,
            num_random_games=args.num_random_games,
            max_steps=args.max_steps,
            use_wandb=args.wandb,
            eval_every=args.eval_every,
        )

    def wrap_rl_agent(self, save_path=None):
        model = convert_tf_list_to_torch(self.deep_cfr_solver._policy_network.weights)
        n_obs = self.game.observation_tensor_size()
        n_actions = self.game.num_distinct_actions()
        agents = [
            EscherRLAgent(model, id, n_obs, n_actions)
            for id in range(self.game.num_players())
        ]
        if save_path:
            import pickle

            with open(save_path, "wb") as f:
                pickle.dump(agents, f)
        return agents

    def get_model(self):
        def model_0(info_state):
            action_values = self.deep_cfr_solver._policy_network(
                torch.tensor(info_state, dtype=torch.float32)
            )
            action_probs = torch.softmax(action_values, dim=1)
            return action_probs

        def model_1(info_state):
            action_values = self.deep_cfr_solver._policy_network(
                torch.tensor(info_state, dtype=torch.float32)
            )
            action_probs = torch.softmax(action_values, dim=1)
            return action_probs

        return [model_0, model_1]

    def run(self):
        # regret, pol_loss, convs, nodes = self.deep_cfr_solver.solve(save_path_convs=self.args.experiment_dir)

        self.deep_cfr_solver.pre_iteration()
        expl_check_step_count = 0
        while True:
            start_nodes = self.deep_cfr_solver._nodes_visited

            wandb_log = self.deep_cfr_solver.iteration()
            wandb.log(wandb_log)
            added_nodes = self.deep_cfr_solver._nodes_visited - start_nodes
            expl_check_step_count += added_nodes

            if expl_check_step_count > self.args.compute_exploitability_every:
                self.deep_cfr_solver._reinitialize_policy_network()
                self.deep_cfr_solver._learn_average_policy_network()
                expl_agents = self.wrap_rl_agent()
                models = [ag.get_model() for ag in expl_agents]
                self.expl_callback(
                    models[0], models[1], self.deep_cfr_solver._nodes_visited
                )
                expl_check_step_count = 0

            if self.deep_cfr_solver._nodes_visited >= self.args.max_steps:
                break

        last_iterate_policy = self.deep_cfr_solver.save_policy_network(
            self.args.experiment_dir
        )
        return last_iterate_policy
