import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
import transformers
import sys
print(sys.path)

from GDT.decision_transformer2.models.model import TrajectoryModel
from GDT.decision_transformer2.models.trajectory_gpt2 import GPT2Model

from torch.distributions.normal import Normal
import torch.nn as nn
class DiagGaussianActor(nn.Module):
    """
    torch.distributions implementation of an diagonal Gaussian policy.
    """

    def __init__(self, hidden_dim, act_dim, log_std_bounds=[-5.0, 2.0]):
        super().__init__()

        self.mu = torch.nn.Linear(hidden_dim, act_dim)
        self.log_std = torch.nn.Linear(hidden_dim, act_dim)
        self.log_std_bounds = log_std_bounds

        def weight_init(m):
            """Custom weight init for Conv2D and Linear layers."""
            if isinstance(m, torch.nn.Linear):
                nn.init.orthogonal_(m.weight.data)
                if hasattr(m.bias, "data"):
                    m.bias.data.fill_(0.0)

        self.apply(weight_init)

    def forward(self, obs):
        mu, log_std = self.mu(obs), self.log_std(obs)
        std = log_std.exp()
        return Normal(mu, std)
    # def forward(self,obs):
    #     mu, log_std = self.mu(obs), self.log_std(obs)
    #     log_std = torch.tanh(log_std)
    #     log_std_min, log_std_max = self.log_std_bounds
    #     log_std = log_std_min + 0.5 * (log_std_max-log_std_min)*(log_std+1.0)
    #     std = log_std.exp()
    #     return SquashedNormal(mu,std)
    
