from re import X
from networkx.algorithms.centrality.degree_alg import in_degree_centrality, out_degree_centrality
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import one_hot
from torch.nn.modules.activation import GELU

from model_utils import NormalizedAggr, SparseAttention, Meanprop

from torch_sparse import SparseTensor

from einops import rearrange
import numpy as np
import pdb
import pickle
import math
from functools import wraps

##################################################
from torch_geometric.nn.conv.gcn_conv import gcn_norm

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def cache_fn(f):
    cache = None
    @wraps(f)
    def cached_fn(*args, _cache = True, **kwargs):
        if not _cache:
            return f(*args, **kwargs)
        nonlocal cache
        if cache is not None:
            return cache
        cache = f(*args, **kwargs)
        return cache
    return cached_fn

class PreNorm(nn.Module):
    def __init__(self, dim, fn, key_dim=None):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)
        self.norm_context = nn.LayerNorm(key_dim) if exists(key_dim) else None
        
    def forward(self, x, **kwargs):
        x = self.norm(x)
        if exists(self.norm_context):
            context = kwargs['key']
            normed_context = self.norm_context(context)
            kwargs.update(key = normed_context)

        return self.fn(x, **kwargs)
    
class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim = -1)
        return x * F.gelu(gates)

class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult),
            nn.GELU(),
            nn.Linear(dim * mult, dim)
        )

    def forward(self, x):
        return self.net(x)
    
        
class MLP(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, n_layers, act_layer=torch.nn.ReLU,
                 dropout=0.5, evaluate=False):
        super(MLP, self).__init__()

        self.lins = torch.nn.ModuleList()
        self.lins.append(torch.nn.Linear(in_dim, hidden_dim))
        self.acts = torch.nn.ModuleList()
        self.acts.append(act_layer())
        for _ in range(n_layers - 2):
            self.lins.append(torch.nn.Linear(hidden_dim, hidden_dim))

        self.lins.append(torch.nn.Linear(hidden_dim, out_dim))

        self.dropout = dropout
        self.evaluate = evaluate

    def reset_parameters(self):
        for lin in self.lins:
            lin.reset_parameters()

    def forward(self, x, **kwargs):
        for i, lin in enumerate(self.lins[:-1]):
            x = lin(x)
            # x = self.bns[i](x)
            x = self.acts[i](x)
            # x = torch.nn.ReLU()(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lins[-1](x)
        if self.evaluate:
            x = x.log_softmax(dim=-1)

        return x

class PositionalEncoding(torch.nn.Module):
    def __init__(self, data, hidden_channels, pe_type, train_mask=None, data_name=None, args=None, lap_pe=None):
        super().__init__()
        self.pe_type = pe_type
        self.data = data
        self.hidden_channels = hidden_channels

        self.alpha = nn.Parameter(torch.tensor([1, 0.00001, 0.00001]))
        self.mlp = nn.Linear(self.data.num_nodes, args.bottleneck, self.hidden_channels)
        self.dist_mat = self.gen_dist_mat(data.edge_index)
        self.dist_mat = [a.to(data.x.device) for a in self.dist_mat]

    def gen_dist_mat(self, edge_index):
        edge_index = edge_index.to('cpu')
        edge_weight = torch.ones(edge_index.size(1)).to(edge_index.device)
        num_nodes = self.data.num_nodes
        A = SparseTensor.from_edge_index(edge_index, edge_weight, [num_nodes, num_nodes])
        A2 = A @ A
        A3 = A2 @ A
        dist_mat = (A, A2, A3)
        return dist_mat
    
    def forward(self, degree=None):
        A, A2, A3 = self.dist_mat
        pe = self.alpha[0] * A.to_dense() + self.alpha[1] * A2.to_dense() + self.alpha[2] * A3.to_dense()
        pe[torch.arange(self.data.num_nodes),torch.arange(self.data.num_nodes)] = 0
        pe = self.mlp(pe.float())
            

        return pe
    

