import math
import os
import numpy as np
import lightning.pytorch as pl
import torch
import wandb

from einops import rearrange
from collections import namedtuple
from lightning.pytorch.utilities.rank_zero import rank_zero_info
from torch import nn
from torch.nn import functional


class Projection(nn.Module):
    def __init__(self, args):
        super().__init__()
        # get local varables

        self.args = args
        self.input = torch.nn.Linear(1, args.proj_dim, bias=False)
        self.output = torch.nn.Linear(args.proj_dim, args.hid_dim, bias=False)
        self.norm = torch.nn.LayerNorm(args.proj_dim)
        for _ in range(args.layer):
            layer1 = torch.nn.Linear(args.proj_dim, args.proj_dim, bias=False)
            lists.append(layer1)
        self.blocks = torch.nn.ModuleList(lists)

    def forward(self, pos):
        pos = self.input(pos)
        for layer1 in self.blocks:
            pos = layer1(self.norm(pos))
            pos = torch.nn.functional.silu(pos)
        return self.output(pos)


class Conv(nn.Module):
    def __init__(self, args, layer_id):
        super().__init__()
        # get local varables

        self.args = args
        self.layer_id = layer_id
        self.ex_decay = torch.nn.Parameter(torch.tensor([(args.decay_rate * args.r1 + args.r2 *
                                                          (layer_id / args.n_layer))]))
        self.projection = Projection(args)
        self.key = torch.nn.Sequential(
            torch.nn.LayerNorm(args.hid_dim),
            torch.nn.Linear(args.hid_dim, args.hid_dim),
            torch.nn.Sigmoid()
        )
        self.kernel = torch.nn.Parameter(torch.ones(1, args.windows, args.hid_dim))
        self.one = torch.ones(1, args.windows, args.hid_dim)

    def forward(self, x):
        args = self.args

        pos = torch.arange(args.ctx_len, device=x.device).unsqueeze(-1) * 1.0

        ex_decay = self.ex_decay ** torch.arange(args.ctx_len, device=x.device).unsqueeze(-1)
        ex_decay = ex_decay + self.projection(ex_decay * pos)
        ex_decay = torch.nn.functional.silu(ex_decay).squeeze(0)
        decay = ex_decay
        ex_decay = ex_decay * self.key(x)
        ones = torch.ones_like(ex_decay).to(x.device)
        ex_decay = decay * args.rate1 + args.rate2* self.compute_batch(ex_decay, self.kernel) / (self.compute_batch(ones, self.one.to(x.device)))

        output = self.compute_batch(x, ex_decay)

        return output

    def compute_batch(self, x, a, dim=-2):
        n = x.shape[dim]
        a_expanded = a

        y = torch.fft.rfft(x, 2 * n, dim=dim)
        v = torch.fft.rfft(a_expanded, 2 * n, dim=dim)
        u = v * y

        output = torch.fft.irfft(u, 2 * n, dim=dim)[:, :n, :]

        return output


class ChannelMixing(nn.Module):
    def __init__(self, args):
        super().__init__()

        self.value = torch.nn.Linear(args.n_embd, args.n_embd * args.c)
        self.receptance = torch.nn.Linear(args.n_embd, args.n_embd * args.c)
        self.output = torch.nn.Linear(args.n_embd * args.c, args.n_embd)

    def forward(self, x):
        value = self.value(x)
        value = torch.nn.functional.silu(value)
        receptance = self.receptance(x)
        value = value * receptance
        return self.output(value)


class TokenMixing(nn.Module):
    def __init__(self, args, layer_id):
        super().__init__()
        self.args = args
        self.layer_id = layer_id
        # linear projection
        self.input = torch.nn.Linear(args.n_embd, args.hid_dim)
        self.output = torch.nn.Linear(args.hid_dim, args.n_embd)
        self.receptance = torch.nn.Linear(args.n_embd, args.hid_dim)
        self.norm = SimpleRMSNorm(args.hid_dim)
        self.conv = Conv(args, layer_id)

    def forward(self, x):
        args = self.args
        output = self.conv(self.input(x))
        receptance = self.receptance(x)
        output = torch.nn.functional.silu(output) * receptance
        # print(torch.max(output), torch.min(output))
        output = self.norm(output)
        output = self.output(output)

        return output


class Block(nn.Module):
    def __init__(self, args, layer_id):
        super().__init__()
        self.layer_id = layer_id
        self.ln1 = torch.nn.LayerNorm(args.n_embd)
        self.ln2 = torch.nn.LayerNorm(args.n_embd)
        self.token_mixing = TokenMixing(args, layer_id)
        self.channel_mixing = ChannelMixing(args)

    def forward(self, x):
        # x: b, t, e
        result = self.token_mixing(self.ln1(x))
        x = x + result
        x = x + self.channel_mixing(self.ln2(x))

        return x


class MyModule(pl.LightningModule):
    def __init__(self, args):
        super().__init__()
        self.args = args

        from tokenizers import Tokenizer
        self.pad_tokenizer = Tokenizer.from_file("src/tokenizer.json")
        self.pad_tokenizer.enable_padding(pad_id=50256, length=int(args.ctx_len) + 1)
        self.pad_tokenizer.enable_truncation(int(args.ctx_len) + 1)

        if args.use_model == "conv":
            self.emb = torch.nn.Embedding(50257, args.n_embd)

            # blocks of time mixing, channel mixing and layer norm
            blocks = list()
            for i in range(args.n_layer):
                            blocks.append(Block(args, i))
            self.blocks = torch.nn.ModuleList(blocks)

            # layer norm in output
            self.ln_out = torch.nn.LayerNorm(args.n_embd)
            self.head = torch.nn.Linear(args.n_embd, 50257, bias=False),
            )
            self.head = torch.nn.Linear(args.n_embd, 50257, bias=False)

    def configure_optimizers(self):
        args = self.args
        optimizer = torch.optim.Adam(self.parameters(),
                                     lr=args.lr_init,
                                     eps=args.adam_eps,
                                     betas=(args.beta1, args.beta2),
                                     weight_decay=args.w)
        return optimizer

    def forward(self, x):

        args = self.args
        if args.use_model == "conv":

            x = self.emb(x)

            for block in self.blocks:
                x = block(x)

            x = self.ln_out(x)

        x = self.head(x)
        return x



