import torch
import torch.nn as nn
import math
import torch.nn.functional as F
from typing import Optional, Tuple
from torch.nn.parameter import Parameter
from torch.nn.init import constant_
from torch import Tensor
from transformers.models.t5.configuration_t5 import T5Config
import math

class Conv_DyN(nn.Module):
    def __init__(
        self,
        kernel_size,
        input_channels,
        output_channels,
        stride,
        padding,
        num_CHs,
        q_dim, 
        norm_p, 
        SCALE_FACTOR_conv=0.01
    ):
        super(Conv_DyN, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.norm_p = norm_p
        self.num_CHs = num_CHs

        self.Conv_inputQs = torch.nn.Parameter(torch.rand(kernel_size, num_CHs, input_channels, q_dim))
        self.Conv_outputQs = torch.nn.Parameter(torch.rand(kernel_size, num_CHs, output_channels, q_dim))
        
        self.shared_coeff_conv = torch.nn.Parameter(SCALE_FACTOR_conv*torch.tensor([(-1)**h_id for h_id in range(self.num_CHs)]).unsqueeze(1).unsqueeze(2), requires_grad=False)

    def pathIntegrals(self, _inputQs, _outputQs, _coeffs):

        dist_W = torch.cdist(_inputQs, _outputQs, p=self.norm_p)
        return torch.sum(dist_W*_coeffs, 0)

    
    def rebuild_conv(self, conv_inputQs, conv_outputQs, _coeffs):

        input_ch, output_ch, kernel_size = conv_inputQs.shape[2], conv_outputQs.shape[2], conv_inputQs.shape[0]
        res_conv = torch.zeros(output_ch, input_ch, kernel_size, kernel_size).to(conv_outputQs.device)

        for _i in range(kernel_size):
            for _j in range(kernel_size):
                res_conv[:,:,_j,_i] += self.pathIntegrals(conv_outputQs[_i], conv_inputQs[_j], _coeffs)
        print(torch.mean(res_conv))
        return res_conv

    
    def forward(self, x):

        W_conv = self.rebuild_conv(self.Conv_inputQs, self.Conv_outputQs, self.shared_coeff_conv)
        out = F.conv2d(x, W_conv, stride=self.stride, padding=self.padding)
        return out


class Linear_DyN(nn.Module):
    def __init__(self, in_features, out_features, num_Hs, q_dim, norm_p, SCALE_FACTOR_fc=0.1, bias=None):
        super(Linear_DyN, self).__init__()
        self.norm_p = norm_p
        self.num_Hs = num_Hs
        
        self.In_Qs = torch.nn.Parameter(torch.rand(num_Hs, in_features, q_dim))
        self.Out_Qs = torch.nn.Parameter(torch.rand(num_Hs, out_features, q_dim))
        self.shared_coeff_fc = torch.nn.Parameter(SCALE_FACTOR_fc*torch.tensor([(-1)**h_id for h_id in range(self.num_Hs)]).unsqueeze(1).unsqueeze(2), requires_grad=False)
        self.bias = torch.nn.Parameter(torch.rand(out_features))
        nn.init.constant_(self.bias, 0)

    def pathIntegrals(self):

        dist_W = torch.cdist(self.In_Qs, self.Out_Qs, p=self.norm_p)
        return torch.sum(dist_W*self.shared_coeff_fc, 0)
    
    def forward(self, x):

        W_fc = self.pathIntegrals()
        #out = x.matmul(W_fc) + self.bias
        return F.linear(x, W_fc.T, self.bias)


class MultiheadAttention_DyN(nn.Module):
    __constants__ = ['batch_first']
    bias_k: Optional[torch.Tensor]
    bias_v: Optional[torch.Tensor]

    def __init__(self, embed_dim, num_heads, num_Hs, q_dim, norm_p, SCALE_FACTOR_fc=0.1,
                 dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False,
                 kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.embed_dim = embed_dim
        self.kdim = kdim if kdim is not None else embed_dim
        self.vdim = vdim if vdim is not None else embed_dim
        self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim

        self.num_heads = num_heads
        self.dropout = dropout
        self.batch_first = batch_first
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"

        if not self._qkv_same_embed_dim:
            self.q_proj_weight = Linear_DyN(embed_dim, embed_dim, num_Hs, q_dim, norm_p, SCALE_FACTOR_fc)
            #Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
            self.k_proj_weight = Linear_DyN(self.kdim, embed_dim, num_Hs, q_dim, norm_p, SCALE_FACTOR_fc)
            #Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))
            self.v_proj_weight = Linear_DyN(self.vdim, embed_dim, num_Hs, q_dim, norm_p, SCALE_FACTOR_fc)
            #Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))
            self.register_parameter('in_proj_weight', None)
        else:
            self.in_proj_weight = Linear_DyN(embed_dim, 3 * embed_dim, num_Hs, q_dim, norm_p, SCALE_FACTOR_fc)
            #Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
            self.register_parameter('q_proj_weight', None)
            self.register_parameter('k_proj_weight', None)
            self.register_parameter('v_proj_weight', None)

        if bias:
            self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
        else:
            self.register_parameter('in_proj_bias', None)
        self.out_proj = Linear_DyN(embed_dim, embed_dim, num_Hs, q_dim, norm_p, SCALE_FACTOR_fc)
        #NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)

        if add_bias_kv:
            self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
            self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
        else:
            self.bias_k = self.bias_v = None

        self.add_zero_attn = add_zero_attn

        self._reset_parameters()

    def _reset_parameters(self):

        if self.in_proj_bias is not None:
            constant_(self.in_proj_bias, 0.)
            constant_(self.out_proj.bias, 0.)
        if self.bias_k is not None:
            xavier_normal_(self.bias_k)
        if self.bias_v is not None:
            xavier_normal_(self.bias_v)

    def __setstate__(self, state):
        if '_qkv_same_embed_dim' not in state:
            state['_qkv_same_embed_dim'] = True

        super().__setstate__(state)

    def forward(
            self,
            query: Tensor,
            key: Tensor,
            value: Tensor,
            key_padding_mask: Optional[Tensor] = None,
            need_weights: bool = True,
            attn_mask: Optional[Tensor] = None,
            average_attn_weights: bool = True,
            is_causal : bool = False) -> Tuple[Tensor, Optional[Tensor]]:

        is_batched = query.dim() == 3
        if key_padding_mask is not None:
            _kpm_dtype = key_padding_mask.dtype
            if _kpm_dtype != torch.bool and not torch.is_floating_point(key_padding_mask):
                raise AssertionError(
                    "only bool and floating types of key_padding_mask are supported")


        why_not_fast_path = ''
        if not is_batched:
            why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
        elif query is not key or key is not value:
            why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
        elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
            why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
        elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.In_Qs.dtype:
            # this case will fail anyway, but at least they'll get a useful error message.
            why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.In_Qs.dtype}) don't match"
        elif self.training:
            why_not_fast_path = "training is enabled"
        elif not self.batch_first:
            why_not_fast_path = "batch_first was not True"
        elif self.bias_k is not None:
            why_not_fast_path = "self.bias_k was not None"
        elif self.bias_v is not None:
            why_not_fast_path = "self.bias_v was not None"
        elif self.add_zero_attn:
            why_not_fast_path = "add_zero_attn was enabled"
        elif not self._qkv_same_embed_dim:
            why_not_fast_path = "_qkv_same_embed_dim was not True"
        elif query.is_nested and (key_padding_mask is not None or attn_mask is not None):
            why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \
                                 is not supported with NestedTensor input"
        elif torch.is_autocast_enabled():
            why_not_fast_path = "autocast is enabled"

            
        in_proj_weight = self.in_proj_weight.pathIntegrals().T
        out_proj_weight = self.out_proj.pathIntegrals().T
        
        if not why_not_fast_path:
            tensor_args = (
                query,
                key,
                value,
                in_proj_weight,
                self.in_proj_bias,
                out_proj_weight,
                self.out_proj.bias,
            )
            # We have to use list comprehensions below because TorchScript does not support
            # generator expressions.
            if torch.overrides.has_torch_function(tensor_args):
                why_not_fast_path = "some Tensor argument has_torch_function"
            elif not all([(x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args]):
                why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
            elif torch.is_grad_enabled() and any([x.requires_grad for x in tensor_args]):
                why_not_fast_path = ("grad is enabled and at least one of query or the "
                                     "input/output projection weights or biases requires_grad")
            if not why_not_fast_path:
                return torch._native_multi_head_attention(
                    query,
                    key,
                    value,
                    self.embed_dim,
                    self.num_heads,
                    in_proj_weight,
                    self.in_proj_bias,
                    out_proj_weight,
                    self.out_proj.bias,
                    key_padding_mask if key_padding_mask is not None else attn_mask,
                    need_weights,
                    average_attn_weights,
                    )

        any_nested = query.is_nested or key.is_nested or value.is_nested
        assert not any_nested, ("MultiheadAttention does not support NestedTensor outside of its fast path. " +
                                f"The fast path was not hit because {why_not_fast_path}")

        if self.batch_first and is_batched:
            # make sure that the transpose op does not affect the "is" property
            if key is value:
                if query is key:
                    query = key = value = query.transpose(1, 0)
                else:
                    query, key = [x.transpose(1, 0) for x in (query, key)]
                    value = key
            else:
                query, key, value = [x.transpose(1, 0) for x in (query, key, value)]

        if not self._qkv_same_embed_dim:
            attn_output, attn_output_weights = F.multi_head_attention_forward(
                query, key, value, self.embed_dim, self.num_heads,
                in_proj_weight, self.in_proj_bias,
                self.bias_k, self.bias_v, self.add_zero_attn,
                self.dropout, out_proj_weight, self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask, need_weights=need_weights,
                attn_mask=attn_mask,
                use_separate_proj_weight=True,
                q_proj_weight=self.q_proj_weight.pathIntegrals(), k_proj_weight=self.k_proj_weight.pathIntegrals(),
                v_proj_weight=self.v_proj_weight.pathIntegrals(),
                average_attn_weights=average_attn_weights,
                )
        else:
            attn_output, attn_output_weights = F.multi_head_attention_forward(
                query, key, value, self.embed_dim, self.num_heads,
                in_proj_weight, self.in_proj_bias,
                self.bias_k, self.bias_v, self.add_zero_attn,
                self.dropout, out_proj_weight, self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask,
                need_weights=need_weights,
                attn_mask=attn_mask,
                average_attn_weights=average_attn_weights,
                )
        if self.batch_first and is_batched:
            return attn_output.transpose(1, 0), attn_output_weights
        else:
            return attn_output, attn_output_weights


    def merge_masks(self, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor],
                    query: Tensor) -> Tuple[Optional[Tensor], Optional[int]]:
        r"""
        Determine mask type and combine masks if necessary. If only one mask is provided, that mask
        and the corresponding mask type will be returned. If both masks are provided, they will be both
        expanded to shape ``(batch_size, num_heads, seq_len, seq_len)``, combined with logical ``or``
        and mask type 2 will be returned
        Args:
            attn_mask: attention mask of shape ``(seq_len, seq_len)``, mask type 0
            key_padding_mask: padding mask of shape ``(batch_size, seq_len)``, mask type 1
            query: query embeddings of shape ``(batch_size, seq_len, embed_dim)``
        Returns:
            merged_mask: merged mask
            mask_type: merged mask type (0, 1, or 2)
        """
        mask_type: Optional[int] = None
        merged_mask: Optional[Tensor] = None

        attn_mask = F._canonical_mask(
            mask=attn_mask,
            mask_name="attn_mask",
            other_type=None,
            other_name="",
            target_type=query.dtype,
            check_other=False,
        )

        if key_padding_mask is not None:
            mask_type = 1
            merged_mask = key_padding_mask

        if attn_mask is not None:
            # In this branch query can't be a nested tensor, so it has a shape
            batch_size, seq_len, _ = query.shape
            mask_type = 2

            # Always expands attn_mask to 4D
            if attn_mask.dim() == 3:
                attn_mask_expanded = attn_mask.view(batch_size, -1, seq_len, seq_len)
            else:  # attn_mask.dim() == 2:
                attn_mask_expanded = attn_mask.view(1, 1, seq_len, seq_len).expand(batch_size, self.num_heads, -1, -1)
            merged_mask = attn_mask_expanded

            if key_padding_mask is not None:
                key_padding_mask_expanded = key_padding_mask.view(batch_size, 1, 1, seq_len).expand(-1, self.num_heads, -1, -1)
                merged_mask = attn_mask_expanded + key_padding_mask_expanded

        # no attn_mask and no key_padding_mask, returns None, None
        return merged_mask, mask_type


