import torch
from easydict import EasyDict
from lzero.policy import inverse_scalar_transform, select_action
import numpy as np
import random

from lzero.mcts.tree_search.mcts_ptree import EfficientZeroMCTSPtree as MCTSPtree
from lzero.mcts.tree_search.mcts_ctree import EfficientZeroMCTSCtree as MCTSCtree
import time


class MuZeroModelFake(torch.nn.Module):
    """
    Overview:
        Fake MuZero model just for test EfficientZeroMCTSPtree.
    Interfaces:
        __init__, initial_inference, recurrent_inference
    """

    def __init__(self, action_num):
        super().__init__()
        self.action_num = action_num

    def initial_inference(self, observation):
        encoded_state = observation
        batch_size = encoded_state.shape[0]

        value = torch.zeros(size=(batch_size, 601))
        value_prefix = [0. for _ in range(batch_size)]
        policy_logits = torch.zeros(size=(batch_size, self.action_num))
        latent_state = torch.zeros(size=(batch_size, 12, 3, 3))
        reward_hidden_state_state = (torch.zeros(size=(1, batch_size, 16)), torch.zeros(size=(1, batch_size, 16)))

        output = {
            'searched_value': value,
            'value_prefix': value_prefix,
            'policy_logits': policy_logits,
            'latent_state': latent_state,
            'reward_hidden_state': reward_hidden_state_state
        }

        return EasyDict(output)

    def recurrent_inference(self, hidden_states, reward_hidden_states, actions):
        batch_size = hidden_states.shape[0]
        latent_state = torch.zeros(size=(batch_size, 12, 3, 3))
        reward_hidden_state_state = (torch.zeros(size=(1, batch_size, 16)), torch.zeros(size=(1, batch_size, 16)))
        value = torch.zeros(size=(batch_size, 601))
        value_prefix = torch.zeros(size=(batch_size, 601))
        policy_logits = torch.zeros(size=(batch_size, self.action_num))

        output = {
            'searched_value': value,
            'value_prefix': value_prefix,
            'policy_logits': policy_logits,
            'latent_state': latent_state,
            'reward_hidden_state': reward_hidden_state_state
        }

        return EasyDict(output)


def ptree_func(policy_config, num_simulations):
    """
        Overview:
            Search on the tree of the Python implementation and record the time spent at different stages.
        Arguments:
            - policy_config: config of game.
            - num_simulations: Number of simulations.
        Returns:
            - build_time: Type builds take time.
            - prepare_time: time for prepare.
            - search_time.
            - total_time.
        """
    batch_size = env_nums = policy_config.batch_size
    action_space_size = policy_config.action_space_size

    build_time = []
    prepare_time = []
    search_time = []
    total_time = []

    for n_s in num_simulations:
        t0 = time.time()
        model = MuZeroModelFake(action_num=action_space_size)
        stack_obs = torch.zeros(
            size=(
                batch_size,
                n_s,
            ), dtype=torch.float
        )

        policy_config.num_simulations = n_s
        network_output = model.initial_inference(stack_obs.float())

        latent_state_roots = network_output['latent_state']
        reward_hidden_state_state = network_output['reward_hidden_state']
        pred_values_pool = network_output['value']
        value_prefix_pool = network_output['value_prefix']
        policy_logits_pool = network_output['policy_logits']

        # network output process
        pred_values_pool = inverse_scalar_transform(pred_values_pool,
                                                    policy_config.model.support_scale).detach().cpu().numpy()
        latent_state_roots = latent_state_roots.detach().cpu().numpy()
        reward_hidden_state_state = (
            reward_hidden_state_state[0].detach().cpu().numpy(), reward_hidden_state_state[1].detach().cpu().numpy()
        )
        policy_logits_pool = policy_logits_pool.detach().cpu().numpy().tolist()

        action_mask = [[random.randint(0, 1) for _ in range(action_space_size)] for _ in range(env_nums)]
        assert len(action_mask) == batch_size
        assert len(action_mask[0]) == action_space_size

        action_num = [int(np.array(action_mask[i]).sum()) for i in range(env_nums)]
        legal_actions_list = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(env_nums)]
        to_play = [np.random.randint(1, 3) for i in range(env_nums)]
        assert len(to_play) == batch_size
        # ============================================ptree=====================================#
        for i in range(env_nums):
            assert action_num[i] == len(legal_actions_list[i])
        t1 = time.time()
        roots = MCTSPtree.roots(env_nums, legal_actions_list)
        build_time.append(time.time() - t1)
        noises = [
            np.random.dirichlet([policy_config.root_dirichlet_alpha] * int(sum(action_mask[j]))
                                ).astype(np.float32).tolist() for j in range(env_nums)
        ]
        t1 = time.time()
        roots.prepare(policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play)
        prepare_time.append(time.time() - t1)
        t1 = time.time()
        MCTSPtree(policy_config).search(roots, model, latent_state_roots, reward_hidden_state_state, to_play)
        search_time.append(time.time() - t1)
        total_time.append(time.time() - t0)
        roots_distributions = roots.get_distributions()
        roots_values = roots.get_values()
        assert len(roots_values) == env_nums
        assert len(roots_values) == env_nums
        for i in range(env_nums):
            assert len(roots_distributions[i]) == action_num[i]

        temperature = [1 for _ in range(env_nums)]
        for i in range(env_nums):
            distributions = roots_distributions[i]
            action_index, visit_count_distribution_entropy = select_action(
                distributions, temperature=temperature[i], deterministic=False
            )
            action = np.where(np.array(action_mask[i]) == 1.0)[0][action_index]
            assert action_index < action_num[i]
            assert action == legal_actions_list[i][action_index]
            print('\n action_index={}, legal_action={}, action={}'.format(action_index, legal_actions_list[i], action))
    return build_time, prepare_time, search_time, total_time


