from matplotlib.pyplot import cla
import torch
import torch.nn as nn
import torch.nn.functional as F
import einops
import models.latent_generators.latent_generator as latent_generator
from torch.distributions import Categorical

import models.libraries.mingpt.model as mingpt_model
import models.libraries.mingpt.trainer as mingpt_trainer
from models.libraries.loss_fn import FocalLoss, soft_cross_entropy

from typing import Optional, Tuple
from utils import batch_indexing


class MinGPT(latent_generator.AbstractLatentGenerator):
    def __init__(
        self,
        input_dim: int,
        n_layer: int = 12,
        n_head: int = 12,
        n_embd: int = 768,
        n_embd_is_per_head: bool = False,
        embd_pdrop: float = 0.1,
        resid_pdrop: float = 0.1,
        attn_pdrop: float = 0.1,
        vocab_size: int = 50257,
        latent_dim: int = 768,  # Ignore, used for compatibility with other models.
        action_dim: int = 0,
        discrete_input: bool = False,
        predict_offsets: bool = False,
        offset_loss_scale: float = 1.0,
        focal_loss_gamma: float = 0.0,
        use_reg: bool = False,
        reg_type: str = 'kl',
        use_prior_weight: bool = False,
        prior_weight_scale: float = 1.0,
        coef_start: float = 0.0,
        coef_end: float = 0.0,
        coef_decay_start: int = 0,
        coef_decay_end: int = 0,
        goal_type: str = None,
        use_skill_head: bool = False,
        use_double_kl: bool = False,
        fix_prior: bool = False,
        goal_conditional: Optional[str] = None,
        obs_seq_lenth: Optional[int] = None,
        goal_seq_lenth: Optional[int] = None,
        subgoal_seq_lenth: Optional[int] = None,
        goal_dim: Optional[int] = None,
        use_img_obs: Optional[bool] = None,
        **kwargs
    ):
        super().__init__()
        self.input_size = input_dim
        self.n_layer = n_layer
        self.n_head = n_head
        if n_embd_is_per_head:
            self.n_embd = n_head * n_embd
        else:
            self.n_embd = n_embd
        self.embd_pdrop = embd_pdrop
        self.resid_pdrop = resid_pdrop
        self.attn_pdrop = attn_pdrop
        self.vocab_size = vocab_size # n_bins
        self.action_dim = action_dim
        self.predict_offsets = predict_offsets
        self.offset_loss_scale = offset_loss_scale
        self.focal_loss_gamma = focal_loss_gamma
        self.use_reg = use_reg
        self.reg_type = reg_type
        self.use_prior_weight = use_prior_weight
        self.prior_weight_scale = prior_weight_scale
        self.coef_start = coef_start
        self.coef_end = coef_end
        self.coef_decay_start = coef_decay_start
        self.coef_decay_end = coef_decay_end
        self.goal_conditional = goal_conditional
        self.obs_seq_lenth = obs_seq_lenth
        self.goal_seq_lenth = goal_seq_lenth
        self.subgoal_seq_lenth = subgoal_seq_lenth + obs_seq_lenth
        self.goal_dim = goal_dim
        self.goal_type = goal_type
        self.use_img_obs = use_img_obs
        self.use_skill_head = use_skill_head
        self.use_double_kl = use_double_kl
        self.fix_prior = fix_prior
        self.aug_noise = 0.1

        for k, v in kwargs.items():
            setattr(self, k, v)

        effective_vocab_size = self.vocab_size
        effective_vocab_size += self.vocab_size * self.action_dim
        effective_input_size = self.input_size

        gpt_config = mingpt_model.GPTConfig(
            input_size=effective_input_size,
            goal_size=goal_dim,
            vocab_size=effective_vocab_size,
            n_layer=self.n_layer,
            n_head=self.n_head,
            n_embd=self.n_embd,
            embd_pdrop=embd_pdrop,
            resid_pdrop=resid_pdrop,
            attn_pdrop=attn_pdrop,
            goal_type= self.goal_type,
            goal_seq_lenth = self.goal_seq_lenth,
            obs_seq_lenth = self.obs_seq_lenth,
            subgoal_seq_lenth = self.subgoal_seq_lenth,
            use_img_obs = self.use_img_obs,
            use_skill_head = self.use_skill_head,
        )

        self.model = mingpt_model.GPT(gpt_config)
        self.cur_coef = self.coef_start
        self.coef_step = (self.coef_end - self.coef_start) / (self.coef_decay_end - self.coef_decay_start)

    def _predict(
        self,
        self_obs_rep: torch.Tensor,
        img_obs_rep: torch.Tensor,
        obs_mask: torch.Tensor,
        goal: torch.Tensor,
        goal_mask: torch.Tensor,
    ):
        """
        Run a forward pass given observation representations and (optionally) goals.
        Arguments:
            obs_rep: Tensor[N, T, E], observation representations.
            goal: Tensor[N, T_goal, G], goal conditionals.
            N: batch size, T: sequence length, E: observation embedding size, G: goal size.
        Returns:
            A batch of predicted actions.
        """
        policy_output, _, _ = self.model.forward_policy(self_obs_rep, img_obs_rep, obs_mask, goal, goal_mask)

        policy_output = policy_output[:, self.goal_seq_lenth:, :]  # [N, T_GOAL+T_OBS, *] --> [N, T_OBS, *]
        policy_logits = policy_output[:, :, : self.vocab_size]
        policy_offsets = policy_output[:, :, self.vocab_size :]
        policy_offsets = einops.rearrange(policy_offsets,"N T (V A) -> N T V A",V=self.vocab_size,A=self.action_dim)
        return policy_logits, policy_offsets

    def _optimize(
        self,
        self_obs_rep: torch.Tensor,
        img_obs_rep: torch.Tensor,
        obs_mask: torch.Tensor,
        subgoal:torch.Tensor,
        subgoal_mask: torch.Tensor,
        goal: torch.Tensor,
        goal_mask: torch.Tensor,
    ):
        """
        Run a forward pass given observation representations and (optionally) goals.
        Arguments:
            obs_rep: Tensor[N, T, E], observation representations.
            goal: Tensor[N, T_goal, G], goal conditionals.
            N: batch size, T: sequence length, E: observation embedding size, G: goal size.
        Returns:
            A batch of predicted actions.
        """
        if self.use_img_obs:
            img_obs_rep = img_obs_rep + torch.randn_like(img_obs_rep) * self.aug_noise
            goal = goal
        else:
            img_rep = None
        self_obs_rep = self_obs_rep + (torch.randn_like(self_obs_rep) * self.aug_noise * 0.1)
        policy_output, policy_hidden, skill_outputs = self.model.forward_policy(self_obs_rep, img_obs_rep, obs_mask, goal, goal_mask)

        policy_output = policy_output[:, self.goal_seq_lenth:, :]  # [N, T_GOAL+T_OBS, *] --> [N, T_OBS, *]
        policy_hidden = policy_hidden[:, self.goal_seq_lenth:, :]

        policy_logits = policy_output[:, :, : self.vocab_size]
        if self.use_skill_head:
            skill_outputs = skill_outputs[:, self.goal_seq_lenth:, :]
            skill_logits = skill_outputs[:, :, : self.vocab_size]
        else:
            skill_logits = None
        policy_offsets = policy_output[:, :, self.vocab_size :]
        policy_offsets = einops.rearrange(policy_offsets,"N T (V A) -> N T V A",V=self.vocab_size,A=self.action_dim)

        if self.use_reg:
            prior_output, prior_hidden = self.model.forward_prior(self_obs_rep, img_obs_rep, obs_mask, subgoal, subgoal_mask)
            prior_output = prior_output[:, self.subgoal_seq_lenth:, :]
            prior_hidden = prior_hidden[:, self.subgoal_seq_lenth:, :]
            prior_logits = prior_output[:, :, : self.vocab_size]
            prior_offsets = prior_output[:, :, self.vocab_size :]
            prior_offsets = einops.rearrange(prior_offsets,"N T (V A) -> N T V A",V=self.vocab_size,A=self.action_dim)
        else:
            prior_logits = None
            prior_offsets = None
            prior_hidden = None

        return policy_logits, policy_offsets, policy_hidden, prior_logits, prior_offsets, prior_hidden, skill_logits

    def _cal_skill_loss(self, logits, target_latents, obs_mask, prior_logits=None):
        target_logits, _ = target_latents
        if prior_logits is not None and self.use_prior_weight:
            prior_prob = F.softmax(prior_logits, dim=-1).detach()
            prior_prob = torch.gather(prior_prob, dim=-1, index=target_logits)
            prior_weight = (1- prior_prob).view(-1) * self.prior_weight_scale
        else:
            prior_weight = self.prior_weight_scale
        flat_logits = logits.reshape(-1, logits.size(-1))
        criterion = FocalLoss(gamma=self.focal_loss_gamma, reduction="none")
        flat_target_logits = target_logits.view(-1)
        flat_obs_mask = obs_mask.view(-1)
        skill_loss = criterion(flat_logits, flat_target_logits) # basically get the log term at each target action
        skill_loss = (prior_weight * skill_loss * flat_obs_mask).sum() / flat_obs_mask.sum()
        return skill_loss

    def _cal_class_loss(self, logits, target_latents, obs_mask):
        target_logits, _ = target_latents
        flat_logits = logits.reshape(-1, logits.size(-1))
        criterion = FocalLoss(gamma=self.focal_loss_gamma, reduction="none")
        flat_target_logits = target_logits.view(-1)
        flat_obs_mask = obs_mask.view(-1)
        class_loss = criterion(flat_logits, flat_target_logits) # basically get the log term at each target action
        class_loss = (class_loss * flat_obs_mask).sum() / flat_obs_mask.sum()
        return class_loss

    def _cal_offset_loss(self, offsets, target_latents, obs_mask):
        target_logits, target_offsets = target_latents
        selected_offsets = batch_indexing(offsets, target_logits.squeeze(-1))  # [N, T, V, A] index [N T](0..V-1) => [N T A]
        offset_loss = self.offset_loss_scale * F.mse_loss(selected_offsets, target_offsets, reduction="none")
        unsqueezed_obs_mask = obs_mask.unsqueeze(-1)
        offset_loss = (offset_loss * unsqueezed_obs_mask).sum() / unsqueezed_obs_mask.sum()
        return offset_loss

    def _cal_kl_loss(self, policy_logits, policy_offsets, prior_logits, prior_offsets, obs_mask):
        prior_logits = prior_logits.detach()
        prior_offsets = prior_offsets.detach()

        unsqueezed_obs_mask = obs_mask.unsqueeze(-1)
        logits_loss = F.mse_loss(policy_logits, prior_logits, reduction="none")
        logits_loss = (logits_loss * unsqueezed_obs_mask).sum() / unsqueezed_obs_mask.sum()

        unsqueezed_obs_mask = obs_mask.unsqueeze(-1).unsqueeze(-1)
        offsets_loss = F.mse_loss(policy_offsets, prior_offsets, reduction="none")
        offsets_loss = (offsets_loss * unsqueezed_obs_mask).sum() / unsqueezed_obs_mask.sum()

        return logits_loss, offsets_loss, logits_loss + offsets_loss

    def _cal_hidden_loss(self, policy_hidden, prior_hidden, obs_mask):
        if not self.use_double_kl:
            prior_hidden = prior_hidden.detach()

        unsqueezed_obs_mask = obs_mask.unsqueeze(-1)
        hidden_loss = F.mse_loss(policy_hidden, prior_hidden, reduction="none")
        hidden_loss = (hidden_loss * unsqueezed_obs_mask).sum() / unsqueezed_obs_mask.sum()

        return hidden_loss

    def _calc_loss(self, output: torch.Tensor, obs_mask: torch.Tensor, target_latents: torch.Tensor, cur_epoch: int):
        policy_logits, policy_offsets, policy_hidden, prior_logits, prior_offsets, prior_hidden, skill_logits = output
        target_logits, target_offsets = target_latents
        target_latents = (target_logits, target_offsets)

        loss = 0
        loss_components = {}

        policy_class_loss = self._cal_class_loss(policy_logits, target_latents, obs_mask)
        loss += policy_class_loss
        loss_components["policy_class_loss"] = policy_class_loss

        policy_offset_loss = self._cal_offset_loss(policy_offsets, target_latents, obs_mask)
        loss += policy_offset_loss
        loss_components["policy_offset_loss"] = policy_offset_loss

        if self.use_reg:
            if not self.fix_prior or cur_epoch < self.coef_decay_start:
                prior_class_loss = self._cal_class_loss(prior_logits, target_latents, obs_mask)
                loss += prior_class_loss
                loss_components["prior_class_loss"] = prior_class_loss

                prior_offset_loss = self._cal_offset_loss(prior_offsets, target_latents, obs_mask)
                loss += prior_offset_loss
                loss_components["prior_offset_loss"] = prior_offset_loss

        if self.use_reg:
            if self.reg_type == "kl":
                logits_kl_loss, offsets_kl_loss, policy_kl_loss = self._cal_kl_loss(policy_logits, policy_offsets, prior_logits, prior_offsets, obs_mask)
                loss += policy_kl_loss * self.cur_coef
                loss_components["logits_kl_loss"] = logits_kl_loss
                loss_components["offsets_kl_loss"] = offsets_kl_loss
            elif self.reg_type == 'hidden':
                hidden_loss = self._cal_hidden_loss(policy_hidden, prior_hidden, obs_mask)
                loss += hidden_loss * self.cur_coef
                loss_components["hidden_loss"] = hidden_loss

            if self.use_skill_head:
                skill_loss = self._cal_skill_loss(skill_logits, target_latents, obs_mask, prior_logits)
                loss += skill_loss
                loss_components["skill_loss"] = skill_loss

            loss_components['coef'] = torch.tensor(self.cur_coef)
        
        return loss, loss_components

    def get_latent_and_loss(
        self,
        self_obs_rep: torch.Tensor,
        img_obs_rep: torch.Tensor,
        obs_mask: torch.Tensor,
        target_latents: torch.Tensor,
        subgoal: torch.Tensor,
        subgoal_mask: torch.Tensor,
        goal: torch.Tensor ,
        goal_mask: torch.Tensor,
        cur_epoch: int,
        return_loss_components: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor]:

        output = self._optimize(
            self_obs_rep=self_obs_rep,
            img_obs_rep= img_obs_rep,
            obs_mask=obs_mask, 
            subgoal=subgoal, 
            subgoal_mask=subgoal_mask,
            goal=goal, 
            goal_mask=goal_mask)

        if cur_epoch < self.coef_decay_start:
            self.cur_coef = self.coef_start
        else:
            coef_min = min(self.coef_start, self.coef_end)
            coef_max = max(self.coef_start, self.coef_end)
            cur_coef = (self.coef_start + self.coef_step * (cur_epoch - self.coef_decay_start))
            self.cur_coef = min(max(coef_min, cur_coef), coef_max)
        loss, loss_components = self._calc_loss(output=output, obs_mask=obs_mask, target_latents=target_latents, cur_epoch=cur_epoch)
        if return_loss_components:
            return output, loss, loss_components
        else:
            return output, loss

    def generate_latents(
        self,
        self_obs_rep: torch.Tensor,
        img_obs_rep: torch.Tensor,
        obs_mask: torch.Tensor,
        goal: torch.Tensor,
        goal_mask: torch.Tensor,
    ) -> torch.Tensor:

        logits, offsets = self._predict(self_obs_rep=self_obs_rep, img_obs_rep=img_obs_rep, obs_mask=obs_mask, goal=goal, goal_mask=goal_mask)

        #sampled_bins = Categorical(logits=logits).sample()  # batch x seq
        sampled_bins = logits.max(dim=-1)[1]
        sampled_offsets = batch_indexing(offsets, sampled_bins)  # N T A
        
        return sampled_bins.unsqueeze(-1), sampled_offsets

    def evaluate_latent(
        self,
        self_obs_rep: torch.Tensor,
        img_obs_rep: torch.Tensor,
        obs_mask: torch.Tensor,
        target_latents: torch.Tensor,
        goal: torch.Tensor ,
        goal_mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        '''
        self.aug_noise = 0
        policy_logits, _, prior_logits, _ = self._optimize(obs_rep=self_obs_rep, img_obs_rep=img_obs_rep, obs_mask=obs_mask, goal=goal, goal_mask=goal_mask)
        policy_pi = F.softmax(policy_logits, dim=-1)
        prior_pi = F.softmax(prior_logits, dim=-1)
        target_logits, _ = target_latents

        action_policy_prob = torch.gather(policy_pi, dim=-1, index=target_logits).detach().cpu()
        action_prior_prob = torch.gather(prior_pi, dim=-1, index=target_logits).detach().cpu()

        return action_policy_prob, action_prior_prob
        '''
        pass

    def get_optimizer(
        self, weight_decay: float, learning_rate: float, betas: Tuple[float, float]
    ) -> torch.optim.Optimizer:
        trainer_cfg = mingpt_trainer.TrainerConfig(
            weight_decay=weight_decay, learning_rate=learning_rate, betas=betas
        )
        return self.model.configure_optimizers(trainer_cfg)
