import numpy as np
import torch
import torch.nn as nn

import transformers
from bert import BertModel

import math

from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union

from wiserl.module.net.attention.base import BaseTransformer
from wiserl.module.net.attention.gpt2 import GPT2
from wiserl.module.net.attention.positional_encoding import get_pos_encoding


class SinusoidalPosEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


class Structured2Transformer(BaseTransformer):
    def __init__(
        self,
        obs_dim: int,
        action_dim: int,
        embed_dim: int,
        pref_embed_dim: int,
        seq_len: int,
        num_layers: int=2,
        num_heads: int=1,
        reward_act: str="identity",
        attention_dropout: Optional[float]=0.1,
        residual_dropout: Optional[float]=0.1,
        embed_dropout: Optional[float]=0.1,
        pos_encoding: str="embed",
        use_weighted_sum: bool=False
    ) -> None:
        super().__init__()
        self.backbone = GPT2(
            input_dim=embed_dim,
            embed_dim=embed_dim,
            num_layers=num_layers,
            num_heads=num_heads,
            causal=True,
            attention_dropout=attention_dropout,
            residual_dropout=residual_dropout,
            embed_dropout=embed_dropout,
            pos_encoding="none",
            seq_len=0
        )
        self.pos_encoding = get_pos_encoding(pos_encoding, embed_dim, seq_len)
        self.obs_embed = nn.Linear(obs_dim, embed_dim)
        self.act_embed = nn.Linear(action_dim, embed_dim)
        self.embed_ln = nn.LayerNorm(embed_dim)
        self.use_weighted_sum = use_weighted_sum
        # additional layers
        self.pref_embed_dim = pref_embed_dim
        if use_weighted_sum:
            self.to_kqv = nn.Linear(embed_dim, 2*pref_embed_dim+1, bias=False)
        else:
            self.output_layer = nn.Sequential(
                nn.Linear(embed_dim, pref_embed_dim),
                nn.GELU(),
                nn.Linear(pref_embed_dim, 1)
            )
        # reward activation
        self.reward_act = nn.Identity() if reward_act == "identity" else nn.Sigmoid()

    def forward(
        self,
        states: torch.Tensor,
        actions: torch.Tensor,
        timesteps: torch.Tensor,
        attention_mask: Optional[torch.Tensor]=None,
        key_padding_mask: Optional[torch.Tensor]=None
    ):
        B, L, *_ = states.shape
        state_embedding = self.pos_encoding(self.obs_embed(states), timesteps)
        action_embedding = self.pos_encoding(self.act_embed(actions), timesteps)
        stacked_input = torch.stack([state_embedding, action_embedding], dim=2).reshape(B, 2*L, state_embedding.shape[-1])
        stacked_input = self.embed_ln(stacked_input)
        if key_padding_mask is not None:
            key_padding_mask = torch.stack([key_padding_mask, key_padding_mask], dim=2).reshape(B, 2*L)
        out = self.backbone(
            inputs=stacked_input,
            timesteps=None,
            attention_mask=attention_mask,
            key_padding_mask=key_padding_mask,
            do_embedding=False
        )
        out = out[:, 1::2] # select the action token output

        if self.use_weighted_sum:
            out = self.to_kqv(out)
            query, key, value = out.split([self.pref_embed_dim, self.pref_embed_dim, 1], dim=2)
            query = query / (self.pref_embed_dim ** 0.25)
            key = key / (self.pref_embed_dim ** 0.25)
            value = self.reward_act(value)
            attention_weights = torch.bmm(query, key.transpose(1, 2))
            attention_weights = torch.softmax(attention_weights, dim=2)
            out = torch.bmm(attention_weights, value)
            return value, out
        else:
            value = self.output_layer(out)
            out = self.reward_act(value)
            return value, out


