import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import warnings
from models.embedding import Embedding


class sin_cos_positional_encoding(nn.Module):

    def __init__(self, d_model: int, max_len: int = 5000, concatenate=False):
        super().__init__()

        position = torch.arange(max_len).unsqueeze(1)
        # Adjust calculation for potentially odd d_model
        num_timescales = d_model // 2
        log_timescale_increment = (
            -math.log(10000.0) / (num_timescales - 1) if num_timescales > 1 else -math.log(10000.0)
        )
        inv_timescales = torch.exp(torch.arange(num_timescales) * log_timescale_increment)

        # Apply scaling
        scaled_time = position * inv_timescales

        # Calculate sin and cos components
        signal_sin = torch.sin(scaled_time)
        signal_cos = torch.cos(scaled_time)

        pe = torch.zeros(1, max_len, d_model)
        # Assign sin to even indices
        pe[0, :, 0 : 2 * num_timescales : 2] = signal_sin
        # Assign cos to odd indices
        pe[0, :, 1 : 2 * num_timescales : 2] = signal_cos

        # If d_model is odd, the last dimension is zero-padded, which is fine.

        self.register_buffer("pe", pe)
        self.concatenate = concatenate

    def forward(self, x):
        """
        Arguments:
            x: Tensor, shape ``[batch_size,seq_length, embedding_dim]``
        """
        if self.concatenate:  # concatenate positional encoding and input x
            x = torch.cat((x, self.pe[:, : x.size(1), :].expand(x.shape[0], -1, -1)), dim=-1)

        else:
            x = x + self.pe[:, : x.size(1), :]
        return x

class one_hot_positional_encoding(nn.Module):
    def __init__(self, max_len: int = 5000, concatenate: bool = False):
        """
        Initializes the OneHotPositionalEncoding class.
        Args:
            max_len (int, optional): The maximum length of the positional encodings. Defaults to 5000.
            concatenate (bool, optional): Whether to concatenate the positional encodings
                                          with the input embeddings (True) or add them (False).
        """
        super().__init__()
        self.max_len = max_len
        self.concatenate = concatenate

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: Tensor of shape [batch_size, seq_length, embedding_dim]
        Returns:
            A Tensor of the same shape (if concatenate=False) or
            shape [batch_size, seq_length, embedding_dim + max_len] (if concatenate=True).
        """
        bsz, seq_len, emb_dim = x.shape
        device = x.device

        # pos_indices will be shape [bsz, seq_len], with each row = [0, 1, 2, ..., seq_len-1]
        pos_indices = torch.arange(seq_len, device=device).unsqueeze(0).expand(bsz, seq_len)

        # Create a zero tensor [bsz, seq_len, max_len] and scatter 1.0 at the index in pos_indices
        pos_one_hot = x.new_zeros(bsz, seq_len, self.max_len)
        pos_one_hot.scatter_(2, pos_indices.unsqueeze(-1), 1.0)  # fill positions with 1.0

        if self.concatenate:
            # Concatenate along the last dimension => [bsz, seq_len, emb_dim + max_len]
            return torch.cat([x, pos_one_hot], dim=-1)
        else:
            # Add the one-hot vectors => [bsz, seq_len, emb_dim]
            return x + pos_one_hot


class learnable_positional_encoding(nn.Module):
    def __init__(self, block_size: int = 5000, embedding_dim=512, concatenate=False):
        super().__init__()
        self.pos_embedding = nn.Embedding(block_size, embedding_dim)
        pos = torch.arange(block_size).unsqueeze(0).to(self.pos_embedding.weight.device)
        # register buffer to make sure that the positional embedding is saved in the state_dict
        self.register_buffer("pos", pos)
        self.concatenate = concatenate

    def forward(self, x):
        """
        Arguments:
            x: Tensor, shape ``[batch_size,seq_length, embedding_dim]``
        """
        pos = self.pos.repeat(x.size(0), 1)
        pos = self.pos_embedding(pos)
        if self.concatenate:
            x = torch.cat((x, pos), dim=-1)
        else:
            x = x + pos
        return x


def get_pos_encoder_embedding(
    pos_dim,
    embedding_dim,
    cat_pos,
    block_size,
    pos_enc,
    att_mask,
    vocab_size,
    one_hot_emb,
    skip_embedding,
    freeze_emb,
    embedding_type="embedding",
):
    # check if cat_pos is False then pos_dim and embedding_dim must be the same if they are not raise a warining
    if pos_enc == "one_hot" and one_hot_emb and not cat_pos:
        raise ValueError(
            "It is not possible to have one hot positional encoding and one hot embedding and sum them because the dimensions can't match"
        )
    if one_hot_emb:
        embedding_dim = vocab_size
    if not cat_pos and pos_dim != embedding_dim:
        warnings.warn(
            "If concatenation of positional encoding is disabled, the positional encoding dimension must be the same as the embedding dimension, therefore the positional encoding dimension is changed to the embedding dimension"
        )
        pos_dim = embedding_dim
    else:
        pos_dim = pos_dim

    if pos_enc.lower() == "one_hot":
        pos_dim = block_size  # if it is one-hot it must always be the same as the block size to represent all possible positions
        if not cat_pos:
            # raise warning, saying that for one_hot encoding with sum, embedding dimension must be of the same size as the block size therefore embeddgin dim is changed
            warnings.warn(
                "For one_hot positional encoding, the embedding dimension must be the same as the block size, therefore the embedding dimension is changed to the block size"
            )
            embedding_dim = block_size

        pos_encoder = one_hot_positional_encoding(max_len=pos_dim, concatenate=cat_pos)

    elif pos_enc.lower() == "sin_cos":
        pos_encoder = sin_cos_positional_encoding(d_model=pos_dim, concatenate=cat_pos)

    elif pos_enc.lower() == "learned":
        pos_encoder = learnable_positional_encoding(
            block_size=block_size, embedding_dim=pos_dim, concatenate=cat_pos
        )

    elif pos_enc.lower() == "none":
        pos_encoder = None
        if cat_pos:
            raise ValueError(
                "Concatenation of positional encoding is not possible with relative or no positional encoding, set cat_pos to False"
            )

    elif pos_enc.lower() == "relative":
        pos_encoder = None
        if cat_pos:
            raise ValueError(
                "Concatenation of positional encoding is not possible with relative or no positional encoding, set cat_pos to False"
            )
        att_mask -= torch.arange(block_size).unsqueeze(1) - torch.arange(block_size).unsqueeze(
            0
        )  # TODO: Debug to see if this is actually working
    else:
        raise ValueError(
            "pos_enc must be one of the following: one_hot, sin_cos, learned, none, relative"
        )

    if skip_embedding:
        embedding = lambda x: x
    else:
        if embedding_type == "embedding":
            embedding = Embedding(vocab_size, embedding_dim, one_hot_emb, freeze_emb)
        elif embedding_type == "linear":
            embedding = nn.Linear(vocab_size, embedding_dim)
        else:
            raise ValueError("embedding_type must be one of the following: embedding, linear")

    model_dim = embedding_dim + pos_dim if cat_pos else embedding_dim

    return pos_encoder, pos_dim, embedding_dim, att_mask, model_dim, embedding
