import math

import torch
import torch.nn as nn

from transformers import (
    GPT2Config,
    GPT2Model,
)
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block


class NoPositionalEncoding(nn.Module):
    def __init__(self, embed_dim):
        """
        Initialize the sinusoidal positional encoding module.

        :param embed_dim: Dimensionality of the model.
        """
        super().__init__()
        if embed_dim % 2 != 0:
            raise ValueError(f"Embedding dimension must be even, got {embed_dim}.")
        self.embed_dim = embed_dim

    def forward(self, position_ids):
        """
        Compute sinusoidal embeddings for given position IDs.

        :param position_ids: Tensor of shape (batch_size, seq_length) containing position indices.
        :return: Tensor of shape (batch_size, seq_length, embed_dim) with sinusoidal embeddings.
        """
        return torch.zeros(position_ids.size(0), position_ids.size(1), self.embed_dim, device=position_ids.device)


class AbsolutePositionalEncoding(nn.Module):
    def __init__(self, embed_dim):
        """
        Initialize the sinusoidal positional encoding module.

        :param embed_dim: Dimensionality of the model.
        """
        super().__init__()
        if embed_dim % 2 != 0:
            raise ValueError(f"Embedding dimension must be even, got {embed_dim}.")
        self.embed_dim = embed_dim

    def forward(self, position_ids):
        """
        Compute sinusoidal embeddings for given position IDs.

        :param position_ids: Tensor of shape (batch_size, seq_length) containing position indices.
        :return: Tensor of shape (batch_size, seq_length, embed_dim) with sinusoidal embeddings.
        """
        # Compute terms for sin and cos
        position = position_ids.unsqueeze(-1).float()  # Shape: (batch_size, seq_length, 1)
        div_term = torch.exp(torch.arange(0, self.embed_dim, 2, dtype=torch.float).to(position.device) *
                             -(math.log(10000.0) / self.embed_dim))  # Shape: (embed_dim / 2)

        # Sin and cos parts
        sin_part = torch.sin(position * div_term)  # Shape: (batch_size, seq_length, embed_dim / 2)
        cos_part = torch.cos(position * div_term)  # Shape: (batch_size, seq_length, embed_dim / 2)

        # Interleave sin and cos along the last dimension
        pe = torch.zeros(position.size(0), position.size(1), self.embed_dim, device=position.device)
        pe[..., 0::2] = sin_part
        pe[..., 1::2] = cos_part

        return pe


class CustomGPT2Config(GPT2Config):
    def __init__(self, pos_enc='learned', is_causal=True, **kwargs):
        """
        Initializes the custom GPT-2 configuration.

        Args:
            pos_enc (bool, optional): Whether to use positional encoding. Defaults to True.
            is_causal (bool, optional): Whether the model is causal. Defaults to True.
            **kwargs: All other GPT-2 configuration parameters.
        """
        # Call the parent constructor to set standard GPT2Config attributes.
        super().__init__(**kwargs)
        # Add additional custom keys.
        self.pos_enc = pos_enc
        self.is_causal = is_causal


class CustomGPT2Attention(GPT2Attention):
    def __init__(self, config, is_cross_attention=False, layer_idx=None):
        super().__init__(config, is_cross_attention, layer_idx)
        self.is_causal = config.is_causal


    def _attn(self, query, key, value, attention_mask=None, head_mask=None):
        attn_weights = torch.matmul(query, key.transpose(-1, -2))

        if self.scale_attn_weights:
            attn_weights = attn_weights / torch.full(
                [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
            )

        # Layer-wise attention scaling
        if self.scale_attn_by_inverse_layer_idx:
            attn_weights = attn_weights / float(self.layer_idx + 1)

        if not self.is_cross_attention and self.is_causal:
            # if only "normal" attention layer implements causal mask
            query_length, key_length = query.size(-2), key.size(-2)
            causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
            mask_value = torch.finfo(attn_weights.dtype).min
            # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
            # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
            mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
            attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)

        if attention_mask is not None:
            # Apply the attention mask
            attn_weights = attn_weights + attention_mask

        attn_weights = nn.functional.softmax(attn_weights, dim=-1)

        # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
        attn_weights = attn_weights.type(value.dtype)
        attn_weights = self.attn_dropout(attn_weights)

        # Mask heads if we want to
        if head_mask is not None:
            attn_weights = attn_weights * head_mask

        attn_output = torch.matmul(attn_weights, value)

        return attn_output, attn_weights


class CustomGPT2Block(GPT2Block):
    def __init__(self, config, layer_idx=None):
        super().__init__(config, layer_idx)
        self.attn = CustomGPT2Attention(config=config, layer_idx=layer_idx)

        if config.add_cross_attention:
            self.crossattention = CustomGPT2Attention(config=config, is_cross_attention=True, layer_idx=layer_idx)


class CustomEmbedGPT2Model(GPT2Model):
    def __init__(self, config):
        super().__init__(config)
        self.h = nn.ModuleList([CustomGPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
        if config.pos_enc == 'learned':
            # Standard learned positional embeddings
            self.wpe = nn.Embedding(config.max_position_embeddings, config.n_embd)
        elif config.pos_enc == 'absolute':
            # Sinusoidal positional embeddings
            self.wpe = AbsolutePositionalEncoding(config.n_embd)
        elif config.pos_enc == 'none':
            # No positional encoding
            self.wpe = NoPositionalEncoding(config.n_embd)
        else:
            raise ValueError(f"Invalid pos_enc: {config.pos_enc}. Choose 'none', 'learned', or 'absolute'.")


class TransformerModel(nn.Module):
    def __init__(self, n_dims, n_positions, n_embd=128, n_layer=12, n_head=4, pos_enc='absolute', is_causal=True):
        super(TransformerModel, self).__init__()
        configuration = CustomGPT2Config(
            n_positions=n_positions,
            n_embd=n_embd,
            n_layer=n_layer,
            n_head=n_head,
            resid_pdrop=0.0,
            embd_pdrop=0.0,
            attn_pdrop=0.0,
            use_cache=False,
            pos_enc=pos_enc,
            is_causal=is_causal,
        )
        self.name = f"gpt2_embd={n_embd}_layer={n_layer}_head={n_head}"

        self.n_positions = n_positions
        self.n_dims = n_dims
        self.n_embd = n_embd
        self._read_in = nn.Linear(n_dims, n_embd)
        self._backbone = CustomEmbedGPT2Model(configuration)
        self._read_out = nn.Linear(n_embd, 1)

    def get_attentions(self, xs):
        embeds = self._read_in(xs)
        output = self._backbone(inputs_embeds=embeds, output_attentions=True)
        return output.attentions

    def get_hidden_states(self, xs):
        embeds = self._read_in(xs)
        output = self._backbone(inputs_embeds=embeds, output_hidden_states=True)
        return output.hidden_states

    def forward(self, xs):
        embeds = self._read_in(xs)
        output = self._backbone(inputs_embeds=embeds).last_hidden_state
        prediction = self._read_out(output)
        return prediction
