import torch
import torch.nn as nn
from xformers.components.attention.feature_maps import SMReg
from typing import Optional, Literal
from scipy.fft import dct
import numpy as np
import math

class FeatureNet(nn.Module):
    dijiang_proj: torch.Tensor

    def __init__(self, config, layer_idx):
        super().__init__()
        num_heads = config.num_attention_heads
        head_dim = config.hidden_size // config.num_attention_heads
        
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.num_positions = config.n_positions
        self.feature_type = config.feature_type
        self.coef_type = config.coef_type
        self.num_features = config.num_features[layer_idx]
        
        if self.feature_type == "fourier":
            raise NotImplementedError
        
        elif self.feature_type == "positive":
            _omega = torch.randn(num_heads, self.num_features, head_dim)
            self.positive_omega = nn.Parameter(_omega)
            if self.coef_type == "standard":
                _coef = torch.ones(num_heads, self.num_features)
            elif self.coef_type == "exp":
                _coef = torch.zeros(num_heads, self.num_features)
            else:
                raise ValueError(f"Unknown coef type: {self.coef_type}")
            self.coef = nn.Parameter(_coef)

        elif self.feature_type == "dijiang":
            icdf_w = torch.distributions.Normal(0, 1).icdf(torch.rand(self.head_dim))
            _dct = dct(np.eye(self.head_dim, self.head_dim), axis=0, norm='ortho') # (Note: same as the published implementation)
            _dct = torch.from_numpy(_dct).float()
            _dijiang_proj = torch.einsum('fd,d->fd', _dct, icdf_w)
            self.register_buffer("dijiang_proj", _dijiang_proj)

            # (Note: These have O(L) learnable parameters, where L is the sequence length, which is same as the pulished implementation)
            # (Note: The parameters are common between the feature coordinates, which is same as the published implementation)
            _coef_query = ((1 - torch.exp(torch.linspace(math.log(1/32), math.log(1/512), self.num_heads))).view(self.num_heads, 1) 
                           ** torch.arange(self.num_positions)).float()
            _coef_key = 1 / ((1 - torch.exp(torch.linspace(math.log(1/32), math.log(1/512), self.num_heads))).view(self.num_heads, 1) 
                             ** torch.arange(self.num_positions)).float()
            self._coef_query = nn.Parameter(_coef_query)
            self._coef_key = nn.Parameter(_coef_key)

        elif self.feature_type == "performer":
            self.performer_feature_map = SMReg(dim_features=self.num_features, iter_before_redraw=None, normalize_inputs=False)

        else:
            raise ValueError(f"Unknown feature type: {self.feature_type}")
    
    def forward(self, X: torch.Tensor, 
                shift_value: Optional[torch.Tensor] = None, 
                korq: Optional[Literal["k", "q"]] = None):
        shift_value = shift_value if shift_value is not None else torch.tensor(0., device=X.device)
        
        _X = X / (self.head_dim ** 0.25)
        if self.feature_type == "fourier":
            out = self._fourier(_X, shift_value)
        elif self.feature_type == "positive":
            out = self._positive(_X, shift_value)
        elif self.feature_type == "dijiang":
            out = self._dijiang(X, korq)
        elif self.feature_type == "performer":
            out = self.performer_feature_map(X)
        else:
            raise ValueError(f"Unknown feature type: {self.feature_type}")
        
        return out

    def _positive(self, X, shift_value: torch.Tensor):
        omega_x = torch.einsum("bhld,hfd->bhlf", X, self.positive_omega)

        if self.coef_type == "standard":
            out = torch.exp(omega_x - 0.5 * (X.norm(dim=-1, keepdim=True) ** 2) - 0.5 * shift_value)
            out = torch.einsum("hf,bhlf->bhlf", self.coef, out) / (self.num_features ** 0.5)
        elif self.coef_type == "exp":
            _coef = self.coef.unsqueeze(-2)
            out = torch.exp(omega_x - 0.5 * (X.norm(dim=-1, keepdim=True) ** 2) 
                            - 0.5 * shift_value + _coef)
        else:
            raise ValueError(f"Unknown coef type: {self.coef_type}")
        
        return out
    
    def _dijiang(self, X, korq: Literal["k", "q"]):
        # (Note: Original implementation also normalize the input in the direction of the features, not the tokens)
        proj = nn.functional.softmax(torch.einsum("bhld,fd->bhlf", X, self.dijiang_proj), dim=-1) 
        if korq == "q":
            _coef_query = self._coef_query[:, :proj.size(-2)]
            out = torch.einsum('bhlf,hl->bhlf', proj, _coef_query)
        elif korq == "k":
            _coef_key = self._coef_key[:, :proj.size(-2)]
            out = torch.einsum('bhlf,hl->bhlf', proj, _coef_key)
        else:
            raise ValueError(f"Unknown korq: {korq}")
        
        return out
    
    def get_log_features(self, X, generator: bool=False):
        if self.feature_type != "positive":
            raise ValueError("This method is only available for the positive kernel")
        if self.coef_type != "exp":
            raise ValueError("This method is only available for the exponential coefficient")
        
        _X = X / (self.head_dim ** 0.25)
        omega_x = torch.einsum("bhld,hfd->bhlf", _X, self.positive_omega)
        _coef = self.coef.unsqueeze(-2) 
        log_features = omega_x - 0.5 * (_X.norm(dim=-1, keepdim=True) ** 2) + _coef
        if generator:
            return (log_features[..., i] for i in range(self.num_features))
        else:
            return log_features