import logging
import os
import shutil
from torch.nn.functional import gumbel_softmax

import torch
from gym.spaces import Space
from typing import Iterable
from torch.nn import Module

from torch.nn import functional as F

from typing import NamedTuple
from torch import Tensor


class ImaginationOutput(NamedTuple):
    belief: Tensor
    prior_state: Tensor
    prior_mean: Tensor
    prior_std_dev: Tensor
    actions: Tensor
    actions_entropy: Tensor
    action_repeats: Tensor
    action_repeats_one_hot: Tensor
    action_repeats_entropy: Tensor
    action_repeats_log_prob: Tensor


class RootSearchOutput(NamedTuple):
    actions: Tensor
    actions_repeat: Tensor
    actions_repeat_one_hot: Tensor
    q_values: Tensor


def get_epsilon(max_eps: float, min_eps: float, curr_steps: int, max_steps: int):
    epsilon = max(min_eps, max_eps - (max_eps - min_eps) * curr_steps / (0.5 * max_steps))
    return epsilon


def clip_action(action, action_space: Space):
    assert len(action.shape) == 3, 'Expected batch of (batch of actions)'
    clamped_action = [torch.clamp(action[:, :, a_i].unsqueeze(1),
                                  action_space.low[a_i],
                                  action_space.high[a_i])
                      for a_i in range(action.shape[2])]
    clamped_action = torch.cat(clamped_action, dim=1)
    clamped_action.transpose_(1, 2)

    assert clamped_action.shape == action.shape
    return clamped_action


def make_results_dir(exp_path, args):
    os.makedirs(exp_path, exist_ok=True)
    if args.opr == 'train' and os.path.exists(exp_path) and os.listdir(exp_path):
        if not args.force:
            raise FileExistsError('{} is not empty. Please use --force to overwrite it'.format(exp_path))
        else:
            shutil.rmtree(exp_path)
            os.makedirs(exp_path)
    log_path = os.path.join(exp_path, 'logs')
    os.makedirs(log_path, exist_ok=True)
    return log_path


def init_logger(base_path):
    formatter = logging.Formatter('[%(asctime)s][%(name)s][%(levelname)s][%(filename)s>%(funcName)s] ==> %(message)s')
    for mode in ['train', 'test', 'train_eval', 'root']:
        file_path = os.path.join(base_path, mode + '.log')
        logger = logging.getLogger(mode)
        handler = logging.StreamHandler()
        handler.setFormatter(formatter)
        logger.addHandler(handler)
        handler = logging.FileHandler(file_path, mode='a')
        handler.setFormatter(formatter)
        logger.addHandler(handler)
        logger.setLevel(logging.DEBUG)


def write_gif(episode_images, action_repeats, episode_rewards, episode_q_values, gif_path, save_mp4=True):
    assert len(episode_images) == len(episode_rewards)

    import plotly.graph_objects as go
    from io import BytesIO
    from PIL import Image

    rep_fig = go.Figure(data=go.Scatter(x=[], y=[]))
    rep_fig.update_layout(
        title="Action Repeat Count",
        xaxis_title="Time Step",
        yaxis_title="Action Repeat",
        title_x=0.5
    )

    episode_score_fig = go.Figure(data=go.Scatter(x=[], y=[]))
    episode_score_fig.update_layout(
        title="Episode Score",
        xaxis_title="Time Step",
        yaxis_title="Score",
        title_x=0.5
    )
    step_score_fig = go.Figure(data=go.Scatter(x=[], y=[]))
    step_score_fig.update_layout(
        title="Step Score",
        xaxis_title="Time Step",
        yaxis_title="Score",
        title_x=0.5
    )
    q_values_score_fig = go.Figure(data=go.Bar(x=[], y=[]))
    q_values_score_fig.update_layout(
        title="Action-value",
        xaxis_title="action idx",
        yaxis_title="q-values",
        title_x=0.5
    )

    episode_stats = []
    total_reward = 0
    _obs = Image.fromarray(episode_images[0])
    width, height = 500, 500
    step_i = 0
    from tqdm import tqdm
    for repeat_i, repeat in enumerate(tqdm(action_repeats)):
        # update repeat count figure
        rep_fig['data'][0]['x'] += tuple([step_i])
        rep_fig['data'][0]['y'] += tuple([repeat])
        repeat_img = Image.open(BytesIO(rep_fig.to_image(format="png",
                                                         width=width, height=height)))

        # update repeat count figure
        q_values_score_fig['data'][0]['x'] = tuple([_ for _ in range(len(episode_q_values[repeat_i]))])
        q_values_score_fig['data'][0]['y'] = tuple(episode_q_values[repeat_i])
        q_values_img = Image.open(BytesIO(q_values_score_fig.to_image(format="png",
                                                                      width=width, height=height)))

        pop_i = 0
        while pop_i <= repeat and step_i < len(episode_images):
            # obs
            obs = Image.fromarray(episode_images[step_i])
            obs = obs.resize((width, height), Image.ANTIALIAS)
            # update score figure.
            total_reward += episode_rewards[step_i]
            episode_score_fig['data'][0]['x'] += tuple([step_i])
            episode_score_fig['data'][0]['y'] += tuple([total_reward])
            step_score_fig['data'][0]['x'] += tuple([step_i])
            step_score_fig['data'][0]['y'] += tuple([episode_rewards[step_i]])
            episode_score_img = Image.open(BytesIO(episode_score_fig.to_image(format="png",
                                                                              width=width, height=height)))
            step_score_img = Image.open(BytesIO(step_score_fig.to_image(format="png",
                                                                        width=width, height=height)))

            # combine repeat image + actual obs + score image
            overall_img = Image.new('RGB', (repeat_img.width + obs.width + episode_score_img.width,
                                            repeat_img.height + max(step_score_img.height, q_values_img.height)),
                                    (255, 255, 255))
            # 1st row
            overall_img.paste(obs, (0, 0))
            overall_img.paste(repeat_img, (obs.width, 0))
            overall_img.paste(episode_score_img, (obs.width + repeat_img.width, 0))

            # 2nd row
            overall_img.paste(step_score_img, (0, repeat_img.height))
            overall_img.paste(q_values_img, (step_score_img.width, repeat_img.height))
            episode_stats.append(overall_img)

            # incr counters
            step_i += 1
            pop_i += 1

    assert total_reward == sum(episode_rewards)
    assert step_i == len(episode_images)

    # save as gif
    episode_stats[0].save(gif_path, save_all=True, append_images=episode_stats[1:], optimize=False, loop=1)

    # save video
    if save_mp4:
        import moviepy.editor as mp
        clip = mp.VideoFileClip(gif_path)
        clip.write_videofile(gif_path.replace('.gif', '.mp4'))


# "get_parameters" and "FreezeParameters" are from the following repo
def get_parameters(modules: Iterable[Module]):
    """
    Given a list of torch modules, returns a list of their parameters.
    :param modules: iterable of modules
    :returns: a list of parameters
    """
    model_parameters = []
    for module in modules:
        model_parameters += list(module.parameters())
    return model_parameters


class FreezeParameters:
    def __init__(self, modules: Iterable[Module]):
        """
        Context manager to locally freeze gradients.
        In some cases with can speed up computation because gradients aren't calculated for these listed modules.
        example:
        ```
        with FreezeParameters([module]):
            output_tensor = module(input_tensor)
        ```
        :param modules: iterable of modules. used to call .parameters() to freeze gradients.
        """
        self.modules = modules
        self.param_states = [p.requires_grad for p in get_parameters(self.modules)]

    def __enter__(self):
        for param in get_parameters(self.modules):
            param.requires_grad = False

    def __exit__(self, exc_type, exc_val, exc_tb):
        for i, param in enumerate(get_parameters(self.modules)):
            param.requires_grad = self.param_states[i]


