import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
import transformers
import sys
print(sys.path)
from ICRL.models.model import TrajectoryModel
from ICRL.models.trajectory_gpt2 import GPT2Model

from torch.distributions.normal import Normal
import torch.nn as nn
class DiagGaussianActor(nn.Module):
    """
    torch.distributions implementation of an diagonal Gaussian policy.
    """

    def __init__(self, hidden_dim, act_dim, log_std_bounds=[-5.0, 2.0]):
        super().__init__()

        self.mu = torch.nn.Linear(hidden_dim, act_dim)
        self.log_std = torch.nn.Linear(hidden_dim, act_dim)
        self.log_std_bounds = log_std_bounds

        def weight_init(m):
            """Custom weight init for Conv2D and Linear layers."""
            if isinstance(m, torch.nn.Linear):
                nn.init.orthogonal_(m.weight.data)
                if hasattr(m.bias, "data"):
                    m.bias.data.fill_(0.0)

        self.apply(weight_init)

    def forward(self, obs):
        mu, log_std = self.mu(obs), self.log_std(obs)
        std = log_std.exp()
        return Normal(mu, std)

class CostAttentionLayer(nn.Module):
    def __init__(self, pre_attn_embd_dim,hidden_size, num_heads, dropout_rate=0.0, scale_attn_weights=True):
        super(CostAttentionLayer, self).__init__()
        self.pre_attn_embd_dim = pre_attn_embd_dim
        self.num_heads = num_heads
        self.dropout = nn.Dropout(dropout_rate)
        self.scale_attn_weights = scale_attn_weights
        self.X = torch.nn.Linear(hidden_size,2*self.pre_attn_embd_dim+1)
        
    # q, k, v, casual_mask, -1e-4, training=training, attn_mask=new_attn_mask, head_mask=None
    
    def get_attention_mask(self,attn_mask,batch_size):
        assert batch_size > 0, 'batch_size should be > 0.'
        attn_mask = torch.reshape(attn_mask, shape=(batch_size, -1))
        #attn_mask = jnp.expand_dims(attn_mask, axis=(1, 2))
        attn_mask = attn_mask.unsqueeze(1).unsqueeze(2)
        attn_mask = (1.0 - attn_mask) * -10000.0
        return attn_mask

    def split_heads(self,x,num_heads,head_dim):
        newshape = x.shape[:-1] + (num_heads, head_dim)
        x = torch.reshape(x, newshape)
        if x.ndim == 5:
            # [batch, blocks, head, block_len, head_dim]
            return x.permute(0, 1, 3, 2, 4)
        elif x.ndim == 4:
            # [batch, head, seq_len, head_dim]
            return x.permute(0, 2, 1, 3)
        else:
            raise ValueError(f'Input tensor should have rank 4 or 5, but has rank {x.ndim}.')
    
    def merge_heads(self,x, num_heads, head_dim):
        """
        Merge embeddings for different heads.

        Args:
            x (tensor): Input tensor, shape [B, num_head, seq_len, head_dim] or [B, blocks, num_head, block_len, head_dim].
            num_heads (int): Number of heads.
            head_dim (int): Dimension of embedding for each head.

        Returns:
            (tensor): Output tensor, shape [B, seq_len, embd_dim] or [B, blocks, block_len, embd_dim].
        """
        if x.ndim == 5:
            #x = jnp.transpose(x, axes=(0, 1, 3, 2, 4))
            x.permute(0, 1, 3, 2, 4)
        elif x.ndim == 4:
            x.permute(0, 2, 1, 3)
        else:
            raise ValueError(f'Input tensor should have rank 4 or 5, but has rank {x.ndim}.')

        # newshape = x.shape[:-2] + (num_heads * head_dim,)
        # x = torch.reshape(x, newshape)
        # 计算 embd_dim
        num_head = x.size(1)
        head_dim = x.size(3)
        embd_dim = num_head * head_dim

        # 将输入张量重塑为 [B, seq_len, embd_dim]
        x = x.view(x.size(0), x.size(2), embd_dim)

        return x 

    def forward(self,hidden_output, masked_bias, training, attention_mask=None, batch_size=32,seq_length=20, head_mask=None, feedback=None):
        
        x = self.X(hidden_output)
        num_heads = 1
        query,key,value = torch.split(x,[self.pre_attn_embd_dim,self.pre_attn_embd_dim,1],dim=2)
        query = self.split_heads(query,num_heads,self.pre_attn_embd_dim)
        key = self.split_heads(key,num_heads,self.pre_attn_embd_dim)
        value = self.split_heads(value,num_heads,1)

        q_len,k_len = query.shape[-2],key.shape[-2]
        casual_mask = torch.ones((1,1,seq_length,seq_length), dtype=torch.bool,device=x.device)[:, :, k_len - q_len:k_len, :k_len]

        attn_mask = self.get_attention_mask(attention_mask, batch_size)

        query = query.to(torch.float32)
        key = key.to(torch.float32)

        attn_weights = torch.matmul(query, key.transpose(-1, -2))

        if self.scale_attn_weights:
            attn_weights = attn_weights / (float(value.shape[-1]) ** 0.5)

        attn_weights = torch.where(casual_mask, attn_weights, masked_bias)

        if attn_mask is not None:
            attn_weights = attn_weights + attn_mask

        _attn_weights = nn.Softmax(dim=-1)(attn_weights)
        attn_weights = _attn_weights.to(value.dtype)
        attn_weights = self.dropout(attn_weights)

        if head_mask is not None:
            attn_weights = attn_weights * head_mask

        out = torch.matmul(attn_weights, value)

        output = self.merge_heads(out, num_heads, 1)
        return output,value,_attn_weights

    
