 
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import distributions as pyd
import einops
from einops.layers.torch import Rearrange
import pdb
import transformers
import diffuser.utils as utils
from .helpers import (
    DiagGaussian,
    Gaussian)

from .gpt import GPT2Model

class TrajectoryModel(nn.Module):

    def __init__(self, state_dim, act_dim, max_length=None):
        super().__init__()

        self.state_dim = state_dim
        self.act_dim = act_dim
        self.max_length = max_length

    def forward(self, states, actions, rewards, masks=None, attention_mask=None):
        # "masked" tokens or unspecified inputs can be passed in as None
        return None, None, None

    def get_predictions(self, states, actions, rewards, **kwargs):
        # these will come as tensors on the correct device
        return torch.zeros_like(states[-1]), torch.zeros_like(actions[-1]), torch.zeros_like(rewards[-1])
    
class FutureTransformer(TrajectoryModel):
    """
    Using second transformer as anti-causal aggregator
    """
    def __init__(
            self,
            state_dim,
            act_dim,
            hidden_size,
            z_dim,
            ordering=1,
            horizon=100,#max_length=None,
            max_length=1000,#max_ep_len=4096,
            **kwargs
    ):
        super().__init__(state_dim, act_dim=act_dim, max_length=max_length)
        
        self.hidden_size = hidden_size
        config = transformers.GPT2Config(
            vocab_size=1,  # doesn't matter -- we don't use the vocab
            n_embd=hidden_size, #we use 256 instead of 768
            n_head=4,#we use 4 instead of 12
            **kwargs
        )
        
        self.transformer = GPT2Model(config)
        self.ordering = ordering
        self.horizon = horizon
        self.max_length = max_length

        self.embed_ln = nn.LayerNorm(hidden_size)
        self.predict_z = DiagGaussian(hidden_size, z_dim)
        
       
        
    #@torch.no_grad()
    def forward(self, state_embeddings, action_embeddings,time_embeddings, attention_mask=None,token_mode=1):
        
        batch_size, seq_length = state_embeddings.shape[0], state_embeddings.shape[1]
        if attention_mask is None:
            # attention mask for GPT: 1 if can be attended to, 0 if not
            attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)
        if token_mode==1: # state tokens + action tokens
            # time embeddings are treated similar to positional embeddings
            state_embeddings = state_embeddings + time_embeddings
            action_embeddings = action_embeddings + time_embeddings

            # this makes the sequence look like (s_0, a_0, s_1, a_1, ...)
            # which works nice in an autoregressive sense since states predict actions
            stacked_inputs = (
                torch.stack(
                    (state_embeddings, action_embeddings), dim=1
                )
                .permute(0, 2, 1, 3)
                .reshape(batch_size, 2 * seq_length, self.hidden_size)
            )
            stacked_inputs = self.embed_ln(stacked_inputs)

            # to make the attention mask fit the stacked inputs, have to stack it as well
            stacked_attention_mask = (
                torch.stack((attention_mask, attention_mask), dim=1)
                .permute(0, 2, 1)
                .reshape(batch_size, 2 * seq_length)
            )

            # we feed in the input embeddings (not word indices as in NLP) to the model
            with torch.no_grad():
                transformer_outputs = self.transformer(
                    inputs_embeds=stacked_inputs,
                    attention_mask=stacked_attention_mask,
                )
                x = transformer_outputs['last_hidden_state']

            # reshape x so that the second dimension corresponds to
            # predicting states (1)
            
            x = x.reshape(batch_size, seq_length, 2, self.hidden_size).permute(0, 2, 1, 3)

            # get predictions
            z_preds = self.predict_z(x[:,1])
            
        else:
            if token_mode==2: #only state tokens
                embeddings = state_embeddings + time_embeddings
            else: #token_mode==3, only action tokens
                embeddings = action_embeddings + time_embeddings
            with torch.no_grad():
                transformer_outputs = self.transformer(
                    inputs_embeds=embeddings,
                    attention_mask=attention_mask,
                )
                x = transformer_outputs['last_hidden_state']
            
            #x = x.reshape(batch_size, seq_length, 1, self.hidden_size).permute(0, 2, 1, 3)

            z_preds = self.predict_z(x)
            
        return z_preds

    
