import torch
import torch.nn
import torch.nn.functional as F
from IPython import embed
import numpy as np
import torch.nn as nn
import transformers
from transformers import BertConfig, BertModel 
from transformers import DistilBertConfig, DistilBertModel, GPT2Config, GPT2Model
import matplotlib.pyplot as plt
import time
import argparse
import torch.nn.functional as F
import time
import os
import pickle
from dataset import TrajDataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Net(torch.nn.Module):
    def __init__(self, config):
        super(Net, self).__init__()
        
        self.config = config
        self.H = self.config['H']
        self.n_embd = self.config['n_embd']
        self.n_layer = self.config['n_layer']
        self.n_head = self.config['n_head']
        self.dx = self.config['dx']
        self.du = self.config['du']
        self.useQ = self.config['Q']

        seq_length = self.H + 1


        if self.useQ:    input_dim = (1 + self.dx + self.dx + self.du) * seq_length + self.dx**2
        else:            input_dim = (1 + self.dx + self.dx + self.du) * seq_length

        self.embed_state_action = torch.nn.Linear( input_dim, self.n_embd )
        self.ln1 = torch.nn.Linear( self.n_embd, self.n_embd )
        self.ln2 = torch.nn.Linear( self.n_embd, self.n_embd )
        self.ln3 = torch.nn.Linear( self.n_embd, self.du )



    def forward(self, x):
        states = x['states'][:,None,:]
        zeros = x['zeros'][:,None,:]

        x_seq = torch.cat([states, x['rollin_xs']], dim=1)
        u_seq = torch.cat([zeros[:,:,:self.du], x['rollin_us']], dim=1)
        xp_seq = torch.cat([zeros[:,:,:self.dx], x['rollin_xps']], dim=1)
        r_seq = torch.cat([zeros[:,:,:1], x['rollin_rs']], dim=1)
        
        seq = torch.cat([x_seq, u_seq, xp_seq, r_seq], dim=-1)
        if self.useQ:
            seq[:,1:] = 0.0

        batch_size = x_seq.shape[0]
        seq_length = x_seq.shape[1]

        seq = torch.reshape(seq, (batch_size, (self.dx + self.du + self.dx + 1) * seq_length) ) 
        if self.useQ:        
            Qs = x['Qs'].reshape((batch_size, self.dx**2))
            seq = torch.cat((seq, Qs), dim=-1)

        embeds = self.embed_state_action(seq)
        embeds = F.relu(embeds)
        embeds = self.ln1(embeds)
        embeds = F.relu(embeds)
        embeds = self.ln2(embeds)
        embeds = F.relu(embeds)
        preds = self.ln3(embeds)

        return preds