class DecisionTransformer_icrl(TrajectoryModel):

    """
    This model uses GPT to model (Return_1, state_1, action_1, Return_2, state_2, ...)
    """

    def __init__(
            self,
            state_dim,
            act_dim,
            hidden_size,
            pre_attn_embd_dim,
            use_weighted_sum,
            max_length=None,
            max_ep_len=4096,
            action_tanh=True,
            **kwargs
    ):
        super().__init__(state_dim, act_dim, max_length=max_length)

        self.hidden_size = hidden_size
        config = transformers.GPT2Config(
            vocab_size=1,  # doesn't matter -- we don't use the vocab
            n_embd=hidden_size,
            **kwargs
        )

        self.transformer = GPT2Model(config)
        self.embed_timestep = nn.Embedding(max_ep_len, hidden_size)
        self.embed_state = torch.nn.Linear(self.state_dim, hidden_size)
        self.embed_action = torch.nn.Linear(self.act_dim, hidden_size)
        self.embed_ln = nn.LayerNorm(hidden_size)

        init_temperature = 0.1
        self.pre_attn_embd_dim = pre_attn_embd_dim

        
        self.use_weighted_sum = use_weighted_sum

        self.apply(self._init_weights)

        self.cost_atten_layer = CostAttentionLayer(pre_attn_embd_dim,hidden_size, 1, dropout_rate=0.0, scale_attn_weights=True)
        self.add_module('cost_atten_layer', self.cost_atten_layer)
        
    @staticmethod
    def _init_weights(module: nn.Module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

    

    def forward(self, states, actions, timesteps, attention_mask=None,training=False,target_idx=1):

        batch_size, seq_length = states.shape[0], states.shape[1]

        if attention_mask is None:
            # attention mask for GPT: 1 if can be attended to, 0 if not
            attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long,device=states.device)

        # embed each modality with a different head
        state_embeddings = self.embed_state(states)
        action_embeddings = self.embed_action(actions)
        time_embeddings = self.embed_timestep(timesteps)

        
        state_embeddings = state_embeddings + time_embeddings
        action_embeddings = action_embeddings + time_embeddings
        

        # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...)
        # which works nice in an autoregressive sense since states predict actions
        stacked_inputs = torch.stack(
            (state_embeddings, action_embeddings), dim=1
        ).permute(0, 2, 1, 3).reshape(batch_size, 2*seq_length, self.hidden_size)
        stacked_inputs = self.embed_ln(stacked_inputs)  

        # to make the attention mask fit the stacked inputs, have to stack it as well
        stacked_attention_mask = torch.stack(
            (attention_mask, attention_mask), dim=1
        ).permute(0, 2, 1).reshape(batch_size, 2*seq_length) 

        # we feed in the input embeddings (not word indices as in NLP) to the model
        transformer_outputs = self.transformer(
            inputs_embeds=stacked_inputs,
            attention_mask=stacked_attention_mask,
            output_attentions = True,
        )
        x = transformer_outputs['last_hidden_state']
        attn_weights_list = transformer_outputs['attentions']

        x = x.reshape(batch_size, seq_length, 2, self.hidden_size).permute(0, 2, 1, 3)
        hidden_output = x[:,1] # cost pre from action


        # add attention layer
        output,value,_attn_weights = self.cost_atten_layer(hidden_output, -1e-4, training=training, attention_mask=attention_mask,batch_size=batch_size,seq_length=seq_length, head_mask=None)            
        attn_weights_list += (_attn_weights,)

        return {"weighted_sum": output, "value": value},attn_weights_list

    
    def get_cost(self, states, actions, timesteps, **kwargs):

        states = states.reshape(1, -1, self.state_dim)
        actions = actions.reshape(1, -1, self.act_dim)
        returns_to_go = returns_to_go.reshape(1, -1, 1)
        timesteps = timesteps.reshape(1, -1)

        if self.max_length is not None:
            states = states[:,-self.max_length:]
            actions = actions[:,-self.max_length:]
            returns_to_go = returns_to_go[:,-self.max_length:]
            timesteps = timesteps[:,-self.max_length:]

            # pad all tokens to sequence length
            attention_mask = torch.cat([torch.zeros(self.max_length-states.shape[1]), torch.ones(states.shape[1])])
            attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1)
            states = torch.cat(
                [torch.zeros((states.shape[0], self.max_length-states.shape[1], self.state_dim), device=states.device), states],
                dim=1).to(dtype=torch.float32)
            actions = torch.cat(
                [torch.zeros((actions.shape[0], self.max_length - actions.shape[1], self.act_dim),
                             device=actions.device), actions],
                dim=1).to(dtype=torch.float32)
            returns_to_go = torch.cat(
                [torch.zeros((returns_to_go.shape[0], self.max_length-returns_to_go.shape[1], 1), device=returns_to_go.device), returns_to_go],
                dim=1).to(dtype=torch.float32)
            timesteps = torch.cat(
                [torch.zeros((timesteps.shape[0], self.max_length-timesteps.shape[1]), device=timesteps.device), timesteps],
                dim=1
            ).to(dtype=torch.long)
        else:
            attention_mask = None
        # states, actions, timesteps, attention_mask=None,training=False,target_idx=1
        pred,_ = self.forward(
            states, actions, timesteps, attention_mask=attention_mask, training=False)
        
        B,T,_ = actions.shape
        
        pred = pred["weighted_sum"]
                
        sum_pred_e = torch.mean(pred.reshape(B, T), axis=1).reshape(-1, 1)
        
        return sum_pred_e
    def evaluation_cost(self,use_weighted_sum,train_type,states_e, actions_e,timesteps_e,attention_mask_e,):
        self.model.eval()
        B,T,_ = actions_e.shape
        trans_pred_e,_ = self.forward(
            states_e, actions_e, timesteps_e, attention_mask=attention_mask_e,training=False
        )
        #print(trans_pred_e)
        if use_weighted_sum:
            trans_pred_e = trans_pred_e["weighted_sum"]
        else:
            trans_pred_e = trans_pred_e["value"]

        if train_type == "mean":
            results = torch.mean(trans_pred_e.reshape(B, T), axis=1).reshape(-1, 1)
        elif train_type == "sum":
            results = torch.sum(trans_pred_e.reshape(B, T), axis=1).reshape(-1, 1)
        elif train_type == "last":
            results = trans_pred_e.reshape(B, T)[:, -1].reshape(-1, 1)
        elif train_type == "every":
            results = trans_pred_e.reshape(B, T)
        self.model.train()
        return results
