import torch
import torch.nn as nn
import os
import sys
import numpy as np
import math
import warnings
warnings.filterwarnings("ignore")
from typing import Optional, Union, Callable, Tuple
from torch import Tensor
from torch import _VF
from torch.utils.checkpoint import checkpoint
from torch.nn import Module, Linear, Dropout, LayerNorm
from torch.nn.init import constant_, xavier_uniform_
from torch.nn.parameter import Parameter
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from models.modules.MHA_SVD import AttentionLayer
import torch.nn.functional as F




class AGF_layer(Module):
    def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
                 activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, 
                 layer_norm_eps: float = 1e-5, bias: bool = True, poly_type= "jacobi", 
                 K=5, alpha=2.0, beta=-1.0, fixI=True, device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.self_SVD_attn = AttentionLayer(d_model=d_model, n_heads=nhead, attention_dropout=dropout,
                                            poly_type= "jacobi", K=K, alpha=alpha, beta=beta, fixI=fixI)
        self.cross_SVD_attn = AttentionLayer(d_model=d_model, n_heads=nhead, attention_dropout=dropout,
                                            poly_type= "jacobi", K=K, alpha=alpha, beta=beta, fixI=fixI)
        
        # Implementation of Feedforward model
        self.linear1 = Linear(d_model, dim_feedforward, bias=bias, **factory_kwargs)
        self.dropout = Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
        
        self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
        self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)

        # Legacy string support for activation function.
        if isinstance(activation, str):
            activation = self._get_activation_fn(activation)
        elif activation is F.relu or isinstance(activation, torch.nn.ReLU):
            pass
        elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
            pass
        self.activation = activation

    def __setstate__(self, state):
        super().__setstate__(state)
        if not hasattr(self, 'activation'):
            self.activation = F.relu
            
    def _get_activation_fn(self, activation: str) -> Callable[[Tensor], Tensor]:
        if activation == "relu":
            return F.relu
        elif activation == "gelu":
            return F.gelu

        raise RuntimeError(f"activation should be relu/gelu, not {activation}")

    def forward(self, src: Tensor, single_eval_pos: int) -> Tensor:
        x = src
        # pre-LN
        src = self.norm1(src)
        src_ctx = src[:,:single_eval_pos,:]
        src_trg = src[:,single_eval_pos:,:]
        
        y, ortho_loss = checkpoint(lambda a,b: self._attn_block(a,b), src_ctx, src_trg)
        x = x + y
        x = x + checkpoint(self._ff_block, self.norm2(x))

        return x, ortho_loss

    # self-attention block
    def _attn_block(self, src_ctx: Tensor, src_trg: Tensor) -> Tensor:
        x_ctx, ortho_loss_self  = self.self_SVD_attn(q=src_ctx, kv=src_ctx)
        x_trg, ortho_loss_cross = self.cross_SVD_attn(q=src_trg, kv=src_ctx)
        x = torch.cat([x_ctx, x_trg], dim=1)
        ortho_loss = ortho_loss_self + ortho_loss_cross
        
        return self.dropout1(x), ortho_loss
    
    def _ff_block(self, x: Tensor) -> Tensor:
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return self.dropout2(x)