import math
import torch
import torch.nn as nn
from torch.nn.utils.parametrizations import orthogonal as Orthogonal
from timm.models.vision_transformer import Mlp, Attention
from einops import rearrange, repeat
import torch.nn.functional as F
from typing import Final
from timm.layers import use_fused_attn
from torch import Tensor
from torch.nn import init
from collections.abc import Iterable
from itertools import repeat
from packaging import version
    

class SVDLinearOptimized(nn.Module):
    def __init__(self, in_features: int, out_features: int, gene_size: int, other_size_1: int, other_size_2: int, bias: bool = True, device=None, dtype=None, init='kaiming') -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(SVDLinearOptimized, self).__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.gene_size = gene_size
        self.other_size_1 = other_size_1
        self.other_size_2 = other_size_2

        self.init = init

        if out_features > in_features:
            self.u_gene = Orthogonal(nn.Linear(in_features, gene_size, bias=False, **factory_kwargs), orthogonal_map = 'cayley', use_trivialization = False)
            self.v_gene = nn.Linear(gene_size, out_features, bias=False, **factory_kwargs)

            self.u_other_1 = Orthogonal(nn.Linear(in_features, other_size_1, bias=False, **factory_kwargs), orthogonal_map = 'cayley', use_trivialization = False)
            self.v_other_1 = nn.Linear(other_size_1, out_features, bias=False, **factory_kwargs)

            self.u_other_2 = Orthogonal(nn.Linear(in_features, other_size_2, bias=False, **factory_kwargs), orthogonal_map = 'cayley', use_trivialization = False)
            self.v_other_2 = nn.Linear(other_size_2, out_features, bias=False, **factory_kwargs)

            with torch.no_grad():
                self.u_gene.parametrizations.weight.original.zero_()
                self.u_other_1.parametrizations.weight.original.zero_()
                self.u_other_2.parametrizations.weight.original.zero_()
        else:
            self.u_gene = nn.Linear(in_features, gene_size, bias=False, **factory_kwargs)
            self.v_gene = Orthogonal(nn.Linear(gene_size, out_features, bias=False, **factory_kwargs), orthogonal_map = 'cayley', use_trivialization = False)

            self.u_other_1 = nn.Linear(in_features, other_size_1, bias=False, **factory_kwargs)
            self.v_other_1 = Orthogonal(nn.Linear(other_size_1, out_features, bias=False, **factory_kwargs), orthogonal_map = 'cayley', use_trivialization = False)

            self.u_other_2 = nn.Linear(in_features, other_size_2, bias=False, **factory_kwargs)
            self.v_other_2 = Orthogonal(nn.Linear(other_size_2, out_features, bias=False, **factory_kwargs), orthogonal_map = 'cayley', use_trivialization = False)

            with torch.no_grad():
                self.v_gene.parametrizations.weight.original.zero_()
                self.v_other_1.parametrizations.weight.original.zero_()
                self.v_other_2.parametrizations.weight.original.zero_()

        if bias:
            self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self) -> None:

        if self.out_features > self.in_features:
            if self.init == 'xavier':
                init.xavier_uniform_(self.v_gene.weight)
                init.xavier_uniform_(self.v_other_1.weight)
                init.xavier_uniform_(self.v_other_2.weight)
            else:
                init.kaiming_uniform_(self.v_gene.weight, a=math.sqrt(5))
                init.kaiming_uniform_(self.v_other_1.weight, a=math.sqrt(5))
                init.kaiming_uniform_(self.v_other_2.weight, a=math.sqrt(5))
        else:
            if self.init == 'xavier':
                init.xavier_uniform_(self.u_gene.weight)
                init.xavier_uniform_(self.u_other_1.weight)
                init.xavier_uniform_(self.u_other_2.weight)
            else:
                init.kaiming_uniform_(self.u_gene.weight, a=math.sqrt(5))
                init.kaiming_uniform_(self.u_other_1.weight, a=math.sqrt(5))
                init.kaiming_uniform_(self.u_other_2.weight, a=math.sqrt(5))

        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(torch.empty(self.out_features, self.in_features))
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input: Tensor, gate1, gate2) -> Tensor:
        outputs = []

        if self.gene_size > 0:
            out_gene = self.v_gene(self.u_gene(input))
            outputs.append(out_gene)

        if self.other_size_1 > 0:
            out_other_1 = self.v_other_1(self.u_other_1(input) * gate1)
            outputs.append(out_other_1)

        if self.other_size_2 > 0:
            out_other_2 = self.v_other_2(self.u_other_2(input) * gate2)
            outputs.append(out_other_2)

        if not outputs:
            raise ValueError("At least one of gene_size, other_size_1, or other_size_2 must be > 0.")

        out = sum(outputs)

        if self.bias is not None:
            out = out + self.bias

        return out  
        

class SVDMultiheadAttention(nn.Module):
    def __init__(self, d_model, num_heads, gene_size, other_size_1, other_size_2, dropout=0.):
        super(SVDMultiheadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        self.q = SVDLinearOptimized(d_model, d_model, gene_size=gene_size, other_size_1=other_size_1, other_size_2=other_size_2, bias=True, init='xavier')
        self.k = SVDLinearOptimized(d_model, d_model, gene_size=gene_size, other_size_1=other_size_1, other_size_2=other_size_2, bias=True, init='xavier')
        self.v = SVDLinearOptimized(d_model, d_model, gene_size=gene_size, other_size_1=other_size_1, other_size_2=other_size_2, bias=True, init='xavier')

        self.proj = SVDLinearOptimized(d_model, d_model, gene_size=gene_size, other_size_1=other_size_1, other_size_2=other_size_2)

        self.dropout = dropout
    
    def forward(self, xq, xk, xv, mask=None, src_mask=None, gate1=None, gate2=None):
        N, B, C = xq.shape

        H = self.num_heads
        D = self.head_dim

        q = self.q(xq, gate1, gate2)
        k = self.k(xk, gate1, gate2)
        v = self.v(xv, gate1, gate2)

        q = q.contiguous().view(N, B * H, D).transpose(0, 1)
        k = k.contiguous().view(N, B * H, D).transpose(0, 1)
        v = v.contiguous().view(N, B * H, D).transpose(0, 1)

        attn_mask = None
        if src_mask is not None:
            src_mask_expanded = src_mask.unsqueeze(0).expand(B * H, -1, -1)
            src_mask_expanded = src_mask_expanded.masked_fill(src_mask_expanded, float("-inf")).float()
            attn_mask = src_mask_expanded if attn_mask is None else attn_mask + src_mask_expanded

        if mask is not None:
            mask = mask.view(B, 1, 1, N).expand(-1, H, -1, -1).reshape(B * H, 1, N)
            new_attn_mask = torch.zeros_like(mask, dtype=torch.float)
            new_attn_mask.masked_fill_(mask, float("-inf"))
            attn_mask = new_attn_mask

        
        if version.parse(torch.__version__) >= version.parse("2.0.0"):
            out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=self.dropout)
        else:
            out, _ = F._scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=self.dropout)
        out = out.transpose(0, 1).contiguous().view(N, B, C)

        out = self.proj(out, gate1, gate2)
        return out, None