def lambda_return(imged_reward, value_pred, imged_action_entropy, action_repeats, bootstrap, discount=0.99,
                  lambda_=0.95):
    # Setting lambda=1 gives a discounted Monte Carlo return.
    # Setting lambda=0 gives a fixed 1-step return.
    next_values = torch.cat([value_pred[1:], bootstrap[None]], 0)
    discount_tensor = discount ** action_repeats
    inputs = (imged_reward + imged_action_entropy) + discount_tensor * next_values * (1 - lambda_)
    last = bootstrap
    indices = reversed(range(len(inputs)))
    outputs = []
    for index in indices:
        inp, disc = inputs[index], discount_tensor[index]
        last = inp + disc * lambda_ * last
        outputs.append(last)
    outputs = list(reversed(outputs))
    outputs = torch.stack(outputs, 0)
    returns = outputs
    return returns


def imagine_ahead(prev_state, prev_belief, model, planning_horizon=12, deterministic=False, root_deterministic=False,
                  root_uniform_action_mask=None):
    '''
    imagine_ahead is the function to draw the imaginary tracjectory using the dynamics model, actor, critic.
    Input: current state (posterior), current belief (hidden), policy, transition_model  # torch.Size([50, 30]) torch.Size([50, 200])
    Output: generated trajectory of features includes beliefs, prior_states, prior_means, prior_std_devs
            torch.Size([49, 50, 200]) torch.Size([49, 50, 30]) torch.Size([49, 50, 30]) torch.Size([49, 50, 30])
    '''
    flatten = lambda x: x.view([-1] + list(x.size()[2:]))
    prev_belief = flatten(prev_belief)
    prev_state = flatten(prev_state)

    # Create lists for hidden states (cannot use single tensor
    # as buffer because autograd won't work with inplace writes)
    T = planning_horizon
    beliefs = [torch.empty(0)] * T
    prior_states = [torch.empty(0)] * T
    prior_means = [torch.empty(0)] * T
    actions = [torch.empty(0)] * (T - 1)
    actions_entropy = [torch.empty(0)] * (T - 1)
    action_repeats = [torch.empty(0)] * (T - 1)
    action_repeats_one_hot_rollout = [torch.empty(0)] * (T - 1)
    action_repeats_entropy = [torch.empty(0)] * (T - 1)
    action_repeats_log_prob = [torch.empty(0)] * (T - 1)
    prior_std_devs = [torch.empty(0)] * T
    beliefs[0], prior_states[0] = prev_belief, prev_state

    # Loop over time sequence
    for t in range(T - 1):
        _state = prior_states[t]

        actor_output = model.actor(beliefs[t].detach(), _state.detach())
        if t == 0 and root_uniform_action_mask is not None:
            _actions = model.actor.action_sample(actor_output)
            uniform_actions = model.actor.action_sample(actor_output, batch_size=_actions.shape[0], uniform=True,
                                                        device=_state.device)
            uniform_actions.squeeze_(0)
            _actions[root_uniform_action_mask.bool()] = uniform_actions[root_uniform_action_mask.bool()]
        else:
            _actions = model.actor.action_sample(actor_output, deterministic=deterministic)

        actions[t] = _actions
        actions_entropy[t] = model.actor.action_dist(actor_output).entropy()
        action_repeat_output = model.actor_repeat(beliefs[t].detach(), _state.detach(), _actions.detach())
        action_repeats_one_hot = gumbel_softmax(action_repeat_output.logit, tau=0.1, hard=True)
        action_repeats_one_hot_rollout[t] = action_repeats_one_hot
        action_repeats_entropy[t] = action_repeat_output.entropy
        action_repeats_log_prob[t] = action_repeat_output.log_prob_mean

        with torch.no_grad():
            action_repeat_set = torch.zeros_like(action_repeats_one_hot).to(prev_state.device)
            action_repeat_set[:] = torch.Tensor(model.actor_repeat.action_repeat_set).to(prev_state.device)
            action_repeats[t] = (action_repeat_set * action_repeats_one_hot).sum(1).detach()

        # Compute belief (deterministic hidden state)
        hidden = model.transition.fc_embed_state_action(torch.cat([_state, _actions, action_repeats_one_hot], dim=1))
        hidden = F.elu(hidden)
        beliefs[t + 1] = model.transition.rnn(hidden, beliefs[t])

        # Compute state prior by applying transition dynamics
        hidden = F.elu(model.transition.fc_embed_belief_prior(beliefs[t + 1]))
        prior_means[t + 1], _prior_std_dev = torch.chunk(model.transition.fc_state_prior(hidden), 2, dim=1)
        prior_std_devs[t + 1] = F.softplus(_prior_std_dev) + model.transition.min_std_dev
        prior_states[t + 1] = prior_means[t + 1] + prior_std_devs[t + 1] * torch.randn_like(prior_means[t + 1])

    return ImaginationOutput(torch.stack(beliefs[1:], dim=0), torch.stack(prior_states[1:], dim=0),
                             torch.stack(prior_means[1:], dim=0), torch.stack(prior_std_devs[1:], dim=0),
                             torch.stack(actions, dim=0),
                             torch.stack(actions_entropy, dim=0),
                             torch.stack(action_repeats, dim=0),
                             torch.stack(action_repeats_one_hot_rollout, dim=0),
                             torch.stack(action_repeats_entropy, dim=0),
                             torch.stack(action_repeats_log_prob, dim=0))