class FutureDiffusion(nn.Module):
    def __init__(
            self,
            prior_model,
            posterior_model,
            diffusion_model,
            observation_dim,
            action_dim,
            hidden_dim,
            z_dim,
            z_reg,
            horizon,
            max_length,
            future_mode=1,
            token_mode=1,
            cond_z=0,
            
    ):
        super().__init__()
        
        self.prior_model = prior_model
        self.posterior_model = posterior_model
        self.diffusion_model = diffusion_model
        
        self.state_dim = observation_dim
        self.action_dim = action_dim
        self.hidden_dim = hidden_dim
        self.z_dim = z_dim
        self.z_reg = z_reg
        
        self.horizon = horizon
        self.max_length = max_length
        
        self.embed_state = torch.nn.Linear(self.state_dim, self.hidden_dim)
        self.embed_timestep = nn.Embedding(max_length, self.hidden_dim)
        self.embed_action = torch.nn.Linear(self.action_dim, self.hidden_dim)
        
        self.future_mode = future_mode
        self.token_mode = token_mode
        self.cond_z = cond_z
        
        #freeze the transformer
        for param in self.posterior_model.transformer.parameters():
            param.requires_grad = False
            
    
    def forward(self,x,cond,returns=None):
        states = x[:, :, self.action_dim:]
        state_embeddings = self.embed_state(states)
        z_prior = self.predict_prior(F.relu(state_embeddings))
        x_new =torch.cat((x,z_prior),dim=-1)
        return self.diffusion_model(x_new,cond,returns)
        
    def loss(self,x,cond,returns=None,timesteps=None,padding_mask=None,ftrajectories=None,ftimesteps=None,fpadding_mask=None):
        import time
        time1=time.time()
        #batch_size,seq_length,_= x.shape
        #returns = returns.to(torch.float32)
        timesteps = timesteps.to(torch.long)
        ftimesteps = ftimesteps.to(torch.long)
        padding_mask = padding_mask.to(torch.long)
        fpadding_mask = fpadding_mask.to(torch.long)
        
        #actions = x[:, :, :self.action_dim]
        states = x[:, :, self.action_dim:]
        state_embeddings = self.embed_state(states)
        actions = x[:, :, :self.action_dim]
        action_embeddings = self.embed_action(actions)
        time_embeddings = self.embed_timestep(timesteps)
        
        batch_size, seq_length = state_embeddings.shape[0], state_embeddings.shape[1]
        
        factions = ftrajectories[:, :, :self.action_dim]
        fstates = ftrajectories[:, :, self.action_dim:]
        faction_embeddings = self.embed_action(factions)
        fstate_embeddings = self.embed_state(fstates)
        ftime_embeddings = self.embed_timestep(ftimesteps)
        
        time2=time.time()
        if self.cond_z==0:
            z_prior = self.prior_model(F.relu(state_embeddings))
        else:
            z_prior = self.prior_model(F.relu(torch.cat((state_embeddings,returns.expand(batch_size,seq_length).reshape(batch_size,seq_length,1)),dim=-1)))

        z_fixed_prior = Gaussian(torch.zeros_like(z_prior.mu), torch.zeros_like(z_prior.log_sigma))
        time3=time.time()
        
        if self.future_mode==1:
            z_posterior = self.posterior_model(fstate_embeddings, faction_embeddings, ftime_embeddings, fpadding_mask,self.token_mode)
        else:
            if self.future_mode==2: 
                len = int(seq_length/2)
            elif self.future_mode==3: 
                len = int(seq_length-10)
            else: #self.future_mode==4
                len = int(seq_length-3)
            s = torch.cat((state_embeddings[:,len:], fstate_embeddings[:,:len]),dim=1)
            a = torch.cat((action_embeddings[:,len:],faction_embeddings[:,:len]),dim=1)
            t = torch.cat((time_embeddings[:,len:],ftime_embeddings[:,:len]),dim=1)
            mask = torch.cat((padding_mask[:,len:],fpadding_mask[:,:len]),dim=1)
            z_posterior = self.posterior_model(s, a, t, mask,self.token_mode)
        
        
        z_random = z_posterior.sample()
        z_random *= padding_mask.unsqueeze(-1).repeat(1, 1, self.z_dim)
        time4=time.time()
        #z_embeddings = self.embed_future(z_random)
        
        reg_kl = z_posterior.kl_divergence(z_fixed_prior)[padding_mask > 0].mean()
        prior_kl = z_posterior.detach().kl_divergence(z_prior)[padding_mask > 0].mean()
       
        loss_future = self.z_reg * reg_kl + prior_kl
        time6=time.time()
        x_new =torch.cat((x,z_random),dim=-1)
        
        cond_new ={0: torch.cat((cond[0],z_random[:,0,:].view(-1,self.z_dim)),dim=-1) }
        
        loss_diff,info= self.diffusion_model.loss(x_new,cond_new,returns)
        time7=time.time()
        
        loss = (loss_future + loss_diff*2)/3
        
        info['future_loss'] = loss_future
        info['total_loss'] = loss
        return loss,info