class DecisionTransformer(TrajectoryModel):

    """
    This model uses GPT to model (Return_1, state_1, action_1, Return_2, state_2, ...)
    """

    def __init__(
            self,
            state_dim,
            act_dim,
            act_max,
            hidden_size,
            max_length=None,
            max_ep_len=4096,
            action_tanh=True,
            vent = False,
            **kwargs
    ):
        super().__init__(state_dim, act_dim, max_length=max_length)

        self.hidden_size = hidden_size
        config = transformers.GPT2Config(
            vocab_size=1,  # doesn't matter -- we don't use the vocab
            n_embd=hidden_size,
            **kwargs
        )
        if isinstance(act_max,int):
            self.act_max = act_max
        else:
            self.act_max = act_max[0]

        # note: the only difference between this GPT2Model and the default Huggingface version
        # is that the positional embeddings are removed (since we'll add those ourselves)
        self.transformer = GPT2Model(config)
        # 把每个值都编码为hidden_size大小的向量
        self.embed_timestep = nn.Embedding(max_ep_len, hidden_size)
        self.embed_return = torch.nn.Linear(1, hidden_size)   # 线形神经网络
        self.embed_state = torch.nn.Linear(self.state_dim, hidden_size)
        self.embed_action = torch.nn.Linear(self.act_dim, hidden_size)

        self.embed_ln = nn.LayerNorm(hidden_size)

        # note: we don't predict states or returns for the paper
        self.predict_state = torch.nn.Linear(2*hidden_size, self.state_dim)
        self.vent = vent
        # if self.vent:
        #     self.action_head = nn.Sequential(
        #     *([nn.Linear(hidden_size, self.act_dim)] + ([nn.Tanh()] if action_tanh else [])))
        #     self.loss_fn = nn.CrossEntropyLoss()
        # else:
        self.action_head = DiagGaussianActor(hidden_size, self.act_dim)
        self.predict_return = torch.nn.Linear(2*hidden_size, 1)
        init_temperature = 0.1
        self.log_temperature = torch.tensor(np.log(init_temperature))
        self.log_temperature.requires_grad = True
        self.target_entropy = -act_dim

        self.apply(self._init_weights)
    
    def temperature(self):
        
        return self.log_temperature.exp()
    
    @staticmethod
    def _init_weights(module: nn.Module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

    def forward(self, states, actions, rewards, returns_to_go, timesteps, attention_mask=None):

        batch_size, seq_length = states.shape[0], states.shape[1]

        if attention_mask is None:
            # attention mask for GPT: 1 if can be attended to, 0 if not
            attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)

        # embed each modality with a different head
        state_embeddings = self.embed_state(states)
        action_embeddings = self.embed_action(actions)
        returns_embeddings = self.embed_return(returns_to_go)
        time_embeddings = self.embed_timestep(timesteps)

        # time embeddings are treated similar to positional embeddings，与时间相关联
        state_embeddings = state_embeddings + time_embeddings
        action_embeddings = action_embeddings + time_embeddings
        returns_embeddings = returns_embeddings + time_embeddings

        # this makes the sequence look like (R_1, s_1, a_1, R_2, s_2, a_2, ...)
        # which works nice in an autoregressive sense since states predict actions
        stacked_inputs = torch.stack(
            (returns_embeddings, state_embeddings, action_embeddings), dim=1
        ).permute(0, 2, 1, 3).reshape(batch_size, 3*seq_length, self.hidden_size)
        stacked_inputs = self.embed_ln(stacked_inputs)  

        # to make the attention mask fit the stacked inputs, have to stack it as well
        stacked_attention_mask = torch.stack(
            (attention_mask, attention_mask, attention_mask), dim=1
        ).permute(0, 2, 1).reshape(batch_size, 3*seq_length)  # 随机阻断神经元输出为0，防止过拟合

        # we feed in the input embeddings (not word indices as in NLP) to the model
        transformer_outputs = self.transformer(
            inputs_embeds=stacked_inputs,
            attention_mask=stacked_attention_mask,
        )
        x = transformer_outputs['last_hidden_state']

        # reshape x so that the second dimension corresponds to the original
        # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t
        x = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3)

        state_output = x[:,1]
        action_output = x[:,2]


        # get predictions
        return_preds = self.predict_return(torch.cat((state_output,action_output),dim=2))  #a predict next return given state and action

        # （s0,a0）

        state_preds = self.predict_state(torch.cat((state_output,action_output),dim=2))    #a predict next state given state and action
        
        action_preds = self.action_head(state_output)  #s predict next action given state

        #state_preds = self.predict_state(action_output) 

        #action_preds = torch.clip(action_preds,0,1)
        #return_preds = torch.clip(return_preds,0,1)
        return state_preds, action_preds, return_preds
        

    def get_action(self, states, actions, rewards, returns_to_go, timesteps, **kwargs):
        # we don't care about the past rewards in this model

        states = states.reshape(1, -1, self.state_dim)
        actions = actions.reshape(1, -1, self.act_dim)
        returns_to_go = returns_to_go.reshape(1, -1, 1)
        timesteps = timesteps.reshape(1, -1)

        if self.max_length is not None:
            states = states[:,-self.max_length:]
            actions = actions[:,-self.max_length:]
            returns_to_go = returns_to_go[:,-self.max_length:]
            timesteps = timesteps[:,-self.max_length:]

            # pad all tokens to sequence length
            attention_mask = torch.cat([torch.zeros(self.max_length-states.shape[1]), torch.ones(states.shape[1])])
            attention_mask = attention_mask.to(dtype=torch.long, device=states.device).reshape(1, -1)
            states = torch.cat(
                [torch.zeros((states.shape[0], self.max_length-states.shape[1], self.state_dim), device=states.device), states],
                dim=1).to(dtype=torch.float32)
            actions = torch.cat(
                [torch.zeros((actions.shape[0], self.max_length - actions.shape[1], self.act_dim),
                             device=actions.device), actions],
                dim=1).to(dtype=torch.float32)
            returns_to_go = torch.cat(
                [torch.zeros((returns_to_go.shape[0], self.max_length-returns_to_go.shape[1], 1), device=returns_to_go.device), returns_to_go],
                dim=1).to(dtype=torch.float32)
            timesteps = torch.cat(
                [torch.zeros((timesteps.shape[0], self.max_length-timesteps.shape[1]), device=timesteps.device), timesteps],
                dim=1
            ).to(dtype=torch.long)
        else:
            attention_mask = None

        state_preds, action_preds, return_preds = self.forward(
            states, actions, None, returns_to_go, timesteps, attention_mask=attention_mask, **kwargs)

        # if self.vent:
        #     act = action_preds[0,-1]
        #     act = act.clamp(0, self.act_max)
        #     act = torch.round(act).int().float()
            
        #     act_loss = self.loss_fn(action_preds,actions.detach())
        #     act_loss = (act_loss * attention_mask.unsqueeze(-1)).mean()
        # else:
        log_likelihood = action_preds.log_prob(actions)[attention_mask > 0].mean() # LOG
        entropy = action_preds.entropy()[attention_mask>0].mean() #H
        entropy_reg = -2 #lambda
         #entropy_reg_item = entropy_reg.item()
                
        act_loss = -(log_likelihood + entropy_reg * entropy)

        action_preds = action_preds.mean


            # [size, seq_len, act_dim]
        action_preds = torch.mean(action_preds, dim=0, keepdim=True)
        action_preds = action_preds.clamp(0, self.act_max)
            # action_preds = action_preds[0, -1].detach().cpu().numpy()

        
        

        return state_preds,action_preds,return_preds,act_loss