def ctree_func(policy_config, num_simulations):
    """
        Overview:
            Search on the tree of the C++ implementation and record the time spent at different stages.
        Arguments:
            - policy_config: config of game.
            - num_simulations: Number of simulations.
        Returns:
            - build_time: Type builds take time.
            - prepare_time: time for prepare.
            - search_time.
            - total_time.
        """
    batch_size = env_nums = policy_config.batch_size
    action_space_size = policy_config.action_space_size

    build_time = []
    prepare_time = []
    search_time = []
    total_time = []

    for n_s in num_simulations:
        t0 = time.time()
        model = MuZeroModelFake(action_num=action_space_size)
        stack_obs = torch.zeros(
            size=(
                batch_size,
                n_s,
            ), dtype=torch.float
        )
        policy_config.num_simulations = n_s

        network_output = model.initial_inference(stack_obs.float())

        latent_state_roots = network_output['latent_state']
        reward_hidden_state_state = network_output['reward_hidden_state']
        pred_values_pool = network_output['value']
        value_prefix_pool = network_output['value_prefix']
        policy_logits_pool = network_output['policy_logits']

        # network output process
        pred_values_pool = inverse_scalar_transform(pred_values_pool,
                                                    policy_config.model.support_scale).detach().cpu().numpy()
        latent_state_roots = latent_state_roots.detach().cpu().numpy()
        reward_hidden_state_state = (
            reward_hidden_state_state[0].detach().cpu().numpy(), reward_hidden_state_state[1].detach().cpu().numpy()
        )
        policy_logits_pool = policy_logits_pool.detach().cpu().numpy().tolist()

        action_mask = [[random.randint(0, 1) for _ in range(action_space_size)] for _ in range(env_nums)]
        assert len(action_mask) == batch_size
        assert len(action_mask[0]) == action_space_size

        action_num = [int(np.array(action_mask[i]).sum()) for i in range(env_nums)]
        legal_actions_list = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(env_nums)]
        to_play = [np.random.randint(1, 3) for i in range(env_nums)]
        assert len(to_play) == batch_size
        # ============================================ctree=====================================#
        for i in range(env_nums):
            assert action_num[i] == len(legal_actions_list[i])

        t1 = time.time()
        roots = MCTSCtree.roots(env_nums, legal_actions_list)
        build_time.append(time.time() - t1)
        noises = [
            np.random.dirichlet([policy_config.root_dirichlet_alpha] * int(sum(action_mask[j]))
                                ).astype(np.float32).tolist() for j in range(env_nums)
        ]
        t1 = time.time()
        roots.prepare(policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play)
        prepare_time.append(time.time() - t1)
        t1 = time.time()
        MCTSCtree(policy_config).search(roots, model, latent_state_roots, reward_hidden_state_state, to_play)
        search_time.append(time.time() - t1)
        total_time.append(time.time() - t0)
        roots_distributions = roots.get_distributions()
        roots_values = roots.get_values()
        assert len(roots_values) == env_nums
        assert len(roots_values) == env_nums
        for i in range(env_nums):
            assert len(roots_distributions[i]) == action_num[i]

        temperature = [1 for _ in range(env_nums)]
        for i in range(env_nums):
            distributions = roots_distributions[i]
            action_index, visit_count_distribution_entropy = select_action(
                distributions, temperature=temperature[i], deterministic=False
            )
            action = np.where(np.array(action_mask[i]) == 1.0)[0][action_index]
            assert action_index < action_num[i]
            assert action == legal_actions_list[i][action_index]
            print('\n action_index={}, legal_action={}, action={}'.format(action_index, legal_actions_list[i], action))
    return build_time, prepare_time, search_time, total_time


