import math
import torch
import torch.nn as nn
from torch.nn import functional as F


class MultiHeadFourier(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = int(d_model / n_heads)
        self.pre_conv = nn.Conv1d(in_channels=self.d_model, out_channels=self.d_model, kernel_size=3, groups=self.d_model, bias=False)
        self.ln = nn.LayerNorm(self.d_model)
        self.silu = nn.SiLU()
        self.W_V = nn.Linear(in_features=self.d_model, out_features=self.d_model)
        self.W_G1 = nn.Linear(in_features=self.d_model, out_features=self.d_model)
        self.W_G2 = nn.Conv1d(in_channels=self.d_model, out_channels=self.d_model, kernel_size=1, groups=self.n_heads)
        self.linear = nn.Linear(in_features=self.d_model, out_features=self.d_model)
        self.linear.NEED_SCALE_INIT = 1

    def forward(self, x):
        # x: [batch_size, seq_len, d_model]
        x_permuted = x.permute(0, 2, 1)
        # x_permuted: [batch_size, d_model, seq_len]
        padding = self.pre_conv.kernel_size[0] - 1
        # padding = 2
        x_padded = F.pad(x_permuted, (padding, 0))
        # padded_x: [batch_size, d_model, seq_len+2]
        x = self.pre_conv(x_padded).permute(0, 2, 1)
        # x: [batch_size, seq_len, d_model]
        x_norm = self.ln(x)
        # x_norm: [batch_size, seq_len, d_model]
        batch_size, seq_len = x_norm.size(0), x_norm.size(1)
        N = 2 * seq_len
        x_v = self.W_V(x_norm).reshape(batch_size, seq_len, self.n_heads, self.d_head).transpose(1,2)
        # x_v: [batch_size, n_heads, seq_len, d_head]
        x_g = self.W_G1(x_norm).transpose(1,2)
        # x_g: [batch_size, d_model, seq_len]
        x_g = self.W_G2(self.silu(x_g)).transpose(1,2)
        # x_g: [batch_size, seq_len, d_model]
        x_g = x_g.reshape(batch_size, seq_len, self.n_heads, self.d_head).transpose(1,2)
        # x_g: [batch_size, n_heads, seq_len, d_head]
        G_fft = torch.fft.rfft(x_g.to(torch.float32), n=N, dim=2)
        V_fft = torch.fft.rfft(x_v.to(torch.float32), n=N, dim=2)
        # G_fft: [batch_size, n_heads, N//2+1, d_head]
        # V_fft: [batch_size, n_heads, N//2+1, d_head]
        X_fft = G_fft * V_fft
        # X_fft: [batch_size, n_heads, N//2+1, d_head]
        x_fft = torch.fft.irfft(X_fft, n=N, dim=2)
        # x_fft: [batch_size, n_heads, N, d_head]
        x_fft = x_fft[:, :, :seq_len, :]
        # x_fft: [batch_size, n_heads, seq_len, d_head]
        x_fft = x_fft.transpose(1, 2).contiguous().reshape(batch_size, seq_len, self.d_model)
        # x_fft: [batch_size, seq_len, d_model]
        x = self.linear(x_fft)
        # x: [batch_size, seq_len, d_model]
        return x


class SlidingWindowAttention(nn.Module):
    def __init__(self, d_model, n_heads, window_size):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.window_size = window_size
        self.d_head = int(d_model / n_heads)
        self.c_attn = nn.Linear(d_model, 3 * d_model, bias=False)
        self.c_proj = nn.Linear(d_model, d_model, bias=False)
        self.c_proj.NEED_SCALE_INIT = 1

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        qkv = self.c_attn(x)
        # qkv: [batch_size, seq_len, d_model*3]
        q, k, v = qkv.view(batch_size, seq_len, self.n_heads, 3 * self.d_head).permute(0, 2, 1, 3).chunk(3, dim=-1)
        # q, k, v: [batch_size, n_heads, seq_len, d_head]
        ones = torch.ones((seq_len, seq_len), device=x.device, dtype=torch.bool)
        causal_mask = torch.tril(ones)
        window_mask = torch.triu(ones, diagonal=-self.window_size + 1)
        mask = causal_mask & window_mask
        min_value = torch.finfo(q.dtype).min
        attn_bias = torch.zeros((seq_len, seq_len), device=x.device, dtype=q.dtype)
        attn_bias = attn_bias.masked_fill(~mask, min_value)
        q = q * (1.0 / math.sqrt(self.d_head))
        # q: [batch_size, n_heads, seq_len, d_head]
        att = (q @ k.transpose(-2, -1)) + attn_bias
        # att: [batch_size, n_heads, seq_len, seq_len]
        att = F.softmax(att, dim=-1)
        # att: [batch_size, n_heads, seq_len, seq_len]
        y = att @ v
        # y: [batch_size, n_heads, seq_len, d_head]
        y = y.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        # y: [batch_size, seq_len, d_model]
        y = self.c_proj(y)
        # y: [batch_size, seq_len, d_model]
        return y


class MLP(nn.Module):
    def __init__(self, d_model, intermediate_size):
        super().__init__()
        self.fc_1 = nn.Linear(in_features=d_model, out_features=intermediate_size, bias=False)
        self.fc_gate = nn.Linear(in_features=d_model, out_features=intermediate_size, bias=False)
        self.fc_2 = nn.Linear(in_features=intermediate_size, out_features=d_model, bias=False)
        self.silu = nn.SiLU()
        self.fc_2.NEED_SCALE_INIT = 1

    def forward(self, x):
        # x: [batch_size, seq_len, d_model]
        x_g = self.silu(self.fc_gate(x))
        # x_g: [batch_size, seq_len, 4*d_model]
        x_v = self.fc_1(x)
        # x_v: [batch_size, seq_len, 4*d_model]
        x = x_v * x_g
        # x: [batch_size, seq_len, 4*d_model]
        x = self.fc_2(x)
        # x: [batch_size, seq_len, d_model]
        return x


class Block(nn.Module):
    def __init__(self, d_model, n_heads, intermediate_size, window_size, layer_type):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.intermediate_size = intermediate_size
        self.window_size = window_size
        self.ln_1 = nn.LayerNorm(self.d_model)
        if layer_type == "attn":
            self.mixer = SlidingWindowAttention(self.d_model, self.n_heads, self.window_size)
        else:
            self.mixer = MultiHeadFourier(self.d_model, self.n_heads)
        self.ln_2 = nn.LayerNorm(self.d_model)
        self.pos_ffn = MLP(self.d_model, self.intermediate_size)

    def forward(self, x):
        # x: [batch_size, seq_len, d_model]
        x = x + self.mixer(self.ln_1(x))
        # x: [batch_size, seq_len, d_model]
        x = x + self.pos_ffn(self.ln_2(x))
        # x: [batch_size, seq_len, d_model]
        return x


class CaracalForCausalLM(nn.Module):
    def __init__(self, d_model, n_layers, n_heads, vocab_size, intermediate_size, attn_layers=(), window_size=256, **kwargs):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.vocab_size = vocab_size
        self.intermediate_size = intermediate_size
        self.window_size = window_size
        self.attn_layers_set = set(attn_layers)
        self.wte = nn.Embedding(self.vocab_size, self.d_model)
        self.h = nn.ModuleList()
        for i in range(n_layers):
            if i in self.attn_layers_set:
                layer_type = "attn"
            else:
                layer_type = "fft"
            block = Block(
                d_model=self.d_model,
                n_heads=self.n_heads,
                intermediate_size=self.intermediate_size,
                window_size=window_size,
                layer_type=layer_type
            )
            self.h.append(block)
        self.ln_f = nn.LayerNorm(self.d_model)
        self.lm_head = nn.Linear(in_features=self.d_model, out_features=self.vocab_size, bias=False)
        self.wte.weight = self.lm_head.weight
        self.apply(self._init_weights)

    def forward(self, x, targets=None):
        # x: [batch_size, seq_len]
        x = self.wte(x)
        # x: [batch_size, seq_len, d_model]
        for layer in self.h:
            x = layer(x)
            # x: [batch_size, seq_len, d_model]
        x = self.ln_f(x)
        # x: [batch_size, seq_len, d_model]
        logits = self.lm_head(x)
        # logits: [batch_size, seq_len, vocab_size]
        return logits
        # logits: [batch_size, seq_len, vocab_size]

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Conv1d)):
            std = 0.02
            if hasattr(module, 'NEED_SCALE_INIT'):
                std *= (2 * self.n_layers) ** -0.5
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