class Transformer(torch.nn.Module):
    def __init__(self, config):
        super(Transformer, self).__init__()

        self.config = config
        self.H = self.config['H']
        self.n_embd = self.config['n_embd']
        self.n_layer = self.config['n_layer']
        self.n_head = self.config['n_head']
        self.dx = self.config['dx']
        self.du = self.config['du']
        self.useQ = self.config['Q']
        self.dropout = self.config['dropout']

        config = GPT2Config(
            n_positions= 4*( 1 + self.H ),
            n_embd=self.n_embd,
            n_layer=self.n_layer,
            n_head=1,
            resid_pdrop=self.dropout,
            embd_pdrop=self.dropout, # added drop out
            attn_pdrop=self.dropout,
            use_cache=False,
        )
        self.transformer = GPT2Model(config)


        self.embed_state = torch.nn.Linear( self.dx , self.n_embd)
        self.embed_action = torch.nn.Linear( self.du , self.n_embd)
        self.embed_state_prime = torch.nn.Linear( self.dx , self.n_embd)
        self.embed_reward = torch.nn.Linear( 1 , self.n_embd)
        if self.useQ:
            self.embed_Q = torch.nn.Linear(self.dx**2, self.n_embd)

        self.embed_ln = nn.LayerNorm(self.n_embd)
        self.pred_actions = nn.Linear(self.n_embd, self.du)




    def forward(self, x):
        states = x['states'][:,None,:]
        zeros = x['zeros'][:,None,:]

        x_seq = torch.cat([states, x['rollin_xs']], dim=1)
        u_seq = torch.cat([zeros[:,:,:self.du], x['rollin_us']], dim=1)
        xp_seq = torch.cat([zeros[:,:,:self.dx], x['rollin_xps']], dim=1)
        r_seq = torch.cat([zeros[:,:,:1], x['rollin_rs']], dim=1)
        
        batch_size = x_seq.shape[0]
        seq_length = x_seq.shape[1]

        if self.useQ:
            zerosQ = x['zerosQ']
            Q = x['Qs'].reshape((batch_size, 1, self.dx**2))
            Q_seq = torch.cat([  Q, zerosQ  ], dim=1)
            Q_embeds = self.embed_Q(Q_seq)

        state_embeds = self.embed_state(x_seq)
        action_embeds = self.embed_action(u_seq)
        state_prime_embeds = self.embed_state_prime(xp_seq)
        reward_embeds = self.embed_reward(r_seq)

        if self.useQ:
            stacked_inputs = torch.stack(
                (state_embeds, Q_embeds), dim=1
            ).permute(0, 2, 1, 3).reshape(batch_size, 2*seq_length, self.n_embd)
            stacked_inputs = self.embed_ln(stacked_inputs)
            
            transformer_outputs = self.transformer(
                inputs_embeds=stacked_inputs,
            )
            x = transformer_outputs['last_hidden_state']
            x = x.reshape(batch_size, seq_length, 2, self.n_embd).permute(0, 2, 1, 3)
        else:
            stacked_inputs = torch.stack(
                (state_embeds, action_embeds, state_prime_embeds, reward_embeds), dim=1
            ).permute(0, 2, 1, 3).reshape(batch_size, 4*seq_length, self.n_embd)
            stacked_inputs = self.embed_ln(stacked_inputs)
            

            transformer_outputs = self.transformer(
                inputs_embeds=stacked_inputs,
            )
            x = transformer_outputs['last_hidden_state']
            x = x.reshape(batch_size, seq_length, 4, self.n_embd).permute(0, 2, 1, 3)

        feats = x[:,-1, :]
        preds = self.pred_actions(feats)
        if self.config['full']:     return preds[:,1:,:]
        else:                       return preds[:,-1,:]





class TransformerTall(torch.nn.Module):
    def __init__(self, config):
        super(TransformerTall, self).__init__()

        self.config = config
        self.H = self.config['H']
        self.n_embd = self.config['n_embd']
        self.n_layer = self.config['n_layer']
        self.n_head = self.config['n_head']
        self.dx = self.config['dx']
        self.du = self.config['du']
        self.useQ = self.config['Q']
        self.dropout = self.config['dropout']

        config = GPT2Config(
            n_positions=4 * ( 1 + self.H ),
            n_embd=self.n_embd,
            n_layer=self.n_layer,
            n_head=1,
            resid_pdrop=self.dropout,
            embd_pdrop=self.dropout, # added drop out
            attn_pdrop=self.dropout,
            use_cache=False,
        )
        self.transformer = GPT2Model(config)

        self.embed_transition = torch.nn.Linear(self.dx + self.du + self.dx + 1, self.n_embd)        

        self.embed_ln = nn.LayerNorm(self.n_embd)
        self.pred_actions = nn.Linear(self.n_embd, self.du)




    def forward(self, x):
        states = x['states'][:,None,:]
        zeros = x['zeros'][:,None,:]

        x_seq = torch.cat([states, x['rollin_xs']], dim=1)
        u_seq = torch.cat([zeros[:,:,:self.du], x['rollin_us']], dim=1)
        xp_seq = torch.cat([zeros[:,:,:self.dx], x['rollin_xps']], dim=1)
        r_seq = torch.cat([zeros[:,:,:1], x['rollin_rs']], dim=1)

        seq = torch.cat([x_seq, u_seq, xp_seq, r_seq], dim=2)

        
        
        batch_size = seq.shape[0]
        seq_length = seq.shape[1]


        seq_embeds = self.embed_transition(seq)

        # stacked_inputs = torch.stack(
        #     (seq_embeds), dim=1
        # ).permute(0, 2, 1, 3).reshape(batch_size, 1*seq_length, self.n_embd)
        # stacked_inputs = self.embed_ln(stacked_inputs)
        stacked_inputs = seq_embeds


        transformer_outputs = self.transformer(
            inputs_embeds=stacked_inputs,
        )
        x = transformer_outputs['last_hidden_state']
        # x = x.reshape(batch_size, seq_length, 1, self.n_embd).permute(0, 2, 1, 3)

        feats = x
        # feats = x[:,-1, :]
        preds = self.pred_actions(feats)
        if self.config['full']:     return preds[:,1:,:]
        else:                       return preds[:,-1,:]



