# Copyright (c) 2024, Sukjun Hwang, Aakash Lahoti, Ratish Puduppully, Tri Dao, Albert Gu.

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange


class Cauchy(nn.Module):
    def __init__(
        self,
        is_data_dependent,
        d_model,
        qk_dim,
        max_seq_len=None,   # max_seq_len is necessary for data-independent version.
        expand=2,
        headdim=128,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.is_data_dependent = is_data_dependent
        self.d_model = d_model
        self.qk_dim = qk_dim
        self.max_seq_len = max_seq_len
        self.expand = expand
        self.d_inner = self.expand * self.d_model
        self.headdim = headdim
        assert self.d_inner % self.headdim == 0
        self.nheads = self.d_inner // self.headdim
        self.d_state = self.nheads * qk_dim

        self.tol = 1e-8
        self.std_dev = 1 / np.sqrt(self.max_seq_len * self.qk_dim)
        if self.is_data_dependent:
            self.bias = nn.Parameter(torch.tensor(0.5))
        else:
            self.q_matrix = nn.Parameter(
                torch.empty(self.max_seq_len, self.nheads, self.qk_dim, **factory_kwargs))
            self.k_matrix = nn.Parameter(
                torch.empty(self.max_seq_len, self.nheads, self.qk_dim, **factory_kwargs))
            nn.init.xavier_normal_(self.q_matrix)
            nn.init.xavier_normal_(self.k_matrix)

    def forward(self, v, q=None, k=None):
        residual = v
        v = rearrange(v, 'b l (n h) -> b l n h', n=self.nheads)

        if self.is_data_dependent:
            q = rearrange(q, 'b l (n d) -> b n l 1 d', n=self.nheads)
            k = rearrange(k, 'b l (n d) -> b n 1 l d', n=self.nheads)
            q = torch.exp(q) + self.bias
            k = torch.exp(k) + self.bias

            inv_cauchy_matrix = q + k + self.tol
            cauchy_matrix = torch.sum(1 / inv_cauchy_matrix, dim=-1)

            output = torch.einsum('b t n h, b n l t -> b l n h', v, cauchy_matrix)
        else:
            # q, k: (nheads, seqlen, qkdim)
            q = torch.exp(self.q_matrix)
            k = torch.exp(self.k_matrix)

            inv_cauchy_matrix = (q.unsqueeze(1) + k.unsqueeze(0)) + self.tol
            cauchy_matrix = torch.sum(1 / inv_cauchy_matrix, dim=-1)

            output = torch.einsum('b t n h, l t n -> b l n h', v, cauchy_matrix)

        output = self.std_dev * output
        output = rearrange(output, 'b l n h -> b l (n h)') + residual

        return output