class StructuredTransformer(nn.Module):
    def __init__(
        self,
        state_dim,
        act_dim,
        hidden_size,
        max_length=None,
        max_ep_len=1000,
        n_bins=10,
        device='cuda',
        **kwargs
    ):
        super().__init__()

        self.state_dim = state_dim
        self.act_dim = act_dim
        self.max_length = max_length
        self.hidden_size = hidden_size
        self.device = device
        self.max_ep_len = max_ep_len
        config = transformers.BertConfig(
            vocab_size=1,  # doesn't matter -- we don't use the vocab
            hidden_size=hidden_size,
            **kwargs
        )

        # note: the only difference between this GPT2Model and the default Huggingface version
        # is that the positional embeddings are removed (since we'll add those ourselves)
        self.transformer = BertModel(config)

        self.embed_timestep = nn.Embedding(max_ep_len, hidden_size)
        self.embed_state = torch.nn.Linear(self.state_dim, hidden_size)
        self.embed_action = torch.nn.Linear(self.act_dim, hidden_size)

        self.embed_ln = nn.LayerNorm(hidden_size)
        # output is utilized to predict weight
        self.output = nn.Linear(self.hidden_size, 1)
        # rew_output is utilized to predict the return
        # self.rew_output = nn.Linear(self.hidden_size, n_bins)
        # rew_hidden = 64
        # self.rew_output = nn.Sequential(
        #     SinusoidalPosEmb(rew_hidden),
        #     nn.Linear(rew_hidden, n_bins)
        # )
        # self.rew_output = nn.Linear(self.hidden_size, 1)

    def get_main_parameters(self):
        """Returns parameters of all layers except self.rew_output."""
        for name, param in self.named_parameters():
            if 'rew_output' not in name:
                yield param

    # def get_rew_output_parameters(self):
    #     """Returns parameters of the self.rew_output layer."""
    #     for name, param in self.named_parameters():
    #         if 'rew_output' in name:
    #             yield param

    def embed_forward(
        self,
        states,
        actions,
        timesteps=None,
    ):
        if timesteps is None:
            timesteps = torch.arange(0, self.max_ep_len).unsqueeze(0).repeat(len(states), 1).to(states.device)
        batch_size, seq_length = states.shape[0], states.shape[1]

        # embed each modality with a different head
        state_embeddings = self.embed_state(states)
        action_embeddings = self.embed_action(actions)
        time_embeddings = self.embed_timestep(timesteps)

        # time embeddings are treated similar to positional embeddings
        state_embeddings = state_embeddings + time_embeddings
        action_embeddings = action_embeddings + time_embeddings

        # this makes the sequence look like (s_1, a_1, s_2, a_2, ...)
        # which works nice in an autoregressive sense since states predict actions
        stacked_inputs = (
            torch.stack((state_embeddings, action_embeddings), dim=1)
            .permute(0, 2, 1, 3)
            .reshape(batch_size, 2 * seq_length, self.hidden_size)
        )
        stacked_inputs = self.embed_ln(stacked_inputs)
        
        return stacked_inputs
    
    def disc_forward(
        self,
        stacked_inputs,
        attention_mask=None
    ):
        batch_size, seq_length = stacked_inputs.shape[0], stacked_inputs.shape[1] // 2
        if attention_mask is None:
            # attention mask for GPT: 1 if can be attended to, 0 if not
            attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long).to(
                self.device
            )
        
        # to make the attention mask fit the stacked inputs, have to stack it as well
        stacked_attention_mask = (
            torch.stack((attention_mask, attention_mask), dim=1)
            .permute(0, 2, 1)
            .reshape(batch_size, 2 * seq_length)
        )

        # we feed in the input embeddings (not word indices as in NLP) to the model
        transformer_outputs = self.transformer(
            inputs_embeds=stacked_inputs,
            attention_mask=stacked_attention_mask,
        )
        x = transformer_outputs["last_hidden_state"]

        # reshape x so that the second dimension corresponds to the original
        # states (0), or actions (1); i.e. x[:,0,t] is the token for s_t
        x = x.reshape(
            batch_size,
            seq_length,
            2,
            self.hidden_size,
        ).permute(0, 2, 1, 3)

        x = x.sum(dim=2).sum(dim=1)
        
        return self.output(x)

    def forward(
        self,
        states,
        actions,
        timesteps=None,
        attention_mask=None,
    ):
        if timesteps is None:
            timesteps = torch.arange(0, self.max_ep_len).unsqueeze(0).repeat(len(states), 1).to(states.device)
        batch_size, seq_length = states.shape[0], states.shape[1]

        if attention_mask is None:
            # attention mask for GPT: 1 if can be attended to, 0 if not
            attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long).to(
                self.device
            )

        # embed each modality with a different head
        state_embeddings = self.embed_state(states)
        action_embeddings = self.embed_action(actions)
        time_embeddings = self.embed_timestep(timesteps)

        # time embeddings are treated similar to positional embeddings
        state_embeddings = state_embeddings + time_embeddings
        action_embeddings = action_embeddings + time_embeddings

        # this makes the sequence look like (s_1, a_1, s_2, a_2, ...)
        # which works nice in an autoregressive sense since states predict actions
        stacked_inputs = (
            torch.stack((state_embeddings, action_embeddings), dim=1)
            .permute(0, 2, 1, 3)
            .reshape(batch_size, 2 * seq_length, self.hidden_size)
        )

        stacked_inputs = self.embed_ln(stacked_inputs)

        # to make the attention mask fit the stacked inputs, have to stack it as well
        stacked_attention_mask = (
            torch.stack((attention_mask, attention_mask), dim=1)
            .permute(0, 2, 1)
            .reshape(batch_size, 2 * seq_length)
        )

        # we feed in the input embeddings (not word indices as in NLP) to the model
        transformer_outputs = self.transformer(
            inputs_embeds=stacked_inputs,
            attention_mask=stacked_attention_mask,
        )
        x = transformer_outputs["last_hidden_state"]

        # reshape x so that the second dimension corresponds to the original
        # states (0), or actions (1); i.e. x[:,0,t] is the token for s_t
        x = x.reshape(
            batch_size,
            seq_length,
            2,
            self.hidden_size,
        ).permute(0, 2, 1, 3)

        x = x.sum(dim=2).sum(dim=1)

        _c = self.output(x)
        return _c

    def gradient_penalty(self, expert_states, expert_actions, offline_states, offline_actions, lambda_=10, device='cuda'):
        # 随机权重用于插值
        alpha = torch.rand(expert_states.size(0), 1, 1, device=device)
        
        alpha_s = alpha.expand_as(expert_states).to(expert_states.device)
        alpha_a = alpha.expand_as(expert_actions).to(expert_actions.device)
        
        mixup_states = alpha_s * expert_states + (1 - alpha_s) * offline_states
        mixup_actions = alpha_a * expert_actions + (1 - alpha_a) * offline_actions
       
        stacked_inputs = self.embed_forward(mixup_states, mixup_actions)
        stacked_inputs.requires_grad_(True)
        
        disc = self.disc_forward(stacked_inputs)
        ones = torch.ones(disc.size()).to(disc.device)
        gradients = torch.autograd.grad(
            outputs=disc,
            inputs=stacked_inputs,
            grad_outputs=ones,
            create_graph=True,
            retain_graph=True,
            only_inputs=True
        )[0]

        # 计算梯度的2-范数和梯度惩罚
        gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=[1, 2]) + 1e-12)
        gp = torch.mean((gradients_norm - 1) ** 2)
    
        return lambda_ * gp