import copy

import hydra
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D

import utils
from agent.peac import PEACAgent, stop_gradient
import agent.dreamer_utils as common


class PEAC_LBSAgent(PEACAgent):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.reward_free = True

        # LBS
        self.lbs = common.MLP(self.wm.inp_size, (1,), **self.cfg.reward_head).to(self.device)
        self.lbs_opt = common.Optimizer('lbs', self.lbs.parameters(), **self.cfg.model_opt, use_amp=self._use_amp)
        self.lbs.train()

        self.context_lbs = common.MLP(self.wm.inp_size, (1,),
                                      **self.cfg.reward_head).to(self.device)
        self.context_lbs_opt = common.Optimizer('context_lbs', self.context_lbs.parameters(),
                                                **self.cfg.model_opt, use_amp=self._use_amp)
        self.context_lbs.train()

        self.requires_grad_(requires_grad=False)

    def update_lbs(self, outs):
        metrics = dict()
        B, T, _ = outs['feat'].shape
        feat, kl = outs['feat'].detach(), outs['kl'].detach()
        feat = feat.reshape(B * T, -1)
        kl = kl.reshape(B * T, -1)

        loss = -self.lbs(feat).log_prob(kl).mean()
        metrics.update(self.lbs_opt(loss, self.lbs.parameters()))
        metrics['lbs_loss'] = loss.item()
        return metrics

    def update_context_lbs(self, outs):
        metrics = dict()
        B, T, _ = outs['feat'].shape
        feat = outs['feat'].detach()
        prior_feat = self.wm.rssm.get_feat(outs['post']).detach()
        feat = feat.reshape(B * T, -1)
        prior_feat = prior_feat.reshape(B * T, -1)
        out = F.softmax(self.wm.task_model(feat), dim=-1)
        prior_out = F.softmax(self.wm.task_model(prior_feat), dim=-1)
        out_dist = D.Categorical(out)
        prior_out_dist = D.Categorical(prior_out)
        kl = D.kl_divergence(out_dist, prior_out_dist).reshape(-1, 1)
        kl = kl.detach()

        loss = -self.context_lbs(feat).log_prob(kl).mean()
        metrics.update(self.context_lbs_opt(loss, self.context_lbs.parameters()))
        metrics['context_lbs_loss'] = loss.item()
        return metrics

    def update(self, data, step):
        metrics = {}
        state, outputs, mets = self.wm.update(data, state=None)
        metrics.update(mets)
        start = outputs['post']
        if self.cfg.reward_type == 4:
            start_prior = outputs['prior']
            with torch.no_grad():
                start['prior_feat'] = self.wm.rssm.get_feat(start_prior)
        start['task_id'] = data['task_id']
        start['context'] = outputs['context']
        start = {k: stop_gradient(v) for k, v in start.items()}

        if self.reward_free:
            with common.RequiresGrad(self.lbs):
                with torch.cuda.amp.autocast(enabled=self._use_amp):
                    metrics.update(self.update_lbs(outputs))
            if self.cfg.reward_type == 5:
                with common.RequiresGrad(self.context_lbs):
                    with torch.cuda.amp.autocast(enabled=self._use_amp):
                        metrics.update(self.update_context_lbs(outputs))
            reward_fn = lambda seq: self.lbs(seq['feat']).mean + \
                                    self.compute_task_reward(seq)
        else:
            reward_fn = lambda seq: self.wm.heads['reward'](seq['feat']).mean  # .mode()

        metrics.update(self._task_behavior.update(
            self.wm, start, data['is_terminal'], reward_fn))
        return state, metrics

    def compute_task_reward(self, seq):
        # print('we use calculated reward')
        B, T, _ = seq['feat'].shape
        task_pred = self.wm.task_model(seq['feat'])
        task_truth = seq['task_id'].repeat(B, 1, 1).to(dtype=torch.int64)
        # print(task_pred.shape) # 16, 2500, task_number
        # print(seq['action'].shape) # 16, 2500, _
        # print(task_truth.shape) # 16, 2500, 1

        # calculate the task model prediction loss v1
        # intr_rew = torch.zeros(task_pred.shape, device=self.device)  # 16, 2500, task_number
        # for B_id in range(B):
        #     for T_id in range(T):
        #         intr_rew[B_id][T_id][task_truth[B_id][T_id][0]] = 1.0

        # calculate the task model prediction loss v2
        if self.cfg.reward_type == 1:
            task_pred = F.softmax(task_pred, dim=2)
            task_rew = torch.zeros(task_pred.shape, device=self.device)  # 16, 2500, task_number
            task_rew = task_rew.reshape(B * T, -1)
            task_rew[torch.arange(B * T), task_truth.reshape(-1)] = 1.0
            task_rew = task_rew.reshape(B, T, -1)
            task_rew = torch.sum(torch.square(task_rew - task_pred), dim=2, keepdim=True)
        # calculate the task model predict prob
        elif self.cfg.reward_type == 2:
            task_pred = F.log_softmax(task_pred, dim=2)
            task_rew = task_pred.reshape(B * T, -1)[torch.arange(B * T), task_truth.reshape(-1)]
            task_rew = -task_rew.reshape(B, T, 1)
        # calculate the task model predict prob - entropy
        elif self.cfg.reward_type == 3:
            task_pred = F.log_softmax(task_pred, dim=2)
            entropy = task_pred.sum(dim=2, keepdim=True) / task_pred.shape[2]  # B, T, 1
            task_rew = task_pred.reshape(B * T, -1)[torch.arange(B * T), task_truth.reshape(-1)]
            task_rew = -(task_rew.reshape(B, T, 1) - entropy)
        elif self.cfg.reward_type == 4:
            # kld = D.kl_divergence
            task_pred = F.log_softmax(task_pred, dim=2)
            task_prior_pred = F.log_softmax(self.wm.task_model(seq['prior_feat'].
                                                               repeat(B, 1, 1)), dim=2)
            kl_divergence_value = F.kl_div(task_pred.reshape(B*T, -1),
                                           task_prior_pred.reshape(B*T, -1),
                                           reduction='none', log_target=True)
            task_rew = kl_divergence_value.sum(dim=1, keepdim=True).reshape(B, T, 1)
            # task_pred_dist = D.Categorical(task_pred)
            # task_prior_pred_dist = D.Categorical(task_prior_pred)
            # task_rew = D.kl_divergence(task_pred_dist, task_prior_pred_dist).reshape(B, T, 1)
        elif self.cfg.reward_type == 5:
            # kld = D.kl_divergence
            task_rew = self.context_lbs(seq['feat']).mean
        else:
            raise Exception('Current reward type is {}, which is not supported'.
                            format(self.cfg.agent.reward_type))

        # print(intr_rew.shape) # 16, 2500, 1
        return task_rew
