import copy
import torch
from torch import nn
import numpy as np
from PIL import ImageColor, Image, ImageDraw, ImageFont

import networks
import tools

to_np = lambda x: x.detach().cpu().numpy()


class RewardEMA(object):
    """running mean and std"""

    def __init__(self, device, alpha=1e-2):
        self.device = device
        self.values = torch.zeros((2,)).to(device)
        self.alpha = alpha
        self.range = torch.tensor([0.05, 0.95]).to(device)

    def __call__(self, x):
        flat_x = torch.flatten(x.detach())
        x_quantile = torch.quantile(input=flat_x, q=self.range)
        self.values = self.alpha * x_quantile + (1 - self.alpha) * self.values
        scale = torch.clip(self.values[1] - self.values[0], min=1.0)
        offset = self.values[0]
        return offset.detach(), scale.detach()


class WorldModel(nn.Module):
    def __init__(self, obs_space, act_space, step, config):
        super(WorldModel, self).__init__()
        self._step = step
        self._use_amp = True if config.precision == 16 else False
        self._config = config
        shapes = {k: tuple(v.shape) for k, v in obs_space.spaces.items()}
        self.encoder = networks.MultiEncoder(shapes, **config.encoder)
        self.embed_size = self.encoder.outdim
        self.dynamics = networks.RSSM(
            config.z_dyn_stoch,
            config.z_dyn_deter,
            config.dyn_hidden,
            config.dyn_input_layers,
            config.dyn_output_layers,
            config.dyn_rec_depth,
            config.dyn_shared,
            config.z_dyn_discrete,
            config.act,
            config.norm,
            config.dyn_mean_act,
            config.dyn_std_act,
            config.dyn_temp_post,
            config.dyn_min_std,
            config.dyn_cell,
            config.unimix_ratio,
            config.initial,
            config.num_actions,
            self.embed_size,
            config.device,
            c_stoch=config.c_dyn_stoch,
            c_discrete=config.c_dyn_discrete,
            c_deter=config.c_dyn_deter,
        )
        self.heads = nn.ModuleDict()
        if config.z_dyn_discrete:
            reward_feat_size = config.z_dyn_stoch * config.z_dyn_discrete + config.z_dyn_deter
            decoder_feat_size = config.z_dyn_stoch * config.z_dyn_discrete + config.z_dyn_deter +\
                                config.c_dyn_stoch * config.c_dyn_discrete + config.c_dyn_deter
            cont_feat_size = reward_feat_size
        else:
            raise NotImplementedError
            # feat_size = config.dyn_stoch + config.dyn_deter
        self.heads["decoder"] = networks.MultiDecoder(
            decoder_feat_size, shapes, **config.decoder
        )
        if config.reward_head == "symlog_disc":
            self.heads["reward"] = networks.MLP(
                reward_feat_size,  # pytorch version
                (255,),
                config.reward_layers,
                config.units,
                config.act,
                config.norm,
                dist=config.reward_head,
                outscale=0.0,
                device=config.device,
            )
        else:
            self.heads["reward"] = networks.MLP(
                reward_feat_size,  # pytorch version
                [],
                config.reward_layers,
                config.units,
                config.act,
                config.norm,
                dist=config.reward_head,
                outscale=0.0,
                device=config.device,
            )
        self.heads["cont"] = networks.MLP(
            cont_feat_size,  # pytorch version
            [],
            config.cont_layers,
            config.units,
            config.act,
            config.norm,
            dist="binary",
            device=config.device,
        )
        for name in config.grad_heads:
            assert name in self.heads, name
        self._model_opt = tools.Optimizer(
            "model",
            self.parameters(),
            config.model_lr,
            config.opt_eps,
            config.grad_clip,
            config.weight_decay,
            opt=config.opt,
            use_amp=self._use_amp,
        )
        self._scales = dict(reward=config.reward_scale, cont=config.cont_scale)

    def _train(self, data):
        # action (batch_size, batch_length, act_dim)
        # image (batch_size, batch_length, h, w, ch)
        # reward (batch_size, batch_length)
        # discount (batch_size, batch_length)
        data = self.preprocess(data)

        with tools.RequiresGrad(self):
            with torch.cuda.amp.autocast(self._use_amp):
                embed = self.encoder(data)
                post, prior = self.dynamics.observe(
                    embed, data["action"], data["is_first"]
                )

                z_kl_free = self._config.z_kl_free
                z_dyn_scale = self._config.z_dyn_scale
                z_rep_scale = self._config.z_rep_scale
                z_kl_loss, z_kl_value, z_dyn_loss, z_rep_loss = self.dynamics.kl_loss(
                    post, prior, z_kl_free, z_dyn_scale, z_rep_scale, latent_type="latent_z"
                )

                c_kl_free = self._config.c_kl_free
                c_dyn_scale = self._config.c_dyn_scale
                c_rep_scale = self._config.c_rep_scale
                c_kl_loss, c_kl_value, c_dyn_loss, c_rep_loss = self.dynamics.kl_loss(
                    post, prior, c_kl_free, c_dyn_scale, c_rep_scale, latent_type="latent_c"
                )

                preds = {}
                for name, head in self.heads.items():
                    grad_head = name in self._config.grad_heads
                    if name in ["reward", "cont"]:
                        feat = self.dynamics.get_feat(post, latent_type="latent_z")
                    elif name in ["decoder"]:
                        feat = self.dynamics.get_feat(post, latent_type="latent_z_latent_c")
                    else:
                        raise NotImplementedError
                    feat = feat if grad_head else feat.detach()
                    pred = head(feat)
                    if type(pred) is dict:
                        preds.update(pred)
                    else:
                        preds[name] = pred
                losses = {}
                for name, pred in preds.items():
                    like = pred.log_prob(data[name])
                    losses[name] = -torch.mean(like) * self._scales.get(name, 1.0)
                # model_loss = sum(losses.values()) + kl_loss
                model_loss = sum(losses.values()) + z_kl_loss + c_kl_loss

                from torch import distributions as torchd
                kld = torchd.kl.kl_divergence


                z_post_logit = post["z_logit"]
                z_post_dist = torchd.independent.Independent(
                    torchd.one_hot_categorical.OneHotCategorical(logits=z_post_logit), 1
                )

                z_infer_dist = torchd.independent.Independent(
                    torchd.one_hot_categorical.OneHotCategorical(logits=post["z_infer_logit"]), 1
                )

                z_infer_loss = kld(
                    z_post_dist, z_infer_dist
                )

                z_infer_loss = self._config.z_infer_scale * torch.mean(z_infer_loss)

                model_loss += z_infer_loss

            metrics = self._model_opt(model_loss, self.parameters())
        metrics.update({"z_infer_loss": to_np(z_infer_loss)})

        metrics.update({f"{name}_loss": to_np(loss) for name, loss in losses.items()})
        metrics["z_infer_scale"] = self._config.z_infer_scale
        metrics["z_kl_free"] = z_kl_free
        metrics["z_dyn_scale"] = z_dyn_scale
        metrics["z_rep_scale"] = z_rep_scale
        metrics["z_dyn_loss"] = to_np(z_dyn_loss)
        metrics["z_rep_loss"] = to_np(z_rep_loss)
        metrics["z_kl"] = to_np(torch.mean(z_kl_value))
        metrics["c_kl_free"] = c_kl_free
        metrics["c_dyn_scale"] = c_dyn_scale
        metrics["c_rep_scale"] = c_rep_scale
        metrics["c_dyn_loss"] = to_np(c_dyn_loss)
        metrics["c_rep_loss"] = to_np(c_rep_loss)
        metrics["c_kl"] = to_np(torch.mean(c_kl_value))
        with torch.cuda.amp.autocast(self._use_amp):
            metrics["z_prior_ent"] = to_np(
                torch.mean(self.dynamics.get_dist(prior, latent_type="latent_z").entropy())
            )
            metrics["z_post_ent"] = to_np(
                torch.mean(self.dynamics.get_dist(post, latent_type="latent_z").entropy())
            )
            metrics["c_prior_ent"] = to_np(
                torch.mean(self.dynamics.get_dist(prior, latent_type="latent_c").entropy())
            )
            metrics["c_post_ent"] = to_np(
                torch.mean(self.dynamics.get_dist(post, latent_type="latent_c").entropy())
            )
            context = dict(
                embed=embed,
                feat=self.dynamics.get_feat(post, latent_type="latent_z"),
                z_kl=z_kl_value,
                postent=self.dynamics.get_dist(post, latent_type="latent_z").entropy(),
            )
        post = {k: v.detach() for k, v in post.items()}
        return post, context, metrics

    # this function is called during both rollout and training
    def preprocess(self, obs):
        obs = obs.copy()
        obs["image"] = torch.Tensor(obs["image"]) / 255.0 - 0.5
        if "discount" in obs:
            obs["discount"] *= self._config.discount
            # (batch_size, batch_length) -> (batch_size, batch_length, 1)
            obs["discount"] = torch.Tensor(obs["discount"]).unsqueeze(-1)
        # 'is_first' is necesarry to initialize hidden state at training
        assert "is_first" in obs
        # 'is_terminal' is necesarry to train cont_head
        assert "is_terminal" in obs
        obs["cont"] = torch.Tensor(1.0 - obs["is_terminal"]).unsqueeze(-1)
        obs = {k: torch.Tensor(v).to(self._config.device) for k, v in obs.items()}
        return obs

    def video_pred(self, data):
        data = self.preprocess(data)
        embed = self.encoder(data)

        states, _ = self.dynamics.observe(
            embed[:6, :5], data["action"][:6, :5], data["is_first"][:6, :5]
        )
        post_z_c_feat = self.dynamics.get_feat(states, latent_type="latent_z_latent_c")
        recon = self.heads["decoder"](post_z_c_feat)["image"].mode()[
            :6
        ]
        reward_post = self.heads["reward"](self.dynamics.get_feat(states, latent_type="latent_z")).mode()[:6]
        init = {k: v[:, -1] for k, v in states.items()}
        prior = self.dynamics.imagine(data["action"][:6, 5:], init)

        prior_z_feat = self.dynamics.get_feat(prior, latent_type="latent_z")
        decoder_padding_input = torch.zeros(list(prior_z_feat.shape[:-1]) + [post_z_c_feat.shape[-1] - prior_z_feat.shape[-1]], device=prior_z_feat.device)
        decoder_input = torch.cat([prior_z_feat, decoder_padding_input], -1)

        openl = self.heads["decoder"](decoder_input)["image"].mode()
        reward_prior = self.heads["reward"](self.dynamics.get_feat(prior, latent_type="latent_z")).mode()

        full_traj_states, _ = self.dynamics.observe(
            embed[:6, :], data["action"][:6, :], data["is_first"][:6, :]
        )
        z_c_recon = self.heads["decoder"](self.dynamics.get_feat(full_traj_states, latent_type="latent_z_latent_c"))["image"].mode()[
            :6
        ]

        c_posterior = self.dynamics.get_feat(full_traj_states, latent_type="latent_z_latent_c")
        assert len(c_posterior.shape) == 3
        c_posterior[:, :, :prior_z_feat.shape[-1]] = 0.0
        c_posterior_recon = self.heads["decoder"](c_posterior)["image"].mode()

        z_posterior = self.dynamics.get_feat(full_traj_states, latent_type="latent_z_latent_c")
        assert len(z_posterior.shape) == 3
        z_posterior[:, :, prior_z_feat.shape[-1]:] = 0.0
        z_posterior_recon = self.heads["decoder"](z_posterior)["image"].mode()

        # observed image is given until 5 steps
        z_model_recon = torch.cat([recon[:, :5], openl], 1)
        truth = data["image"][:6] + 0.5
        z_model_recon = z_model_recon + 0.5
        z_c_recon = z_c_recon + 0.5
        c_posterior_recon = c_posterior_recon + 0.5
        z_posterior_recon = z_posterior_recon + 0.5

        error1 = (z_model_recon - truth + 1.0) / 2.0
        error2 = (z_c_recon - truth + 1.0) / 2.0
        error3 = (c_posterior_recon - truth + 1.0) / 2.0
        error4 = (z_posterior_recon - truth + 1.0) / 2.0

        return torch.cat([truth, z_model_recon, error1, z_c_recon, error2, c_posterior_recon, error3, z_posterior_recon, error4], 2)


class ImagBehavior(nn.Module):
    def __init__(self, config, world_model, stop_grad_actor=True, reward=None):
        super(ImagBehavior, self).__init__()
        self._use_amp = True if config.precision == 16 else False
        self._config = config
        self._world_model = world_model
        self._stop_grad_actor = stop_grad_actor
        self._reward = reward
        if config.z_dyn_discrete:
            z_feat_size = config.z_dyn_stoch * config.z_dyn_discrete + config.z_dyn_deter
        else:
            raise NotImplementedError
            # feat_size = config.dyn_stoch + config.dyn_deter
        self.actor = networks.ActionHead(
            z_feat_size,
            config.num_actions,
            config.actor_layers,
            config.units,
            config.act,
            config.norm,
            config.actor_dist,
            config.actor_init_std,
            config.actor_min_std,
            config.actor_max_std,
            config.actor_temp,
            outscale=1.0,
            unimix_ratio=config.action_unimix_ratio,
        )
        if config.value_head == "symlog_disc":
            self.value = networks.MLP(
                z_feat_size,
                (255,),
                config.value_layers,
                config.units,
                config.act,
                config.norm,
                config.value_head,
                outscale=0.0,
                device=config.device,
            )
        else:
            self.value = networks.MLP(
                z_feat_size,
                [],
                config.value_layers,
                config.units,
                config.act,
                config.norm,
                config.value_head,
                outscale=0.0,
                device=config.device,
            )
        if config.slow_value_target:
            self._slow_value = copy.deepcopy(self.value)
            self._updates = 0
        kw = dict(wd=config.weight_decay, opt=config.opt, use_amp=self._use_amp)
        self._actor_opt = tools.Optimizer(
            "actor",
            self.actor.parameters(),
            config.actor_lr,
            config.ac_opt_eps,
            config.actor_grad_clip,
            **kw,
        )
        self._value_opt = tools.Optimizer(
            "value",
            self.value.parameters(),
            config.value_lr,
            config.ac_opt_eps,
            config.value_grad_clip,
            **kw,
        )
        if self._config.reward_EMA:
            self.reward_ema = RewardEMA(device=self._config.device)

    def _train(
        self,
        start,
        objective=None,
        action=None,
        reward=None,
        imagine=None,
        tape=None,
        repeats=None,
    ):
        objective = objective or self._reward
        self._update_slow_target()
        metrics = {}

        with tools.RequiresGrad(self.actor):
            with torch.cuda.amp.autocast(self._use_amp):
                imag_feat, imag_state, imag_action = self._imagine(
                    start, self.actor, self._config.imag_horizon, repeats
                )
                reward = objective(imag_feat, imag_state, imag_action)
                actor_ent = self.actor(imag_feat).entropy()
                state_ent = self._world_model.dynamics.get_dist(imag_state, latent_type="latent_z").entropy()
                # this target is not scaled
                # slow is flag to indicate whether slow_target is used for lambda-return
                target, weights, base = self._compute_target(
                    imag_feat, imag_state, imag_action, reward, actor_ent, state_ent
                )
                actor_loss, mets = self._compute_actor_loss(
                    imag_feat,
                    imag_state,
                    imag_action,
                    target,
                    actor_ent,
                    state_ent,
                    weights,
                    base,
                )
                metrics.update(mets)
                value_input = imag_feat

        with tools.RequiresGrad(self.value):
            with torch.cuda.amp.autocast(self._use_amp):
                value = self.value(value_input[:-1].detach())
                target = torch.stack(target, dim=1)
                # (time, batch, 1), (time, batch, 1) -> (time, batch)
                value_loss = -value.log_prob(target.detach())
                slow_target = self._slow_value(value_input[:-1].detach())
                if self._config.slow_value_target:
                    value_loss = value_loss - value.log_prob(
                        slow_target.mode().detach()
                    )
                if self._config.value_decay:
                    value_loss += self._config.value_decay * value.mode()
                # (time, batch, 1), (time, batch, 1) -> (1,)
                value_loss = torch.mean(weights[:-1] * value_loss[:, :, None])

        metrics.update(tools.tensorstats(value.mode(), "value"))
        metrics.update(tools.tensorstats(target, "target"))
        metrics.update(tools.tensorstats(reward, "imag_reward"))
        if self._config.actor_dist in ["onehot"]:
            metrics.update(
                tools.tensorstats(
                    torch.argmax(imag_action, dim=-1).float(), "imag_action"
                )
            )
        else:
            metrics.update(tools.tensorstats(imag_action, "imag_action"))
        metrics["actor_entropy"] = to_np(torch.mean(actor_ent))
        with tools.RequiresGrad(self):
            metrics.update(self._actor_opt(actor_loss, self.actor.parameters()))
            metrics.update(self._value_opt(value_loss, self.value.parameters()))
        return imag_feat, imag_state, imag_action, weights, metrics

    def _imagine(self, start, policy, horizon, repeats=None):
        dynamics:networks.RSSM = self._world_model.dynamics
        if repeats:
            raise NotImplemented("repeats is not implemented in this version")
        flatten = lambda x: x.reshape([-1] + list(x.shape[2:]))
        start = {k: flatten(v) for k, v in start.items()}

        def step(prev, _):
            state, _, _ = prev
            feat = dynamics.get_feat(state, latent_type="latent_z")
            inp = feat.detach() if self._stop_grad_actor else feat
            action = policy(inp).sample()
            succ = dynamics.z_img_step(state, action, sample=self._config.imag_sample)
            return succ, feat, action

        # ([f_0~f_{horizon-1}], [s_1~s_{horizon}],[a_0~a_{horizon-1}])
        succ, feats, actions = tools.static_scan(
            step, [torch.arange(horizon)], (start, None, None)
        )
        states = {k: torch.cat([start[k][None], v[:-1]], 0) for k, v in succ.items()}
        if repeats:
            raise NotImplemented("repeats is not implemented in this version")

        return feats, states, actions  # ([f_0~f_{horizon-1}], [s_0~s_{horizon-1}],[a_0~a_{horizon-1}])

    def _compute_target(
        self, imag_feat, imag_state, imag_action, reward, actor_ent, state_ent
    ):
        if "cont" in self._world_model.heads:
            inp = self._world_model.dynamics.get_feat(imag_state, latent_type="latent_z")
            discount = self._config.discount * self._world_model.heads["cont"](inp).mean
        else:
            discount = self._config.discount * torch.ones_like(reward)
        if self._config.future_entropy and self._config.actor_entropy > 0:
            reward += self._config.actor_entropy * actor_ent
        if self._config.future_entropy and self._config.actor_state_entropy > 0:
            reward += self._config.actor_state_entropy * state_ent
        value = self.value(imag_feat).mode()
        target = tools.lambda_return(
            reward[1:],
            value[:-1],
            discount[1:],
            bootstrap=value[-1],
            lambda_=self._config.discount_lambda,
            axis=0,
        )
        weights = torch.cumprod(
            torch.cat([torch.ones_like(discount[:1]), discount[:-1]], 0), 0
        ).detach()
        return target, weights, value[:-1]

    def _compute_actor_loss(
        self,
        imag_feat,
        imag_state,
        imag_action,
        target,
        actor_ent,
        state_ent,
        weights,
        base,
    ):
        metrics = {}
        inp = imag_feat.detach() if self._stop_grad_actor else imag_feat
        policy = self.actor(inp)
        actor_ent = policy.entropy()
        # Q-val for actor is not transformed using symlog
        target = torch.stack(target, dim=1)
        if self._config.reward_EMA:
            offset, scale = self.reward_ema(target)
            normed_target = (target - offset) / scale
            normed_base = (base - offset) / scale
            adv = normed_target - normed_base
            metrics.update(tools.tensorstats(normed_target, "normed_target"))
            values = self.reward_ema.values
            metrics["EMA_005"] = to_np(values[0])
            metrics["EMA_095"] = to_np(values[1])

        if self._config.imag_gradient == "dynamics":
            actor_target = adv
        elif self._config.imag_gradient == "reinforce":
            actor_target = (
                policy.log_prob(imag_action)[:-1][:, :, None]
                * (target - self.value(imag_feat[:-1]).mode()).detach()
            )
        elif self._config.imag_gradient == "both":
            actor_target = (
                policy.log_prob(imag_action)[:-1][:, :, None]
                * (target - self.value(imag_feat[:-1]).mode()).detach()
            )
            mix = self._config.imag_gradient_mix
            actor_target = mix * target + (1 - mix) * actor_target
            metrics["imag_gradient_mix"] = mix
        else:
            raise NotImplementedError(self._config.imag_gradient)
        if not self._config.future_entropy and self._config.actor_entropy > 0:
            actor_entropy = self._config.actor_entropy * actor_ent[:-1][:, :, None]
            actor_target += actor_entropy
        if not self._config.future_entropy and (self._config.actor_state_entropy > 0):
            state_entropy = self._config.actor_state_entropy * state_ent[:-1]
            actor_target += state_entropy
            metrics["actor_state_entropy"] = to_np(torch.mean(state_entropy))
        actor_loss = -torch.mean(weights[:-1] * actor_target)
        return actor_loss, metrics

    def _update_slow_target(self):
        if self._config.slow_value_target:
            if self._updates % self._config.slow_target_update == 0:
                mix = self._config.slow_target_fraction
                for s, d in zip(self.value.parameters(), self._slow_value.parameters()):
                    d.data = mix * s.data + (1 - mix) * d.data
            self._updates += 1
