import os
from dataclasses import dataclass
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
import transformers
from transformers.models.gpt2.configuration_gpt2 import GPT2Config

import numpy as np
from torch.nn import CrossEntropyLoss, MSELoss
from torch.distributions import Normal, Independent
from torch.distributions.transformed_distribution import TransformedDistribution
from torch.distributions.transforms import TanhTransform
from transformers.activations import ACT2FN
from transformers.modeling_utils import (
    Conv1D,
    PreTrainedModel,
)
from transformers.utils import logging
from transformers.models.gpt2.configuration_gpt2 import GPT2Config

logger = logging.get_logger(__name__)

class Convolution(nn.Module):
    def __init__(self, config, hidden_size, block_index):
        super().__init__()
        self.window_size = config.window_size
        self.remove_act_embs = config.remove_act_embs
        
        self.rtg_conv1d = nn.Conv1d(in_channels=hidden_size, out_channels=hidden_size, kernel_size=self.window_size, groups=hidden_size)
        self.obs_conv1d = nn.Conv1d(in_channels=hidden_size, out_channels=hidden_size, kernel_size=self.window_size, groups=hidden_size)
        if not self.remove_act_embs:
            self.act_conv1d = nn.Conv1d(in_channels=hidden_size, out_channels=hidden_size, kernel_size=self.window_size, groups=hidden_size)

    def forward(self, x):
        window_size = self.window_size

        padded_tensor = torch.nn.functional.pad(x, (0, 0, window_size - 1, 0)).transpose(1, 2)
        
        if not self.remove_act_embs:
            rtg_conv_tensor = self.rtg_conv1d(padded_tensor)[:, :, ::3]
            obs_conv_tensor = self.obs_conv1d(padded_tensor)[:, :, 1::3]
            act_conv_tensor = self.act_conv1d(padded_tensor)[:, :, 2::3]

            conv_tensor = torch.cat((rtg_conv_tensor.unsqueeze(3), obs_conv_tensor.unsqueeze(3), act_conv_tensor.unsqueeze(3)), dim=3)
            conv_tensor = conv_tensor.reshape(conv_tensor.shape[0], conv_tensor.shape[1], -1)
        else:
            rtg_conv_tensor = self.rtg_conv1d(padded_tensor)[:, :, ::2]
            obs_conv_tensor = self.obs_conv1d(padded_tensor)[:, :, 1::2]
            
            conv_tensor = torch.cat((rtg_conv_tensor.unsqueeze(3), obs_conv_tensor.unsqueeze(3)), dim=3)
            conv_tensor = conv_tensor.reshape(conv_tensor.shape[0], conv_tensor.shape[1], -1)
        
        conv_tensor = conv_tensor.transpose(1, 2)
       
        return conv_tensor


class Block(nn.Module):
    def __init__(self, config, index, scale=False):
        super().__init__()
        hidden_size = config.n_embd
        inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.conv = Convolution(config, hidden_size, index)
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, 4*hidden_size),
            nn.GELU(),
            nn.Linear(4*hidden_size, hidden_size),
            nn.Dropout(config.drop_p),
        )
            
        self.index = index

    def forward(
            self,
            hidden_states,
    ):
        conv_output = self.conv(self.ln_1(hidden_states))
        hidden_states = conv_output + hidden_states

        feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states))
        
        hidden_states = hidden_states + feed_forward_hidden_states
        
        return hidden_states