class Linear_DyN_NoMat(nn.Module):
    def __init__(self, in_features, out_features, num_Hs, q_dim, norm_p, SCALE_FACTOR_fc=0.01, bias=None):
        super(Linear_DyN_NoMat, self).__init__()
        self.num_Hs = num_Hs
        self.norm_p = norm_p
        self.in_features = in_features
        self.out_features = out_features

        self.In_Qs = torch.nn.Parameter(torch.rand(num_Hs, in_features, q_dim))
        self.Out_Qs = torch.nn.Parameter(torch.rand(num_Hs, out_features, q_dim))
        #self.shared_coeff = torch.nn.Parameter(SCALE_FACTOR_fc*torch.tensor([(-1)**h_id for h_id in range(self.num_Hs)]), requires_grad=False)
        self.shared_coeff = torch.nn.Parameter(SCALE_FACTOR_fc*torch.tensor([(-1)**h_id for h_id in range(self.num_Hs)], dtype=torch.float32))
        self.bias = torch.nn.Parameter(torch.rand(out_features))
        nn.init.constant_(self.bias, 0)
        nn.init.kaiming_uniform_(self.In_Qs, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.Out_Qs, a=math.sqrt(5))

    def pathIntegrals(self):
        dist_io = torch.cdist(self.In_Qs, self.Out_Qs) ** 2
        dist_io = torch.sum(self.shared_coeff.unsqueeze(-1).unsqueeze(-1)*dist_io, dim=0)
        return dist_io

    def forward_(self):
        return self.pathIntegrals().T

    def forward(self, x):
        in_shape = x.shape
        assert in_shape[-1] == self.in_features
        x = x.reshape(-1, self.in_features)
        In_Qs_Square, Out_Qs_Square = self.In_Qs**2, self.Out_Qs**2
        #Inner_Product = self.Out_Qs @ self.In_Qs.permute(0, 2, 1)
        #Inner_Product = torch.einsum('h m r, h n r -> h m n', self.Out_Qs, self.In_Qs)
        
        In_Qs_Square = In_Qs_Square.sum(-1)*self.shared_coeff.unsqueeze(1)
        S_In_Qs_Square = x @ In_Qs_Square.permute(1, 0)
        S_In_Qs_Square = S_In_Qs_Square.sum(-1).unsqueeze(-1)
        #S_In_Qs_Square = torch.einsum('b n, h n -> b h', x, In_Qs_Square).sum(-1).unsqueeze(-1)

        Out_Qs_Square = Out_Qs_Square.sum(-1)*self.shared_coeff.unsqueeze(1)
        x_sum = x.sum(-1).unsqueeze(-1).repeat(1, self.num_Hs)
        S_Out_Qs_Square = x_sum @ Out_Qs_Square

        S_Inner_Product = self.In_Qs.permute(0, 2, 1) @ x.T
        S_Inner_Product = self.Out_Qs @ S_Inner_Product
        S_Inner_Product = S_Inner_Product*self.shared_coeff.unsqueeze(-1).unsqueeze(-1)
        S_Inner_Product = S_Inner_Product.sum(0).permute(1, 0)

        out_shape = list(in_shape)
        out_shape[-1] = self.out_features
        x_out = S_In_Qs_Square + S_Out_Qs_Square - 2 * S_Inner_Product
        x_out = x_out.reshape(*tuple(out_shape))

        x_out += self.bias
        return x_out

    def extra_repr(self) -> str:
        return 'In_Qs={}, Out_Qs={}, bias={}'.format(
            self.In_Qs.shape,
            self.Out_Qs.shape,
            self.bias is not None
        )