import os
import pickle
from typing import NamedTuple

import numpy as np
import torch

from .model import DreamerNetwork, ActorOutput
from .utils import write_gif
from .env import EnvBatcher
from .utils import root_search
from .mcts import Node, MCTS


class TestOutput(NamedTuple):
    score: float
    avg_repeat: float


def test(env: EnvBatcher, model: DreamerNetwork, config, render: bool = False, recording_path=None,
         save_video: bool = False, save_test_data: bool = False, save_path=None, mode='no-search',
         mcts_num_simulations=None):
    mode in ['mcts','mcts+fixed', 'rollout', 'no-search']

    episode_rewards = {key_i: [] for key_i in range(env.n)}
    action_repeats = {key_i: [] for key_i in range(env.n)}
    episode_q_values = {key_i: [] for key_i in range(env.n)}
    episode_images = {key_i: [] for key_i in range(env.n)}

    model.to(config.device)
    model.eval()

    # init
    belief = model.init_belief(env.n).to(config.device)
    posterior_state = model.init_state(env.n).to(config.device)
    action = model.init_action(env.n).to(config.device)
    action_repeat_one_hot = model.init_action_repeat_one_hot(env.n).to(config.device)

    obs = env.reset()
    dones = [False for _ in range(env.n)]

    while not all(dones):
        obs = torch.FloatTensor(obs).to(config.device)
        with torch.no_grad():
            transition_output = model.transition(belief, posterior_state, action.unsqueeze(0),
                                                 action_repeat_one_hot.unsqueeze(0),
                                                 model.encoder(obs).unsqueeze(0))
            belief = transition_output.belief.squeeze(0)
            posterior_state = transition_output.posterior_state.squeeze(0)

            if mode == 'no-search':
                actor_output = model.actor(belief, posterior_state)
                action = model.actor.action_sample(actor_output, deterministic=True)

                actor_repeat_output = model.actor_repeat(belief, posterior_state, action)
                action_repeat_one_hot, action_repeat = model.actor_repeat.sample(actor_repeat_output,
                                                                                 deterministic=True)
                action_repeat = action_repeat.squeeze(1).cpu().int().tolist()
            elif mode == 'rollout':
                root_search_output = root_search(env.n, model,
                                                 transition_output.belief,
                                                 transition_output.posterior_state,
                                                 config.proposal_action_sample,
                                                 config.uniform_action_sample,
                                                 config)

                actions_idx = root_search_output.q_values.argmax(dim=1)
                action = root_search_output.actions[torch.arange(env.n), actions_idx]
                action_repeat = root_search_output.actions_repeat[torch.arange(env.n), actions_idx].cpu().int().tolist()
                action_repeat_one_hot = root_search_output.actions_repeat_one_hot[torch.arange(env.n), actions_idx]

                for key_i in range(env.n):
                    episode_q_values[key_i].append(root_search_output.q_values[key_i].data.cpu().numpy().tolist())
            elif 'mcts' in mode:
                actor_output = model.actor(belief, posterior_state)
                child_action = model.actor.action_sample(actor_output, deterministic=False)
                child_actor_repeat_output = model.actor_repeat(belief, posterior_state, child_action)
                reward = model.reward(belief, posterior_state)
                child_action_repeat_one_hot, child_action_repeat = model.actor_repeat.sample(child_actor_repeat_output,
                                                                                             deterministic=False)
                actor_dist = model.actor.action_dist(actor_output)
                child_action_log_prob = - actor_dist.entropy(child_action)

                action, action_repeat, action_repeat_one_hot = None, None, None
                for env_i in range(env.n):
                    root = Node(0, root=True)
                    progressive = 'fixed' not in mode
                    root.expand(model,ActorOutput(actor_output.mu[env_i, :].unsqueeze(0),
                                            actor_output.std_dev[env_i, :].unsqueeze(0)),
                                child_action_log_prob[env_i].item(),
                                child_action[env_i, :].unsqueeze(0).unsqueeze(0),
                                child_action_repeat[env_i, :].int().item(),
                                child_action_repeat_one_hot[env_i, :].unsqueeze(0).unsqueeze(0),
                                (belief[env_i, :].unsqueeze(0),
                                 posterior_state[env_i, :].unsqueeze(0)),
                                reward[env_i, :].item())
                    MCTS(config, progressive=progressive).run(root, model,
                                                              num_simulations=config.num_simulations)
                    _, _, action_cap, child = max([(child.value(), child.visit_count, action_cap, child)
                                                   for action_cap, child in root.children.items()])

                    episode_q_values[env_i].append([child.value() for action_cap, child in root.children.items()])

                    if action is None:
                        action = action_cap.action
                        action_repeat = [action_cap.repeat]
                        action_repeat_one_hot = action_cap.repeat_one_hot
                    else:
                        action = torch.cat((action, action_cap.action), dim=1)
                        action_repeat.append(action_cap.repeat)
                        action_repeat_one_hot = torch.cat((action_repeat_one_hot, action_cap.repeat_one_hot), dim=1)

                action.squeeze_(0)
                action_repeat_one_hot.squeeze_(0)
            else:
                raise NotImplementedError()

        # step action
        step_action = action.cpu().numpy()
        for key_i in range(env.n):
            action_repeats[key_i].append(action_repeat[key_i])

        _obs, dones = [], []
        _rewards = {_: [] for _ in range(env.n)}
        _max_rep = 0
        for env_idx in range(env.n):
            for rep_i in range(action_repeat[env_idx]):

                if render and save_video:
                    img = env.render(mode='rgb_array', idx=env_idx)
                    episode_images[env_idx].append(img)
                elif render:
                    env.render(mode='human', idx=env_idx)

                # step
                obs, reward, done, info = env.step_single(step_action[env_idx], idx=env_idx)
                episode_rewards[env_idx] += [reward]

                _max_rep = max(_max_rep, rep_i + 1)
                if done:
                    break
            dones.append(done)
            _obs.append(obs)

        obs = _obs

    if render and save_video:
        for key_i in range(env.n):
            write_gif(episode_images[key_i], action_repeats[key_i], episode_rewards[key_i], episode_q_values[key_i],
                      os.path.join(recording_path, 'ep_{}.gif'.format(key_i)))
    if save_test_data:
        pickle.dump((episode_rewards, action_repeats), open(save_path, 'wb'))

    return TestOutput(np.mean([sum(episode_rewards[k]) for k in range(env.n)]),
                      np.mean([np.mean(action_repeats[k]) for k in range(env.n)]))