def bottle(f, x_tuple):
    """ Wraps the input tuple for a function to process a (time x batch x features) sequence in
    (batch x features) (assumes one output)
    """
    x_sizes = tuple(map(lambda x: x.size(), x_tuple))
    y = f(*map(lambda x: x[0].view(x[1][0] * x[1][1], *x[1][2:]), zip(x_tuple, x_sizes)))
    y_size = y.size()
    output = y.view(x_sizes[0][0], x_sizes[0][1], *y_size[1:])
    return output


def root_search(num_env, model, belief, posterior_state, proposal_action_sample, uniform_action_sample,
                config, planning_horizon=12):
    total_actions = proposal_action_sample + uniform_action_sample
    root_uniform_action_mask = torch.zeros((num_env, total_actions)).to(config.device)
    root_uniform_action_mask[:, torch.arange(proposal_action_sample, total_actions)] = 1
    root_uniform_action_mask = root_uniform_action_mask.flatten().long()

    # reshape states to accommodate multiple actions for each root state.
    _belief = belief.repeat(1, 1, total_actions)
    _belief = _belief.reshape((1, num_env * total_actions, model.belief_size))

    _posterior_state = posterior_state.repeat(1, 1, total_actions)
    _posterior_state = _posterior_state.reshape((1, num_env * total_actions, model.state_size))

    # imagine rollouts
    imagination_output = imagine_ahead(_posterior_state, _belief, model, deterministic=True,
                                       root_uniform_action_mask=root_uniform_action_mask,
                                       planning_horizon=planning_horizon)
    imged_reward = bottle(model.reward, (imagination_output.belief, imagination_output.prior_state))
    value_pred = bottle(model.value, (imagination_output.belief, imagination_output.prior_state))

    # estimate q-value estimates for each root child
    # if training:
    #     _img_action_entropy = config.actor_entropy_coeff * imagination_output.actions_entropy.unsqueeze(-1)
    # else:
    _img_action_entropy = torch.zeros_like(imged_reward).to(imagination_output.actions_entropy.device)

    returns = lambda_return(imged_reward, value_pred, _img_action_entropy,
                            action_repeats=imagination_output.action_repeats.unsqueeze(2),
                            bootstrap=value_pred[-1],
                            discount=config.gamma,
                            lambda_=config.disclam)

    # get value of root childs
    q_values = returns[0, :, :].reshape((num_env, total_actions))

    # determine actions
    root_actions = imagination_output.actions[0, :, :]
    root_actions_repeat = imagination_output.action_repeats[0, :]
    root_actions_repeat_one_hot = imagination_output.action_repeats_one_hot[0, :]

    root_actions = root_actions.reshape((num_env, total_actions, root_actions.shape[-1]))
    root_actions_repeat = root_actions_repeat.reshape((num_env, total_actions))
    root_actions_repeat_one_hot = root_actions_repeat_one_hot.reshape((num_env, total_actions,
                                                                       root_actions_repeat_one_hot.shape[-1]))

    return RootSearchOutput(root_actions, root_actions_repeat, root_actions_repeat_one_hot, q_values)
