import math
from collections import OrderedDict

import hydra
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D
from dm_env import specs

import utils
from agent.dreamer import DreamerAgent, stop_gradient
import agent.dreamer_utils as common


def get_feat_ac(seq):
    return torch.cat([seq['feat'], seq['context'], seq['skill']], dim=-1)


class PEACDIAYNAgent(DreamerAgent):
    def __init__(self, update_skill_every_step, skill_dim, diayn_scale, num_init_frames,
                 task_number, task_scale, **kwargs):
        self.skill_dim = skill_dim
        self.task_number = task_number
        self.context_dim = task_number
        self.update_skill_every_step = update_skill_every_step
        self.num_init_frames = num_init_frames
        self.diayn_scale = diayn_scale
        self.task_scale = task_scale
        super().__init__(**kwargs)

        self.wm = WorldModel(self.cfg, self.obs_space, self.act_dim, self.tfstep,
                             task_number=self.task_number,
                             skill_dim=skill_dim).to(self.device)
        in_dim = self.wm.inp_size
        self.hidden_dim = in_dim
        self._task_behavior = ContextSkillActorCritic(self.cfg, self.act_spec, self.tfstep,
                                                      self.context_dim,
                                                      self.skill_dim,
                                                      discrete_skills=True).to(self.device)

        self._skill_behavior = ContextMetaCtrlAC(self.cfg, self.context_dim, self.skill_dim,
                                                 self.tfstep,
                                                 self._task_behavior,
                                                 frozen_skills=self.cfg.freeze_skills,
                                                 skill_len=int(1)).to(self.device)

        self.hidden_dim = self.cfg.diayn_hidden
        self.reward_free = True
        self.solved_meta = None
        self.requires_grad_(requires_grad=False)

    def finetune_mode(self):
        self.is_ft = True
        self.reward_free = False
        self._skill_behavior.rewnorm = common.StreamNorm(**{"momentum": 1.00, "scale": 1.0, "eps": 1e-8},
                                                         device=self.device)
        self._task_behavior.rewnorm = common.StreamNorm(**{"momentum": 1.00, "scale": 1.0, "eps": 1e-8},
                                                        device=self.device)
        self.cfg.actor_ent = 1e-4
        self.cfg.skill_actor_ent = 1e-4

    def get_meta_specs(self):
        return (specs.Array((self.skill_dim,), np.float32, 'skill'),)

    def init_meta(self):
        if self.solved_meta is not None:
            return self.solved_meta
        skill = np.zeros(self.skill_dim, dtype=np.float32)
        skill[np.random.choice(self.skill_dim)] = 1.0
        meta = OrderedDict()
        meta['skill'] = skill
        return meta

    def update_meta(self, meta, global_step, time_step):
        if global_step % self.update_skill_every_step == 0:
            return self.init_meta()
        return meta

    def act(self, obs, meta, step, eval_mode, state):
        obs = {k: torch.as_tensor(np.copy(v), device=self.device).unsqueeze(0) for k, v in obs.items()}
        meta = {k: torch.as_tensor(np.copy(v), device=self.device).unsqueeze(0) for k, v in meta.items()}

        if state is None:
            latent = self.wm.rssm.initial(len(obs['reward']))
            action = torch.zeros((len(obs['reward']),) + self.act_spec.shape, device=self.device)
        else:
            latent, action = state
        embed = self.wm.encoder(self.wm.preprocess(obs))
        should_sample = (not eval_mode) or (not self.cfg.eval_state_mean)
        latent, _ = self.wm.rssm.obs_step(latent, action, embed, obs['is_first'], should_sample)
        feat = self.wm.rssm.get_feat(latent)
        context, skill_pred = self.wm.task_skill_model(feat)
        context = F.softmax(context, dim=-1)

        # pretrain
        if self.reward_free:
            skill = meta['skill']
            if eval_mode:
                action = self._task_behavior.actor(torch.cat([feat, context, skill], dim=-1))
                action = action.mean
            else:
                action = self._task_behavior.actor(torch.cat([feat, context, skill], dim=-1))
                action = action.sample()
            new_state = (latent, action)
            return action.cpu().numpy()[0], new_state

        # fine tune
        if eval_mode:
            skill = self._skill_behavior.actor(torch.cat([feat, context], dim=-1))
            skill = skill.mode()
            action = self._task_behavior.actor(torch.cat([feat, context, skill], dim=-1))
            action = action.mean
        else:
            skill = self._skill_behavior.actor(torch.cat([feat, context], dim=-1))
            skill = skill.sample()
            action = self._task_behavior.actor(torch.cat([feat, context, skill], dim=-1))
            action = action.sample()
        new_state = (latent, action)
        return action.cpu().numpy()[0], new_state

    def init_from(self, other):
        # TODO: update task model?
        init_task = self.cfg.get('init_task', 1.0)
        init_critic = self.cfg.get('init_critic', False)
        init_actor = self.cfg.get('init_actor', True)

        # copy parameters over
        print(f"Copying the pretrained world model")
        utils.hard_update_params(other.wm.rssm, self.wm.rssm)
        utils.hard_update_params(other.wm.encoder, self.wm.encoder)
        utils.hard_update_params(other.wm.heads['decoder'], self.wm.heads['decoder'])

        if init_task > 0.0:
            print(f"Copying the task model")
            utils.hard_update_params(other.wm.task_skill_model, self.wm.task_skill_model)

        print(f"Copying the pretrained actor")
        utils.hard_update_params(other._task_behavior.actor, self._task_behavior.actor)

        if init_critic:
            print(f"Copying the pretrained critic")
            utils.hard_update_params(other._task_behavior.critic, self._task_behavior.critic)
            if self.cfg.slow_target:
                utils.hard_update_params(other._task_behavior._target_critic, self._task_behavior._target_critic)

    def compute_skill_reward(self, seq):
        B, T, _ = seq['skill'].shape
        skill = seq['skill'].reshape(B * T, -1)
        next_obs = seq['feat'].reshape(B * T, -1)

        z_hat = torch.argmax(skill, dim=1)
        context_pred, skill_pred = self.wm.task_skill_model(next_obs)  # B*T, skill_dim
        skill_pred_log_softmax = F.log_softmax(skill_pred, dim=1)
        _, pred_z = torch.max(skill_pred_log_softmax, dim=1, keepdim=True)
        skill_rew = skill_pred_log_softmax[torch.arange(skill_pred.shape[0]), z_hat] \
                    - math.log(1 / self.skill_dim)

        return skill_rew.reshape(B, T, 1) * self.diayn_scale

    def compute_context_reward(self, seq):
        B = seq['skill'].shape[0]
        T = seq['skill'].shape[1]
        next_obs = seq['feat'].reshape(B * T, -1)

        task_pred, skill_pred = self.wm.task_skill_model(next_obs)  # B*T, skill_dim
        task_truth = seq['task_id'].repeat(B, 1, 1).to(dtype=torch.int64)
        if self.cfg.reward_type == 1:
            task_pred = F.softmax(task_pred, dim=1)
            intr_rew = torch.zeros(task_pred.shape, device=self.device)  # 16, 2500, task_number
            # intr_rew = intr_rew.reshape(B*T, -1)
            intr_rew[torch.arange(B*T), task_truth.reshape(-1)] = 1.0
            # intr_rew = intr_rew.reshape(B, T, -1)
            task_rew = torch.sum(torch.square(intr_rew - task_pred), dim=1, keepdim=True)
            task_rew = task_rew.reshape(B, T, -1)
        # calculate the task model predict prob
        elif self.cfg.reward_type == 2:
            task_pred = F.log_softmax(task_pred, dim=1)
            intr_rew = task_pred[torch.arange(B*T), task_truth.reshape(-1)] # B*T
            task_rew = -intr_rew.reshape(B, T, 1)
        # calculate the task model predict prob - entropy
        elif self.cfg.reward_type == 3:
            task_pred = F.log_softmax(task_pred, dim=1)
            entropy = task_pred.sum(dim=1, keepdim=True) / task_pred.shape[1] # B*T, 1
            intr_rew = task_pred[torch.arange(B*T), task_truth.reshape(-1)] # B*T
            task_rew = - (intr_rew.reshape(-1, 1) - entropy).reshape(B, T, 1)
        else:
            raise Exception('Current reward type is {}, which is not supported'.
                            format(self.cfg.agent.reward_type))

        return task_rew * self.task_scale

    def update(self, data, step):
        metrics = {}
        state, outputs, mets = self.wm.update(data, state=None)
        metrics.update(mets)
        start = outputs['post']

        if self.cfg.reward_type == 4:
            start_prior = outputs['prior']
            with torch.no_grad():
                start['prior_feat'] = self.wm.rssm.get_feat(start_prior)

        start['task_id'] = data['task_id']
        start['context'] = outputs['context']
        start = {k: stop_gradient(v) for k, v in start.items()}

        # task_pred, skill_pred = self.peacdiayn(feat)
        # context = F.softmax(task_pred, dim=-1).detach()
        # start['context'] = stop_gradient(context)

        if self.reward_free:
            # reward_fn = lambda seq: self.compute_skill_reward(seq)
            # metrics.update(self._skill_behavior.update(
            #     self.wm, start, data['is_terminal'], reward_fn))

            # Train skill (module + AC)
            # start['feat'] = stop_gradient(self.wm.rssm.get_feat(start))
            reward_fn = lambda seq: self.compute_context_reward(seq) + \
                                    self.compute_skill_reward(seq)
            metrics.update(self._task_behavior.update(
                self.wm, start, data['is_terminal'], reward_fn))
        else:
            reward_fn = lambda seq: self.wm.heads['reward'](seq['feat']).mean  # .mode()
            metrics.update(self._skill_behavior.update(
                self.wm, start, data['is_terminal'], reward_fn))

            # metrics.update(self._task_behavior.update(
            #     self.wm, start, data['is_terminal'], reward_fn))
        return state, metrics


