import math
import torch
import torch.nn as nn
from einops import rearrange
from flashlla.ops.fwd import fwd_kernel
from flashlla.ops.bwd import bwd_kernel


class LLAFunction(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        q: torch.Tensor,
        k: torch.Tensor,
        v: torch.Tensor,
        ridge_lambda: torch.Tensor,
        qk_scale: float,
        delta_eps: float,
        cg_atol: float,
        cg_rtol: float,
        cg_max_iters: int,
        cg_use_preconditioner: bool,
    ) -> torch.Tensor:
        o, r, d, m = fwd_kernel(
            q, k, v,
            ridge_lambda,
            qk_scale,
            delta_eps,
            cg_atol,
            cg_rtol,
            cg_max_iters,
            cg_use_preconditioner,
        )
        qk_scale_tensor = torch.tensor(qk_scale, device=q.device, dtype=q.dtype)
        cg_atol_tensor = torch.tensor(cg_atol, device=q.device, dtype=q.dtype)
        cg_rtol_tensor = torch.tensor(cg_rtol, device=q.device, dtype=q.dtype)
        cg_max_iters_tensor = torch.tensor(cg_max_iters, device=q.device, dtype=q.dtype)
        cg_use_preconditioner_tensor = torch.tensor(cg_use_preconditioner, device=q.device, dtype=torch.bool)
        ctx.save_for_backward(
            q, k, v, r, d, m,
            ridge_lambda,
            qk_scale_tensor,
            cg_atol_tensor,
            cg_rtol_tensor,
            cg_max_iters_tensor,
            cg_use_preconditioner_tensor
        )
        return o

    @staticmethod
    def backward(ctx, grad_o):
        q, k, v, r, d, m, ridge_lambda, qk_scale, cg_atol, cg_rtol, cg_max_iters, cg_use_preconditioner = ctx.saved_tensors
        grad_q, grad_k, grad_v = bwd_kernel(
            q, k, v,
            r, d, m, grad_o,
            ridge_lambda,
            qk_scale.item(),
            cg_atol.item(),
            cg_rtol.item(),
            int(cg_max_iters.item()),
            bool(cg_use_preconditioner.item()),
        )
        return grad_q, grad_k, grad_v, None, None, None, None, None, None, None


class LLA(nn.Module):
    def __init__(
        self,
        dim: int,
        nhead: int,
        ridge_lambda: torch.Tensor
    ):
        super().__init__()
        self.dim = dim
        self.nhead = nhead
        self.head_dim = dim // nhead
        self.ridge_lambda = ridge_lambda
        self.q_proj = nn.Linear(dim, nhead * self.head_dim, bias=False)
        self.k_proj = nn.Linear(dim, nhead * self.head_dim, bias=False)
        self.v_proj = nn.Linear(dim, nhead * self.head_dim, bias=False)
        self.o_proj = nn.Linear(nhead * self.head_dim, dim, bias=False)
        self.qk_scale = 1.0 / math.sqrt(self.head_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        q, k, v = map(lambda x: rearrange(x, "bs(hd)->(bh)sd", h=self.nhead), (q, k, v))
        o = LLAFunction.apply(q, k, v, self.ridge_lambda, self.qk_scale)
        o = rearrange(o, "(bh)sd->bs(hd)", h=self.nhead)
        o = self.o_proj(o)
        return o