class TransformerBERT(torch.nn.Module):
    def __init__(self, config):
        super(TransformerBERT, self).__init__()

        self.config = config
        self.H = self.config['H']
        self.n_embd = self.config['n_embd']
        self.n_layer = self.config['n_layer']
        self.n_head = self.config['n_head']
        self.dx = self.config['dx']
        self.du = self.config['du']
        self.useQ = self.config['Q']
        self.dropout = self.config['dropout']

        config = BertConfig(
            max_position_embeddings=(1 + self.H),
            hidden_size=self.n_embd,
            num_hidden_layers=self.n_layer,
            num_attention_heads=self.n_head,
            intermediate_size=self.n_embd * 4,
            hidden_dropout_prob=self.dropout,
            attention_probs_dropout_prob=self.dropout,
            use_cache=False,
        )
        self.transformer = BertModel(config)

        self.embed_transition = torch.nn.Linear(self.dx + self.du + self.dx + 1, self.n_embd)
        
        self.embed_ln = nn.LayerNorm(self.n_embd)
        self.pred_actions = nn.Linear(self.n_embd, self.du)
        

    def forward(self, x):
        states = x['states'][:, None, :]
        zeros = x['zeros'][:, None, :]

        x_seq = torch.cat([states, x['rollin_xs']], dim=1)
        u_seq = torch.cat([zeros[:, :, :self.du], x['rollin_us']], dim=1)
        xp_seq = torch.cat([zeros[:, :, :self.dx], x['rollin_xps']], dim=1)
        r_seq = torch.cat([zeros[:, :, :1], x['rollin_rs']], dim=1)

        seq = torch.cat([x_seq, u_seq, xp_seq, r_seq], dim=2)

        batch_size = seq.shape[0]
        seq_length = seq.shape[1]

        seq_embeds = self.embed_transition(seq)

        # stacked_inputs = torch.stack(
        #     (seq_embeds), dim=1
        # ).permute(0, 2, 1, 3).reshape(batch_size, 1*seq_length, self.n_embd)
        # stacked_inputs = self.embed_ln(stacked_inputs)
        stacked_inputs = seq_embeds

        transformer_outputs = self.transformer(
            inputs_embeds=stacked_inputs,
        )
        x = transformer_outputs['last_hidden_state']
        # x = x.reshape(batch_size, seq_length, 1, self.n_embd).permute(0, 2, 1, 3)

        feats = x
        # feats = x[:, -1, :]
        preds = self.pred_actions(feats)
        if self.config['full']:
            return preds[:, 1:, :]
        else:
            return preds[:, -1, :]

if __name__ == '__main__':
    config = {}
    n_envs = 1000
    n_hists = 1
    n_samples = 1
    H = 10
    dim = 1
    path_train = f'datasets/trajs_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_train.pkl'
    path_test = f'datasets/trajs_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_test.pkl'
    ds = TrajDataset(path_train, config)

    config = {
        'H': H,
        'dx': dim,
        'du': dim,
        'n_layer': 3,
        'n_embd': 32,
        'n_head': 1,
        'Q': False,
    }
    model = Net(config).to(device)
    model(ds[:64])

    model = Transformer(config).to(device)
    model(ds[:64])