# feat + context + skill  -> action
class ContextSkillActorCritic(common.Module):
    def __init__(self, config, act_spec, tfstep, context_dim, skill_dim,
                 solved_meta=None, discrete_skills=True):
        super().__init__()
        self.cfg = config
        self.act_spec = act_spec
        self.tfstep = tfstep
        self._use_amp = (config.precision == 16)
        self.device = config.device

        self.discrete_skills = discrete_skills
        self.solved_meta = solved_meta

        # cat (context, skill)
        self.context_dim = context_dim
        self.skill_dim = skill_dim
        inp_size = config.rssm.deter
        if config.rssm.discrete:
            inp_size += config.rssm.stoch * config.rssm.discrete
        else:
            inp_size += config.rssm.stoch

        inp_size += context_dim
        inp_size += skill_dim
        self.actor = common.MLP(inp_size, act_spec.shape[0], **self.cfg.actor)
        self.critic = common.MLP(inp_size, (1,), **self.cfg.critic)
        if self.cfg.slow_target:
            self._target_critic = common.MLP(inp_size, (1,), **self.cfg.critic)
            self._updates = 0
        else:
            self._target_critic = self.critic
        self.actor_opt = common.Optimizer('context_skill_actor', self.actor.parameters(), **self.cfg.actor_opt,
                                          use_amp=self._use_amp)
        self.critic_opt = common.Optimizer('context_skill_critic', self.critic.parameters(), **self.cfg.critic_opt,
                                           use_amp=self._use_amp)
        self.rewnorm = common.StreamNorm(**self.cfg.context_skill_reward_norm, device=self.device)

    def update(self, world_model, start, is_terminal, reward_fn):
        metrics = {}
        hor = self.cfg.imag_horizon
        with common.RequiresGrad(self.actor):
            with torch.cuda.amp.autocast(enabled=self._use_amp):
                B, T, _ = start['deter'].shape
                context_pred = start['context'].reshape(B*T, -1)

                if self.solved_meta is not None:
                    img_skill = torch.from_numpy(self.solved_meta['skill']).repeat(B * T, 1).to(self.device)
                else:
                    if self.discrete_skills:
                        img_skill = F.one_hot(torch.randint(0, self.skill_dim,
                                                            size=(B * T,), device=self.device),
                                              num_classes=self.skill_dim).float()
                    else:
                        img_skill = torch.randn((B * T, self.skill_dim), device=self.device)
                        img_skill = img_skill / torch.norm(img_skill, dim=-1, keepdim=True)

                task_cond = torch.cat([context_pred, img_skill], dim=-1)

                seq = world_model.imagine(self.actor, start, is_terminal, hor,
                                          task_cond=task_cond)
                popped_task = seq.pop('task')
                # print('pppp', popped_task.shape)
                # print('c', self.context_dim)
                # print('s', self.skill_dim)
                # seq['context'] = popped_task[..., :self.context_dim]
                seq['context'] = popped_task[:, :, :self.context_dim]
                seq['skill'] = popped_task[:, :, self.context_dim:]
                # print('cc', seq['context'].shape)
                # print('ss', seq['skill'].shape)
                reward = reward_fn(seq)
                seq['reward'], mets1 = self.rewnorm(reward)
                mets1 = {f'context_skill_reward_{k}': v for k, v in mets1.items()}
                target, mets2 = self.target(seq)
                actor_loss, mets3 = self.actor_loss(seq, target)
            metrics.update(self.actor_opt(actor_loss, self.actor.parameters()))
        with common.RequiresGrad(self.critic):
            with torch.cuda.amp.autocast(enabled=self._use_amp):
                seq = {k: stop_gradient(v) for k, v in seq.items()}
                critic_loss, mets4 = self.critic_loss(seq, target)

                # start = {k: stop_gradient(v.transpose(0,1)) for k,v in start.items()}
                # start_target, _ = self.target(start)
                # critic_loss_start, _ = self.critic_loss(start, start_target)
                # critic_loss += critic_loss_start
            metrics.update(self.critic_opt(critic_loss, self.critic.parameters()))
        metrics.update(**mets1, **mets2, **mets3, **mets4)
        self.update_slow_target()  # Variables exist after first forward pass.
        return metrics

    def actor_loss(self, seq, target):  # , step):
        metrics = {}
        # Two states are lost at the end of the trajectory, one for the boostrap
        # value prediction and one because the corresponding action does not lead
        # anywhere anymore. One target is lost at the start of the trajectory
        # because the initial state comes from the replay buffer.
        policy = self.actor(stop_gradient(get_feat_ac(seq)[:-2]))
        if self.cfg.actor_grad == 'dynamics':
            objective = target[1:]
        elif self.cfg.actor_grad == 'reinforce':
            baseline = self._target_critic(get_feat_ac(seq)[:-2]).mean  # .mode()
            advantage = stop_gradient(target[1:] - baseline)
            objective = policy.log_prob(stop_gradient(seq['action'][1:-1]))[:, :, None] * advantage
        elif self.cfg.actor_grad == 'both':
            baseline = self._target_critic(get_feat_ac(seq)[:-2]).mean  # .mode()
            advantage = stop_gradient(target[1:] - baseline)
            objective = policy.log_prob(stop_gradient(seq['action'][1:-1]))[:, :, None] * advantage
            mix = utils.schedule(self.cfg.actor_grad_mix, self.tfstep)
            objective = mix * target[1:] + (1 - mix) * objective
            metrics['context_skill_actor_grad_mix'] = mix
        else:
            raise NotImplementedError(self.cfg.actor_grad)
        ent = policy.entropy()[:, :, None]
        ent_scale = utils.schedule(self.cfg.context_skill_actor_ent, self.tfstep)
        objective += ent_scale * ent
        weight = stop_gradient(seq['weight'])
        actor_loss = -(weight[:-2] * objective).mean()
        metrics['context_skill_actor_ent'] = ent.mean()
        metrics['context_skill_actor_ent_scale'] = ent_scale
        return actor_loss, metrics

    def critic_loss(self, seq, target):
        dist = self.critic(get_feat_ac(seq)[:-1])
        target = stop_gradient(target)
        weight = stop_gradient(seq['weight'])
        critic_loss = -(dist.log_prob(target)[:, :, None] * weight[:-1]).mean()
        metrics = {'context_skill_critic': dist.mean.mean()}  # .mode().mean()}
        return critic_loss, metrics

    def target(self, seq):
        reward = seq['reward']
        disc = seq['discount']
        # print('1', seq['feat'].shape)
        # print('2', seq['context'].shape)
        value = self._target_critic(get_feat_ac(seq)).mean  # .mode()
        # Skipping last time step because it is used for bootstrapping.
        target = common.lambda_return(
            reward[:-1], value[:-1], disc[:-1],
            bootstrap=value[-1],
            lambda_=self.cfg.discount_lambda,
            axis=0)
        metrics = {}
        metrics['context_skill_critic_slow'] = value.mean()
        metrics['context_skill_critic_target'] = target.mean()
        return target, metrics

    def update_slow_target(self):
        if self.cfg.slow_target:
            if self._updates % self.cfg.slow_target_update == 0:
                mix = 1.0 if self._updates == 0 else float(
                    self.cfg.slow_target_fraction)
                for s, d in zip(self.critic.parameters(), self._target_critic.parameters()):
                    d.data = mix * s.data + (1 - mix) * d.data
            self._updates += 1  # .assign_add(1)