class DeformableGT(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, data, dropout, pe_type, num_blocks=1, rel_pe="no", args=None, num_heads=4, num_axis=1, spd_mat=None, lap_pe=None):
        super().__init__()
        self.hidden_dim =  hidden_dim
        self.num_blocks = num_blocks

        self.enc = nn.Linear(in_dim, hidden_dim)        


        self.out_linear = PreNorm(hidden_dim, torch.nn.Linear(hidden_dim, out_dim))
        self.pe_type = pe_type

        self.blocks = nn.ModuleList([])
        for _ in range(num_blocks):
            self.blocks.append(DeformableBlock(hidden_dim, num_heads, data=data, rel_pe=rel_pe, 
                            interpolate_mode=args.interpolate_mode, bandwidth=args.bandwidth, num_axis=num_axis, args=args))
            
        if self.pe_type != "no":
            self.pe = PositionalEncoding(data, hidden_dim, pe_type, data_name=args.data_name, args=args, lap_pe=lap_pe)
            self.pe2 = PositionalEncoding(data, hidden_dim, pe_type, data_name=args.data_name, args=args, lap_pe=lap_pe)

        self.dropout = dropout
        self.spd_mat  = spd_mat
        
        self.args = args

        
    def decay_bandwidth(self, decaying_factor, min_bandwidth):
        for block in self.blocks:
            bandwidth = block.attn.fn.bandwidth
            if bandwidth <= min_bandwidth:
                break
            block.attn.fn.bandwidth = math.floor(bandwidth * decaying_factor)
    
    def forward(self, x, **kwargs):
        reference_points = kwargs['reference_points']
        edge_index = kwargs['edge_index']
        test = kwargs['test']
        
        x = self.enc(x)

        if self.pe_type != "no":
            x = x + self.pe()
        x = F.dropout(x, p=self.dropout, training=self.training)
            
        for block in self.blocks:
            x = block(x, reference_points=reference_points, spd_mat=self.spd_mat, edge_index=edge_index, test=test)                        
        
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.out_linear(x)
        
        return x.log_softmax(dim=-1)



