"""Common layers to use in different models """

import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 2048):
        super().__init__()
        idxs = torch.arange(0, d_model, 2).float()
        div_term = torch.exp(idxs * (-math.log(10000.0) / d_model))
        positions = torch.arange(max_len).unsqueeze(1)
        pe = torch.zeros(size=(max_len, d_model))
        pe[:, 0::2] = torch.sin(positions * div_term)
        pe[:, 1::2] = torch.cos(positions * div_term)
        # add dummy for batch dimension
        pe = pe.unsqueeze(0)
        self.register_buffer("PE", pe)

    def forward(self, X):
        return X + self.PE[:, : X.size(1)].to(X.device)


def create_rope(n_pos, d_model):
    """
    Args:
        n_pos: int = max number of positional embeddings
        d_model: int = hiddend dim of the embedding/model
    Returns:
        out: tensor.Tensor([n_pos, d_model])
    """
    sub_space = torch.arange(start=0, end=d_model, step=2)  # 2i
    denom = torch.exp(sub_space * (-math.log(10000.0) / d_model))
    positions = torch.arange(n_pos)
    positions = positions[None, :].repeat((d_model // 2, 1)).transpose(1, 0)
    out = torch.zeros((n_pos, d_model))
    out[:, 0 : (d_model // 2)] = torch.sin(positions * denom)
    out[:, (d_model // 2) :] = torch.cos(positions * denom)
    return out


class RopeEmbeds(nn.Module):
    def __init__(self, n_pos, d_model):
        super().__init__()
        self.n_pos = n_pos
        self.d_model = d_model

        rel_pos = create_rope(self.n_pos, self.d_model)
        self.register_buffer("rel_pos", rel_pos, persistent=False)

    def forward(self, mat):
        seq_len = mat.shape[2]
        sin, cos = torch.chunk(
            self.rel_pos[:seq_len, :].to(mat.device), chunks=2, dim=-1
        )
        sin = sin.repeat((1, 2))
        cos = cos.repeat((1, 2))
        rotate_half_mat = torch.concat([-mat[..., 1::2], mat[..., 0::2]], axis=-1)
        # rotated = cos * mat + sin * rotate_half_mat
        rotated = torch.einsum("BHTD,TD->BHTD", mat, cos) + torch.einsum(
            "BHTD,TD->BHTD", rotate_half_mat, sin
        )

        return rotated

class RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)
    
    
class MLP(nn.Module):
    def __init__(self, hidden_dim, dropout):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.layer1 = nn.Linear(self.hidden_dim, 4 * self.hidden_dim)
        self.act = nn.GELU()
        self.layer2 = nn.Linear(4 * self.hidden_dim, self.hidden_dim)
        self.drop = nn.Dropout(p=dropout)

    def forward(self, x):
        x = self.layer1(x)
        x = self.act(x)
        x = self.layer2(x)
        return self.drop(x)

class GatedMLP(nn.Module):
    def __init__(self,hidden_dim):
        super().__init__()
        self.hidden_size = hidden_dim
        self.intermediate_size = 4*self.hidden_size
        
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = F.silu
        
    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj