import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F


class MultiHeadFourier(nn.Module):

    def __init__(self, d_model, n_heads):
        super(MultiHeadFourier, self).__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = int(d_model / n_heads)
        self.pre_conv = nn.Conv1d(in_channels=self.d_model, out_channels=self.d_model, kernel_size=3, groups=self.d_model, bias=False)
        self.ln = nn.LayerNorm(self.d_model)
        self.silu = nn.SiLU()
        self.W_V = nn.Linear(in_features=self.d_model, out_features=self.d_model)
        self.W_G1 = nn.Linear(in_features=self.d_model, out_features=self.d_model)
        self.W_G2 = nn.Conv1d(in_channels=self.d_model, out_channels=self.d_model, kernel_size=1, groups=self.n_heads)
        self.linear = nn.Linear(in_features=self.d_model, out_features=self.d_model)
        self.linear.NEED_SCALE_INIT = 1

    def forward(self, x):
        # x: [batch_size, seq_len, d_model]
        x_permuted = x.permute(0, 2, 1)
        # x_permuted: [batch_size, d_model, seq_len]
        padding = self.pre_conv.kernel_size[0] - 1
        # padding = 2
        x_padded = F.pad(x_permuted, (padding, 0))
        # padded_x: [batch_size, d_model, seq_len+2]
        x = self.pre_conv(x_padded).permute(0, 2, 1)
        # x: [batch_size, seq_len, d_model]
        x_norm = self.ln(x)
        # x_norm: [batch_size, seq_len, d_model]
        batch_size, seq_len = x_norm.size(0), x_norm.size(1)
        N = 2 * seq_len
        x_v = self.W_V(x_norm).reshape(batch_size, seq_len, self.n_heads, self.d_head).transpose(1,2)
        # x_v: [batch_size, n_heads, seq_len, d_head]
        x_g = self.W_G1(x_norm).transpose(1,2)
        # x_g: [batch_size, d_model, seq_len]
        x_g = self.W_G2(self.silu(x_g)).transpose(1,2)
        # x_g: [batch_size, seq_len, d_model]
        x_g = x_g.reshape(batch_size, seq_len, self.n_heads, self.d_head).transpose(1,2)
        # x_g: [batch_size, n_heads, seq_len, d_head]
        G_fft = torch.fft.rfft(x_g.to(torch.float32), n=N, dim=2)
        V_fft = torch.fft.rfft(x_v.to(torch.float32), n=N, dim=2)
        # G_fft: [batch_size, n_heads, N//2+1, d_head]
        # V_fft: [batch_size, n_heads, N//2+1, d_head]
        X_fft = G_fft * V_fft
        # X_fft: [batch_size, n_heads, N//2+1, d_head]
        x_fft = torch.fft.irfft(X_fft, n=N, dim=2)
        # x_fft: [batch_size, n_heads, N, d_head]
        x_fft = x_fft[:, :, :seq_len, :]
        # x_fft: [batch_size, n_heads, seq_len, d_head]
        x_fft = x_fft.transpose(1, 2).contiguous().reshape(batch_size, seq_len, self.d_model)
        # x_fft: [batch_size, seq_len, d_model]
        x = self.linear(x_fft)
        # x: [batch_size, seq_len, d_model]
        return x


# # MLP (for ablation study)
# class MLP(nn.Module):

#     def __init__(self, d_model):
#         super(MLP, self).__init__()
#         self.d_model = d_model
#         self.c_1 = nn.Linear(in_features=self.d_model, out_features=4*self.d_model)
#         self.silu = nn.SiLU()
#         self.c_2 = nn.Linear(in_features=4*self.d_model, out_features=self.d_model)
#         self.c_2.NEED_SCALE_INIT = 1

#     def forward(self, x):
#         # x: [batch_size, seq_len, d_model]
#         x = self.c_1(x)
#         # x: [batch_size, seq_len, 4*d_model]
#         x = self.silu(x)
#         # x: [batch_size, seq_len, 4*d_model]
#         x = self.c_2(x)
#         # x: [batch_size, seq_len, d_model]
#         return x


# SwiGLU
class MLP(nn.Module):
    def __init__(self, d_model):
        super(MLP, self).__init__()
        self.fc_1 = nn.Linear(in_features=d_model, out_features=4*d_model, bias=False)
        self.fc_gate = nn.Linear(in_features=d_model, out_features=4*d_model, bias=False)
        self.fc_2 = nn.Linear(in_features=4*d_model, out_features=d_model, bias=False)
        self.silu = nn.SiLU()
        self.fc_2.NEED_SCALE_INIT = 1

    def forward(self, x):
        # x: [batch_size, seq_len, d_model]
        x_g = self.silu(self.fc_gate(x))
        # x_g: [batch_size, seq_len, 4*d_model]
        x_v = self.fc_1(x)
        # x_v: [batch_size, seq_len, 4*d_model]
        x = x_v * x_g
        # x: [batch_size, seq_len, 4*d_model]
        x = self.fc_2(x)
        # x: [batch_size, seq_len, d_model]
        return x


class Block(nn.Module):

    def __init__(self, d_model, n_heads):
        super(Block, self).__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.ln_1 = nn.LayerNorm(self.d_model)
        self.mul_head_fft = MultiHeadFourier(self.d_model, self.n_heads)
        self.ln_2 = nn.LayerNorm(self.d_model)
        self.pos_ffn = MLP(self.d_model)

    def forward(self, x):
        # x: [batch_size, seq_len, d_model]
        x = x + self.mul_head_fft(self.ln_1(x))
        # x: [batch_size, seq_len, d_model]
        x = x + self.pos_ffn(self.ln_2(x))
        # x: [batch_size, seq_len, d_model]
        return x

class Transfourier(nn.Module):

    def __init__(self, d_model, n_layers, n_heads, vocab_size, **kwargs):
        super(Transfourier, self).__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.vocab_size = vocab_size
        self.wte = nn.Embedding(self.vocab_size, self.d_model)
        self.h = nn.ModuleList([Block(self.d_model, self.n_heads) for _ in range(self.n_layers)])
        self.ln_f = nn.LayerNorm(self.d_model)
        self.lm_head = nn.Linear(in_features=self.d_model, out_features=self.vocab_size, bias=False)
        self.wte.weight = self.lm_head.weight
        self.apply(self._init_weights)

    def forward(self, x, targets=None):
        # x: [batch_size, seq_len]
        x = self.wte(x)
        # x: [batch_size, seq_len, d_model]
        for layer in self.h:
            x = layer(x)
            # x: [batch_size, seq_len, d_model]
        x = self.ln_f(x)
        # x: [batch_size, seq_len, d_model]
        logits = self.lm_head(x)
        # x: [batch_size, seq_len, vocab_size]
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), targets.view(-1))
        return logits, loss
        # [batch_size, seq_len, vocab_size], loss: float

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Conv1d)):
            std = 0.02
            if hasattr(module, 'NEED_SCALE_INIT'):
                std *= (2 * self.n_layers) ** -0.5
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

if __name__ == '__main__':

    def print_model(model):
        print("model_size: ", sum(p.numel() for p in model.parameters() if p.requires_grad))
        print("Model Architecture:")
        print(model)
        print("\nParameter Count for Each Part:")
        total_params = 0
        for name, param in model.named_parameters():
            num_params = param.numel()
            total_params += num_params
            print(f"{name}: {param.shape} -> {num_params} parameters")
        print(f"\nTotal Parameters: {total_params}")

 
    max_len = 1024
    vocab = {'P': 0, 'i': 1, 'want': 2, 'a': 3, 'beer': 4, 'S': 5, 'E': 6}
    data_root = [['S i want a beer E', 'i want a beer E P'], ['S want a beer i E', 'want a beer i E P']]
    vocab_size = len(vocab)

    size = 'mini'
    # size = 'small'
    # size = 'medium'
    # size = 'large'
    if size == 'mini':
        n_layers = 12
        d_model = 512
        n_heads = 8
        d_head = 64
    elif size == 'small':
        n_layers = 12
        d_model = 768
        n_heads = 12
        d_head = 64
    elif size == 'medium':
        n_layers = 24
        d_model = 1024
        n_heads = 16
        d_head = 64
    elif size == 'large':
        n_layers = 24
        d_model = 1536
        n_heads = 16
        d_head = 96

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    model = Transfourier(d_model, n_layers, n_heads, vocab_size).to(device)
 
    print_model(model)

    input_batch = []
    target_batch = []
    for sentence in data_root:
        input_batch.append([vocab[n] for n in sentence[0].split()])
        target_batch.append([vocab[n] for n in sentence[1].split()])
    inputs, targets = torch.LongTensor(input_batch).to(device), torch.LongTensor(target_batch).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(20):
        optimizer.zero_grad()
        outputs, loss = model(inputs, targets=targets)
        print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
        loss.backward()
        optimizer.step()