#https://github.com/Meituan-AutoML/Twins/blob/main/gvt.py 
class DeformableAttention(torch.nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., 
                data=None, rel_pe=None, interpolate_mode="gaussian", bandwidth=8, num_axis=1, args=None):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."

        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = qk_scale or self.head_dim ** -0.5
        self.attn_drop = torch.nn.Dropout(attn_drop)
        self.proj = torch.nn.Linear(dim*num_axis, dim)
        self.proj_drop = torch.nn.Dropout(proj_drop)
        self.data = data
        ################################################
        self.num_points = args.num_points
        self.v = torch.nn.Linear(dim, dim * num_axis, bias=qkv_bias)
        self.sampling_offsets = nn.Linear(dim, self.num_points*self.num_heads*num_axis)
        self.attention_weights = nn.Linear(dim, self.num_points*self.num_heads*num_axis)

        self.num_axis = num_axis
        self.bandwidth = args.bandwidth
        self.eps = args.eps
        self.query_gnn = args.query_gnn
        self.value_gnn = args.value_gnn
        self.change_weight = args.change_weight
        self.change_bias = args.change_bias
        if args.change_bias:
            self.bias = args.bias
        if args.change_weight:
            self.weight = args.weight


        self.interpolate_mode = interpolate_mode
        self.hop_coef = torch.FloatTensor([1, 0.9, 0.7, 0.5, 0.3, 0.2]).to('cuda')
        # multi-bandwidth
        self.multi_bandwidth = None
        if args.multi_bandwidth != '':
            multi_bandwidth = args.multi_bandwidth.split(',')
            multi_bandwidth = [float(bandwidth) for bandwidth in multi_bandwidth]
            self.multi_bandwidth = torch.from_numpy(np.array(multi_bandwidth))[None,None,None,None,:]
            self.multi_bandwidth = self.multi_bandwidth.float()
            self.multi_bandwidth_attn = nn.Parameter(torch.zeros(self.multi_bandwidth.shape[-1]))
        # remove_negative_points
        self.remove_negative_points = args.remove_negative_points
        self.reset_parameters()
        
    def reset_parameters(self):
        self.proj.reset_parameters()
        if self.change_bias:
            self.sampling_offsets.bias=nn.Parameter((torch.arange(self.num_points)*self.bias).unsqueeze(1).repeat(1,self.num_heads*self.num_axis).reshape(-1))            
        if self.change_weight:
            self.sampling_offsets.weight = nn.Parameter(self.sampling_offsets.weight*self.weight) 

        self.attention_weights.reset_parameters()

    def interpolate(self, sampling_locations, reference_points, v, N, spd_mat, test):
        start_points = (sampling_locations-self.eps).ceil().long()
        index_heads = np.array(list(np.arange(self.num_heads)) * (N*self.num_points))

        points_value = 0
        values = [None]*self.num_axis
        # multi-bandwidth
        if exists(self.multi_bandwidth):
            multi_bandwidth_attn = torch.softmax(self.multi_bandwidth_attn, dim=0) 
            multi_bandwidth_attn = multi_bandwidth_attn[None,None,None,None,:]
            multi_bandwidth = self.multi_bandwidth.to(sampling_locations.device)

        for i in range(self.eps*2):
            points = start_points+i
            if self.bandwidth == 0:
                coef = 1
            else:
                # multi-bandwidth
                if exists(self.multi_bandwidth):
                    if not test:
                        coef = torch.sum(multi_bandwidth_attn * (-(points.unsqueeze(-1)-sampling_locations.unsqueeze(-1)).pow(2)/multi_bandwidth).exp(), dim=-1)
                    else:
                        multi_bandwidth_attn = multi_bandwidth_attn.squeeze()
                        multi_bandwidth = multi_bandwidth.squeeze()
                        max_bandwidth_idx = multi_bandwidth_attn.argmax()
                        coef = (-(points-sampling_locations).pow(2)/multi_bandwidth[max_bandwidth_idx]).exp()
                        # coef = (1/multi_bandwidth_attn[max_bandwidth_idx]) * (-(points-sampling_locations).pow(2)/multi_bandwidth[max_bandwidth_idx]).exp()
                else:
                    coef = (-(points-sampling_locations).pow(2)/self.bandwidth).exp()

            mask = ((((points-sampling_locations).abs()-self.eps)<0)*(((points-sampling_locations).abs()+self.eps)>0)).float()
            coef = coef*mask
            points = torch.clamp(points, min=0, max=N-1)
            for j in range(self.num_axis):
                points_changed = rearrange(points[:,:,:,j], 'n h k -> n (h k)')
                points_changed = torch.gather(reference_points[j], 1, points_changed)     
                temp_value = v[points_changed.flatten(),index_heads+j*self.num_heads].view(N, self.num_heads, self.num_points, self.head_dim)* coef[:,:,:,j].unsqueeze(-1)
                values[j] = temp_value if i==0 else values[j] + temp_value
        
        
        if self.num_axis > 1:
            points_value = torch.cat(values, 1)
        else:
            points_value = values[0]
                

        return points_value

    def forward(self, x, reference_points, spd_mat, edge_index, test):     
        '''
        :input x                        (N, C)
        :input reference_points         (N, 1) for 1d
        '''
        q = x
        if self.query_gnn:
            q = NormalizedAggr()(q, edge_index)

        if self.value_gnn:
            v = rearrange(self.v(x), 'n (h d) -> (n h) d', h=self.num_heads*self.num_axis)
            v = NormalizedAggr()(v, edge_index)
            v = rearrange(v, '(n h) d -> n h d', n=x.size(0))
        else:
            v = rearrange(self.v(x), 'n (h d) -> n h d', h=self.num_heads*self.num_axis)

        N = x.shape[0]
        sampling_offsets = rearrange(self.sampling_offsets(q), 'N (h n_p a) -> N h n_p a', h=self.num_heads, a=self.num_axis, n_p=self.num_points)
        attention_weights = rearrange(self.attention_weights(q), 'N (h n_p) -> N h n_p', h=self.num_heads*self.num_axis)
        attention_weights = F.softmax(attention_weights, -1)
        
        # generate sampling_locations by adding offsets to reference points
        if len(reference_points) == 1:
            sampling_locations = F.softplus(sampling_offsets) #/ (N, h, k, a)
        else:
            reference_points = torch.stack(reference_points)
            sampling_locations = F.softplus(sampling_offsets) #/ (N, h, k, a)
        # interpolate
        points_value = self.interpolate(sampling_locations, reference_points, v, N, spd_mat, test)
        # aggregating values of points for each head according to attention weights 
        x = (attention_weights.unsqueeze(-1) * points_value).sum(-2)
        x = rearrange(x, 'n h d -> n (h d)', h=self.num_heads*self.num_axis)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class DeformableBlock(torch.nn.Module):
    def __init__(self, hidden_dim, num_heads, qkv_bias=False, qk_scale=None, dropout=0., attn_drop=0.,
                 drop_path=0, norm_layer=torch.nn.LayerNorm, data=None, rel_pe="no", interpolate_mode="original", bandwidth=8, num_axis=1, args=None):
        super().__init__()
        self.attn =  PreNorm(hidden_dim, DeformableAttention(hidden_dim,
            num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
            attn_drop=attn_drop, proj_drop=dropout, data=data, rel_pe=rel_pe, interpolate_mode=interpolate_mode, bandwidth=bandwidth, num_axis=num_axis, args=args))
        
        self.ffn = PreNorm(hidden_dim, FeedForward(dim=hidden_dim, mult=1))
    
    def forward(self, x, reference_points=None, spd_mat=None, edge_index=None, test=False):
        x = self.attn(x, reference_points=reference_points, spd_mat=spd_mat, edge_index=edge_index, test=test) + x

        x = self.ffn(x) + x
        return x