# feat + context -> (choose) skill
class ContextMetaCtrlAC(common.Module):
    def __init__(self, config, context_dim, skill_dim, tfstep, skill_executor, frozen_skills=False, skill_len=1):
        super().__init__()
        self.cfg = config

        self.context_dim = context_dim
        self.skill_dim = skill_dim
        self.tfstep = tfstep
        self.skill_executor = skill_executor
        self._use_amp = (config.precision == 16)
        self.device = config.device

        inp_size = config.rssm.deter
        if config.rssm.discrete:
            inp_size += config.rssm.stoch * config.rssm.discrete
        else:
            inp_size += config.rssm.stoch
        inp_size += self.context_dim

        actor_config = {'layers': 4, 'units': 400, 'norm': 'none', 'dist': 'trunc_normal'}
        actor_config['dist'] = 'onehot'
        self.actor = common.MLP(inp_size, skill_dim, **actor_config)
        self.critic = common.MLP(inp_size, (1,), **self.cfg.critic)
        if self.cfg.slow_target:
            self._target_critic = common.MLP(inp_size, (1,), **self.cfg.critic)
            self._updates = 0
        else:
            self._target_critic = self.critic

        self.termination = False
        self.skill_len = skill_len

        self.selector_opt = common.Optimizer('selector_actor', self.actor.parameters(), **self.cfg.actor_opt,
                                             use_amp=self._use_amp)
        self.executor_opt = common.Optimizer('executor_actor', self.skill_executor.actor.parameters(),
                                             **self.cfg.actor_opt, use_amp=self._use_amp)
        self.critic_opt = common.Optimizer('selector_critic', self.critic.parameters(), **self.cfg.critic_opt,
                                           use_amp=self._use_amp)
        self.rewnorm = common.StreamNorm(**self.cfg.reward_norm, device=self.device)

    def update(self, world_model, start, is_terminal, reward_fn):
        metrics = {}
        hor = self.cfg.imag_horizon
        with common.RequiresGrad(self.actor):
            with common.RequiresGrad(self.skill_executor.actor):
                with torch.cuda.amp.autocast(enabled=self._use_amp):
                    seq = self.selector_imagine(world_model, self.actor, start, is_terminal, hor)
                    reward = reward_fn(seq)
                    seq['reward'], mets1 = self.rewnorm(reward)
                    mets1 = {f'reward_{k}': v for k, v in mets1.items()}
                    target, mets2 = self.target(seq)
                    high_actor_loss, low_actor_loss, mets3 = self.actor_loss(seq, target)
                metrics.update(self.selector_opt(high_actor_loss, self.actor.parameters()))
                metrics.update(self.executor_opt(low_actor_loss, self.skill_executor.actor.parameters()))
        with common.RequiresGrad(self.critic):
            with torch.cuda.amp.autocast(enabled=self._use_amp):
                seq = {k: stop_gradient(v) for k, v in seq.items()}
                critic_loss, mets4 = self.critic_loss(seq, target)
            metrics.update(self.critic_opt(critic_loss, self.critic.parameters()))
        metrics.update(**mets1, **mets2, **mets3, **mets4)
        self.update_slow_target()  # Variables exist after first forward pass.
        return metrics

    def actor_loss(self, seq, target):
        self.tfstep = 0
        metrics = {}
        skill = stop_gradient(seq['skill'])
        action = stop_gradient(seq['action'])
        policy = self.actor(stop_gradient(torch.cat([seq['feat'][:-2], seq['context'][:-2]],
                                                    dim=-1)))
        low_inp = stop_gradient(torch.cat([seq['feat'][:-2], seq['context'][:-2],
                                          skill[:-2]], dim=-1))
        low_policy = self.skill_executor.actor(low_inp)
        if self.cfg.actor_grad == 'dynamics':
            low_objective = target[1:]

        ent_scale = utils.schedule(self.cfg.actor_ent, self.tfstep)
        weight = stop_gradient(seq['weight'])

        low_ent = low_policy.entropy()[:, :, None]
        high_ent = policy.entropy()[:, :, None]

        baseline = self._target_critic(torch.cat([seq['feat'][:-2], seq['context'][:-2]],
                                                 dim=-1)).mean
        advantage = stop_gradient(target[1:] - baseline)
        log_probs = policy.log_prob(skill[1:-1])[:, :, None]

        # Note: this is impactful only if skill_len > 1. In Choreographer we fixed skill_len to 1
        indices = torch.arange(0, log_probs.shape[0], step=self.skill_len, device=self.device)
        advantage = torch.index_select(advantage, 0, indices)
        log_probs = torch.index_select(log_probs, 0, indices)
        high_ent = torch.index_select(high_ent, 0, indices)
        high_weight = torch.index_select(weight[:-2], 0, indices)

        high_objective = log_probs * advantage
        if getattr(self, 'reward_smoothing', False):
            high_objective *= 0
            low_objective *= 0

        high_objective += ent_scale * high_ent
        high_actor_loss = -(high_weight * high_objective).mean()
        low_actor_loss = -(weight[:-2] * low_objective).mean()

        metrics['high_actor_ent'] = high_ent.mean()
        metrics['low_actor_ent'] = low_ent.mean()
        metrics['skills_updated'] = len(torch.unique(torch.argmax(skill, dim=-1)))
        return high_actor_loss, low_actor_loss, metrics

    def critic_loss(self, seq, target):
        dist = self.critic(torch.cat([seq['feat'][:-1], seq['context'][:-1]], dim=-1))
        target = stop_gradient(target)
        weight = stop_gradient(seq['weight'])
        critic_loss = -(dist.log_prob(target)[:, :, None] * weight[:-1]).mean()
        metrics = {'critic': dist.mean.mean()}
        return critic_loss, metrics

    def target(self, seq):
        reward = seq['reward']
        disc = seq['discount']
        value = self._target_critic(torch.cat([seq['feat'], seq['context']], dim=-1)).mean
        # Skipping last time step because it is used for bootstrapping.
        target = common.lambda_return(
            reward[:-1], value[:-1], disc[:-1],
            bootstrap=value[-1],
            lambda_=self.cfg.discount_lambda,
            axis=0)
        metrics = {}
        metrics['critic_slow'] = value.mean()
        metrics['critic_target'] = target.mean()
        return target, metrics

    def update_slow_target(self):
        if self.cfg.slow_target:
            if self._updates % self.cfg.slow_target_update == 0:
                mix = 1.0 if self._updates == 0 else float(
                    self.cfg.slow_target_fraction)
                for s, d in zip(self.critic.parameters(), self._target_critic.parameters()):
                    d.data = mix * s.data + (1 - mix) * d.data
            self._updates += 1

    def selector_imagine(self, wm, policy, start, is_terminal, horizon, eval_policy=False):
        flatten = lambda x: x.reshape([-1] + list(x.shape[2:]))
        start = {k: flatten(v) for k, v in start.items()}
        start['feat'] = wm.rssm.get_feat(start)
        context, skill_pred = wm.task_skill_model(start['feat'])
        start['context'] = F.softmax(context, dim=-1)
        inp = torch.cat([start['feat'], start['context']], dim=-1)
        fake_skill = policy(inp).mean
        fake_action = self.skill_executor.actor(torch.cat([inp, fake_skill], dim=-1)).mean
        B, _ = fake_action.shape
        start['skill'] = torch.zeros_like(fake_skill, device=wm.device)
        start['action'] = torch.zeros_like(fake_action, device=wm.device)
        seq = {k: [v] for k, v in start.items()}
        for h in range(horizon):
            inp = stop_gradient(torch.cat([seq['feat'][-1], seq['context'][-1]], dim=-1))
            if h % self.skill_len == 0:
                skill = policy(inp)
                if not eval_policy:
                    skill = skill.sample()
                else:
                    skill = skill.mode()

            executor_inp = stop_gradient(torch.cat([inp, skill], dim=-1))
            action = self.skill_executor.actor(executor_inp)
            action = action.sample() if not eval_policy else action.mean
            state = wm.rssm.img_step({k: v[-1] for k, v in seq.items()}, action)
            feat = wm.rssm.get_feat(state)
            context, skill_pred = wm.task_skill_model(feat)
            context = F.softmax(context, dim=-1)
            for key, value in {**state, 'action': action, 'feat': feat, 'skill': skill,
                               'context': context, }.items():
                seq[key].append(value)
        # shape will be (T, B, *DIMS)
        seq = {k: torch.stack(v, 0) for k, v in seq.items()}
        if 'discount' in wm.heads:
            disc = wm.heads['discount'](seq['feat']).mean()
            if is_terminal is not None:
                # Override discount prediction for the first step with the true
                # discount factor from the replay buffer.
                true_first = 1.0 - flatten(is_terminal)
                true_first *= wm.cfg.discount
                disc = torch.cat([true_first[None], disc[1:]], 0)
        else:
            disc = wm.cfg.discount * torch.ones(list(seq['feat'].shape[:-1]) + [1], device=wm.device)
        seq['discount'] = disc
        # Shift discount factors because they imply whether the following state
        # will be valid, not whether the current state is valid.
        seq['weight'] = torch.cumprod(
            torch.cat([torch.ones_like(disc[:1], device=wm.device), disc[:-1]], 0), 0)
        return seq