def plot(ctree_time, ptree_time, iters, label):
    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib import pyplot
    plt.style.use('seaborn-whitegrid')
    palette = pyplot.get_cmap('Set1')
    font1 = {
        'family': 'Times New Roman',
        'weight': 'normal',
        'size': 18,
    }

    plt.figure(figsize=(20, 10))
    # ctree
    color = palette(0)
    avg = np.mean(ctree_time, axis=0)
    std = np.std(ctree_time, axis=0)
    r1 = list(map(lambda x: x[0] - x[1], zip(avg, std)))
    r2 = list(map(lambda x: x[0] + x[1], zip(avg, std)))
    plt.plot(iters, avg, color=color, label="ctree", linewidth=3.0)
    plt.fill_between(iters, r1, r2, color=color, alpha=0.2)

    # ptree
    ptree_time = np.array(ptree_time)
    color = palette(1)
    avg = np.mean(ptree_time, axis=0)
    std = np.std(ptree_time, axis=0)
    r1 = list(map(lambda x: x[0] - x[1], zip(avg, std)))
    r2 = list(map(lambda x: x[0] + x[1], zip(avg, std)))
    plt.plot(iters, avg, color=color, label="ptree", linewidth=3.0)
    plt.fill_between(iters, r1, r2, color=color, alpha=0.2)

    plt.legend(loc='lower right', prop=font1)
    plt.title('{}'.format(label))
    plt.xlabel('simulations', fontsize=22)
    plt.ylabel('time', fontsize=22)
    plt.savefig('{}-time.png'.format(label))


if __name__ == "__main__":

    # cProfile.run("ctree_func()", filename="ctree_result.out", sort="cumulative")
    # cProfile.run("ptree_func()", filename="ptree_result.out", sort="cumulative")

    policy_config = EasyDict(
        dict(
            lstm_horizon_len=5,
            model=dict(
                support_scale=300,
                categorical_distribution=True,
            ),
            action_space_size=100,
            num_simulations=100,
            batch_size=512,
            pb_c_base=1,
            pb_c_init=1,
            discount_factor=0.9,
            root_dirichlet_alpha=0.3,
            root_noise_weight=0.2,
            dirichlet_alpha=0.3,
            exploration_fraction=1,
            device='cpu',
            value_delta_max=0.01,
        )
    )

    ACTION_SPCAE_SIZE = [16, 50]
    BATCH_SIZE = [8, 64, 512]
    NUM_SIMULATIONS = [i for i in range(20, 200, 20)]

    # ACTION_SPCAE_SIZE = [50]
    # BATCH_SIZE = [512]
    # NUM_SIMULATIONS =  [i for i in range(10, 50, 10)]

    for action_space_size in ACTION_SPCAE_SIZE:
        for batch_size in BATCH_SIZE:
            policy_config.batch_size = batch_size
            policy_config.action_space_size = action_space_size
            ctree_build_time = []
            ctree_prepare_time = []
            ctree_search_time = []
            ptree_build_time = []
            ptree_prepare_time = []
            ptree_search_time = []
            ctree_total_time = []
            ptree_total_time = []
            num_simulations = NUM_SIMULATIONS
            for i in range(3):
                build_time, prepare_time, search_time, total_time = ctree_func(
                    policy_config, num_simulations=num_simulations
                )
                ctree_build_time.append(build_time)
                ctree_prepare_time.append(prepare_time)
                ctree_search_time.append(search_time)
                ctree_total_time.append(total_time)

            for i in range(3):
                build_time, prepare_time, search_time, total_time = ptree_func(
                    policy_config, num_simulations=num_simulations
                )
                ptree_build_time.append(build_time)
                ptree_prepare_time.append(prepare_time)
                ptree_search_time.append(search_time)
                ptree_total_time.append(total_time)
            label = 'action_space_size_{}_batch_size_{}'.format(action_space_size, batch_size)
            plot(ctree_build_time, ptree_build_time, iters=num_simulations, label=label + '_bulid_time')
            plot(ctree_prepare_time, ptree_prepare_time, iters=num_simulations, label=label + '_prepare_time')
            plot(ctree_search_time, ptree_search_time, iters=num_simulations, label=label + '_search_time')
            plot(ctree_total_time, ptree_total_time, iters=num_simulations, label=label + '_total_time')