class GPT2PreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface
    for downloading and loading pretrained models.
    """
    config_class = GPT2Config
    base_model_prefix = "transformer"

    def __init__(self, *inputs, **kwargs):
        super().__init__(*inputs, **kwargs)

    def _init_weights(self, module):
        """Initialize the weights."""
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if isinstance(module, (nn.Linear)) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


class GPT2Model(GPT2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.drop = nn.Dropout(config.embd_pdrop)
        self.h = nn.ModuleList([Block(config, index) for index in range(config.n_layer)])
        self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)

        self.init_weights()

    def forward(self, inputs_embeds=None):
        hidden_states = inputs_embeds
        hidden_states = self.drop(hidden_states)

        for block in self.h:
            hidden_states = block(hidden_states)

        hidden_states = self.ln_f(hidden_states)
        return hidden_states


class DecisionConvFormer(nn.Module):

    def __init__(
            self,
            state_dim,
            act_dim,
            hidden_size,
            dist_dim,
            max_length=None,
            max_ep_len=4096,
            action_tanh=True,
            remove_act_embs=False,
            **kwargs
    ):
        super().__init__()

        self.state_dim = state_dim
        self.act_dim = act_dim
        self.dist_dim = dist_dim
        self.hidden_size = hidden_size
        self.max_length = max_length

        config = transformers.GPT2Config(
            vocab_size=1, 
            n_embd=hidden_size,
            remove_act_embs=remove_act_embs,
            max_length=max_length,
            **kwargs
        )

        self.transformer = GPT2Model(config)
        self.remove_act_embs = remove_act_embs

        self.embed_timestep = nn.Embedding(max_ep_len, hidden_size)
        self.embed_return = torch.nn.Linear(self.dist_dim, hidden_size)
        self.embed_state = torch.nn.Linear(self.state_dim, hidden_size)
        self.embed_action = torch.nn.Linear(self.act_dim, hidden_size)

        self.embed_ln = nn.LayerNorm(hidden_size)

        self.predict_action = nn.Sequential(
            *([nn.Linear(hidden_size, self.act_dim)] + ([nn.Tanh()] if action_tanh else []))
        )

    def forward(self, states, actions, rewards, returns_to_go, timesteps, distributions):
        batch_size, seq_length = states.shape[0], states.shape[1]

        time_embeddings = self.embed_timestep(timesteps)
        state_embeddings = self.embed_state(states) + time_embeddings
        returns_embeddings = self.embed_return(distributions) + time_embeddings
        if not self.remove_act_embs:
            action_embeddings = self.embed_action(actions) + time_embeddings

        if self.remove_act_embs:
            num_token_type = 2
            stacked_inputs = torch.stack(
                (returns_embeddings, state_embeddings), dim=1
            ).permute(0, 2, 1, 3).reshape(batch_size, num_token_type*seq_length, self.hidden_size)
            stacked_inputs = self.embed_ln(stacked_inputs)
        else:
            num_token_type = 3
            stacked_inputs = torch.stack(
                (returns_embeddings, state_embeddings, action_embeddings), dim=1
            ).permute(0, 2, 1, 3).reshape(batch_size, num_token_type*seq_length, self.hidden_size)
            stacked_inputs = self.embed_ln(stacked_inputs)

        x = self.transformer(inputs_embeds=stacked_inputs)
        x = x.reshape(batch_size, seq_length, num_token_type, self.hidden_size).permute(0, 2, 1, 3)
        action_preds = self.predict_action(x[:, 1])

        return action_preds

    def get_action(self, states, actions, rewards, returns_to_go, timesteps, distributions, **kwargs):

        states = states.reshape(1, -1, self.state_dim)
        actions = actions.reshape(1, -1, self.act_dim)
        distributions = distributions.reshape(1, -1, self.dist_dim)
        timesteps = timesteps.reshape(1, -1)

        states = states[:,-self.max_length:]
        actions = actions[:,-self.max_length:]
        distributions = distributions[:, -self.max_length:]
        timesteps = timesteps[:,-self.max_length:]

        states = torch.cat(
            [torch.zeros((states.shape[0], self.max_length-states.shape[1], self.state_dim), device=states.device), states],
            dim=1).to(dtype=torch.float32)
        actions = torch.cat(
            [torch.zeros((actions.shape[0], self.max_length - actions.shape[1], self.act_dim), device=actions.device), actions],
            dim=1).to(dtype=torch.float32)
        distributions = torch.cat(
            [torch.zeros((distributions.shape[0], self.max_length - distributions.shape[1], self.dist_dim),
                         device=distributions.device), distributions],
            dim=1).to(dtype=torch.float32)
        timesteps = torch.cat(
            [torch.zeros((timesteps.shape[0], self.max_length-timesteps.shape[1]), device=timesteps.device), timesteps],
            dim=1).to(dtype=torch.long)

        action_preds = self.forward(states, actions, None, None, timesteps, distributions)
        return action_preds[0,-1]
    