# feat -> context, skill
class PEACDIAYN(nn.Module):
    def __init__(self, obs_dim, skill_dim, task_number, hidden_dim):
        super().__init__()
        self.feat_net = nn.Sequential(nn.Linear(obs_dim, hidden_dim),
                                      nn.ReLU(),
                                      nn.Linear(hidden_dim, hidden_dim),
                                      nn.ReLU())
        self.skill_pred_net = nn.Linear(hidden_dim, skill_dim)
        self.task_pred_net = nn.Linear(hidden_dim, task_number)

        self.apply(utils.weight_init)

    def forward(self, obs):
        feat = self.feat_net(obs)
        skill_pred = self.skill_pred_net(feat)
        task_pred = self.task_pred_net(feat)
        return task_pred, skill_pred


class WorldModel(common.Module):
    def __init__(self, config, obs_space, act_dim, tfstep, task_number=1, skill_dim=1):
        super().__init__()
        shapes = {k: tuple(v.shape) for k, v in obs_space.items()}
        self.cfg = config
        self.device = config.device
        self.tfstep = tfstep
        self.encoder = common.Encoder(shapes, **config.encoder)
        # Computing embed dim
        with torch.no_grad():
            zeros = {k: torch.zeros((1,) + v) for k, v in shapes.items()}
            outs = self.encoder(zeros)
            embed_dim = outs.shape[1]
        self.embed_dim = embed_dim
        self.rssm = common.EnsembleRSSM(**config.rssm, action_dim=act_dim, embed_dim=embed_dim, device=self.device)
        self.heads = {}
        self._use_amp = (config.precision == 16)
        inp_size = config.rssm.deter
        if config.rssm.discrete:
            inp_size += config.rssm.stoch * config.rssm.discrete
        else:
            inp_size += config.rssm.stoch
        self.inp_size = inp_size
        self.heads['decoder'] = common.Decoder(shapes, **config.decoder, embed_dim=inp_size)
        self.heads['reward'] = common.MLP(inp_size, (1,), **config.reward_head)

        self.task_number = task_number
        if self.task_number > 1:
            # self.heads['task_id'] = common.MLP(inp_size, (task_number,), **config.task_head)
            self.task_skill_model = PEACDIAYN(inp_size, skill_dim, task_number,
                                              config.task_head.units)
            self.task_skill_criterion = nn.CrossEntropyLoss()
        else:
            self.task_skill_model = None
        if config.pred_discount:
            self.heads['discount'] = common.MLP(inp_size, (1,), **config.discount_head)
        for name in config.grad_heads:
            assert name in self.heads, name
        self.grad_heads = config.grad_heads
        self.heads = nn.ModuleDict(self.heads)
        self.model_opt = common.Optimizer('model', self.parameters(), **config.model_opt, use_amp=self._use_amp)

    def update(self, data, state=None):
        with common.RequiresGrad(self):
            with torch.cuda.amp.autocast(enabled=self._use_amp):
                model_loss, state, outputs, metrics = self.loss(data, state)
            metrics.update(self.model_opt(model_loss, self.parameters()))
        return state, outputs, metrics

    def loss(self, data, state=None):
        data = self.preprocess(data)
        embed = self.encoder(data)
        post, prior = self.rssm.observe(
            embed, data['action'], data['is_first'], state)
        kl_loss, kl_value = self.rssm.kl_loss(post, prior, **self.cfg.kl)
        assert len(kl_loss.shape) == 0 or (len(kl_loss.shape) == 1 and kl_loss.shape[0] == 1), kl_loss.shape
        likes = {}
        losses = {'kl': kl_loss}
        feat = self.rssm.get_feat(post)
        for name, head in self.heads.items():
            # print('name', name)
            # print('input', feat.shape)
            grad_head = (name in self.grad_heads)
            inp = feat if grad_head else stop_gradient(feat)
            out = head(inp)
            dists = out if isinstance(out, dict) else {name: out}
            for key, dist in dists.items():
                # print('key:', key)
                # print('output', data[key].shape)
                # print('dist', dist)
                like = dist.log_prob(data[key])
                likes[key] = like
                losses[key] = -like.mean()

        if self.task_number > 1:
            # print('task task task')
            task_pred, d_pred = self.task_skill_model(feat)
            B, T, _ = task_pred.shape
            task_pred = task_pred.reshape(B*T, -1)
            task_id_key = data['task_id'].reshape(task_pred.shape[0]).to(torch.int64)
            # print('oo', out.shape)
            # print('key', data['task_id'].shape)
            # print(task_id_key)
            losses['task_id'] = F.cross_entropy(task_pred, task_id_key)
            context = F.softmax(task_pred.reshape(B, T, -1), dim=2)

            z_hat = torch.argmax(data['skill'], dim=1)
            d_pred_log_softmax = F.log_softmax(d_pred, dim=1)
            _, pred_z = torch.max(d_pred_log_softmax, dim=1, keepdim=True)
            d_loss = self.task_skill_criterion(d_pred.reshape(B*T, -1), z_hat.reshape(-1))
            # df_accuracy = torch.sum(
            #     torch.eq(z_hat,
            #              pred_z.reshape(1,
            #                             list(
            #                                 pred_z.size())[0])[0])).float() / list(
            #     pred_z.size())[0]
            losses['skill'] = d_loss

        model_loss = sum(
            self.cfg.loss_scales.get(k, 1.0) * v for k, v in losses.items())
        outs = dict(
            embed=embed, feat=feat, post=post,
            prior=prior, likes=likes, kl=kl_value, context=context)
        metrics = {f'{name}_loss': value for name, value in losses.items()}
        metrics['model_kl'] = kl_value.mean()
        metrics['prior_ent'] = self.rssm.get_dist(prior).entropy().mean()
        metrics['post_ent'] = self.rssm.get_dist(post).entropy().mean()
        last_state = {k: v[:, -1] for k, v in post.items()}
        # last_state.keys(): stoch, deter, logits
        return model_loss, last_state, outs, metrics

    def imagine(self, policy, start, is_terminal, horizon, task_cond=None, eval_policy=False):
        flatten = lambda x: x.reshape([-1] + list(x.shape[2:]))
        start = {k: flatten(v) for k, v in start.items()}
        start['feat'] = self.rssm.get_feat(start)
        inp = start['feat'] if task_cond is None else torch.cat([start['feat'], task_cond], dim=-1)
        start['action'] = torch.zeros_like(policy(inp).mean, device=self.device)  # .mode())
        seq = {k: [v] for k, v in start.items()}
        if task_cond is not None:
            seq['task'] = [task_cond]
        for _ in range(horizon):
            inp = seq['feat'][-1] if task_cond is None else torch.cat([seq['feat'][-1], task_cond], dim=-1)
            action = policy(stop_gradient(inp)).sample() if not eval_policy else policy(stop_gradient(inp)).mean
            state = self.rssm.img_step({k: v[-1] for k, v in seq.items()}, action)
            feat = self.rssm.get_feat(state)
            for key, value in {**state, 'action': action, 'feat': feat}.items():
                seq[key].append(value)
            if task_cond is not None:
                seq['task'].append(task_cond)
        # shape will be (T, B, *DIMS)
        seq = {k: torch.stack(v, 0) for k, v in seq.items()}
        if 'discount' in self.heads:
            disc = self.heads['discount'](seq['feat']).mean()
            if is_terminal is not None:
                # Override discount prediction for the first step with the true
                # discount factor from the replay buffer.
                true_first = 1.0 - flatten(is_terminal)
                true_first *= self.cfg.discount
                disc = torch.cat([true_first[None], disc[1:]], 0)
        else:
            disc = self.cfg.discount * torch.ones(list(seq['feat'].shape[:-1]) + [1], device=self.device)
        seq['discount'] = disc
        # Shift discount factors because they imply whether the following state
        # will be valid, not whether the current state is valid.
        seq['weight'] = torch.cumprod(
            torch.cat([torch.ones_like(disc[:1], device=self.device), disc[:-1]], 0), 0)
        return seq

    def preprocess(self, obs):
        obs = obs.copy()
        for key, value in obs.items():
            if key.startswith('log_'):
                continue
            if value.dtype in [np.uint8, torch.uint8]:
                value = value / 255.0 - 0.5
            obs[key] = value
        obs['reward'] = {
            'identity': nn.Identity(),
            'sign': torch.sign,
            'tanh': torch.tanh,
        }[self.cfg.clip_rewards](obs['reward'])
        obs['discount'] = 1.0 - obs['is_terminal'].float()
        obs['discount'] *= self.cfg.discount
        return obs

    def video_pred(self, data, key, nvid=8):
        decoder = self.heads['decoder']  # B, T, C, H, W
        truth = data[key][:nvid] + 0.5
        embed = self.encoder(data)
        states, _ = self.rssm.observe(
            embed[:nvid, :5], data['action'][:nvid, :5], data['is_first'][:nvid, :5])
        recon = decoder(self.rssm.get_feat(states))[key].mean[:nvid]  # mode
        init = {k: v[:, -1] for k, v in states.items()}
        prior = self.rssm.imagine(data['action'][:nvid, 5:], init)
        prior_recon = decoder(self.rssm.get_feat(prior))[key].mean  # mode
        model = torch.clip(torch.cat([recon[:, :5] + 0.5, prior_recon + 0.5], 1), 0, 1)
        error = (model - truth + 1) / 2

        if getattr(self, 'recon_skills', False):
            prior_feat = self.rssm.get_feat(prior)
            if self.skill_module.discrete_skills:
                B, T, _ = prior['deter'].shape
                z_e = self.skill_module.skill_encoder(prior['deter'].reshape(B * T, -1)).mean
                z_q, _ = self.skill_module.emb(z_e, weight_sg=True)
                latent_skills = z_q.reshape(B, T, -1)
            else:
                latent_skills = self.skill_module.skill_encoder(prior['deter']).mean
                latent_skills = latent_skills / torch.norm(latent_skills, dim=-1, keepdim=True)

            x = deter = self.skill_module.skill_decoder(latent_skills).mean

            stats = self.rssm._suff_stats_ensemble(x)
            index = torch.randint(0, self.rssm._ensemble, ())
            stats = {k: v[index] for k, v in stats.items()}
            dist = self.rssm.get_dist(stats)
            stoch = dist.sample()
            prior = {'stoch': stoch, 'deter': deter, **stats}
            skill_recon = decoder(self.rssm.get_feat(prior))[key].mean  # mode
            error = torch.clip(torch.cat([recon[:, :5] + 0.5, skill_recon + 0.5], 1), 0, 1)

        video = torch.cat([truth, model, error], 3)
        B, T, C, H, W = video.shape
        return video
