import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy.sparse as sp
import numpy as np
from torch import Tensor
from config import args
from scipy.sparse import csr_matrix
from models.label_propagation_models import NonParaLP
from models.message_op.laplacian_graph_op import LaplacianGraphOp
from models.message_op.last_message_op import LastMessageOp
from models.message_op.over_smooth_distance_message_op import OverSmoothDistanceWeightedOp
from models.message_op.mean_message_op import MeanMessageOp
from models.message_op.sum_message_op import SumMessageOp
from models.message_op.concat_message_op import ConcatMessageOp

homo_layers = args.homo_layers
homo_drop = args.homo_drop
hete_drop = args.hete_drop
message_op = args.message_op
hete_negative_slope = args.hete_negative_slope
hete_re_smooth = args.hete_re_smooth
hete_re_temperature = args.hete_re_temperature


class MyModel(nn.Module):
    def __init__(self, prop_steps, feat_dim, hidden_dim, output_dim, r=0.5):
        super(MyModel, self).__init__()
        self.prop_steps = prop_steps
        self.feat_dim = feat_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim

        self.NonPLP = NonParaLP(prop_steps=100, num_class=self.output_dim, alpha=0.5)
        self.pre_graph_op = LaplacianGraphOp(prop_steps=self.prop_steps, r=r)
        self.post_graph_op = LaplacianGraphOp(prop_steps=100, r=r)
        self.post_msg_op = LastMessageOp()


    def homo_init(self):
        self.homo_model = HomoPropagateModel(num_layers=homo_layers,
        feat_dim=self.feat_dim,
        hidden_dim=self.hidden_dim,
        output_dim=self.output_dim,
        dropout=homo_drop,
        prop_steps=self.prop_steps,
        bn=False,
        ln=False)

        self.total_trainable_params = round(sum(p.numel() for p in self.homo_model.parameters() if p.requires_grad)/1000000, 3)

    def hete_init(self, x):
        self.hete_model = HetePropagateModel(
        feat_dim=self.feat_dim,
        hidden_dim=self.hidden_dim,
        output_dim=self.output_dim,
        ori_feature=x,
        dropout=hete_drop,
        prop_steps=self.prop_steps,
        bn=False,
        ln=False)

        self.total_trainable_params = round(sum(p.numel() for p in self.hete_model.parameters() if p.requires_grad)/1000000, 3)


    def non_para_lp(self, subgraph, nodes_embedding, x, device):
        self.nodes_embedding = nodes_embedding
        self.ori_feature = x
        self.homo = False
        self.NonPLP.preprocess(subgraph, device)
        self.NonPLP.propagate(adj=subgraph.adj)
        self.reliability_acc = self.NonPLP.eval()
        if self.reliability_acc >= 0.6 or (subgraph.x.shape[0] <= 200 and self.reliability_acc >= 0.4):
            self.homo = True
            self.homo_init()

        else:
            self.hete_init(x=x)
        print("| Non Parameter Label propagation Reliability Value: {}, Homo: {}".format(round(self.reliability_acc, 4), self.homo))


    def preprocess(self, adj):
        if message_op == "last":
            self.pre_msg_op = LastMessageOp()

        elif message_op == "over":
            self.pre_msg_op = OverSmoothDistanceWeightedOp()

        elif message_op == "mean":
            self.pre_msg_op = MeanMessageOp(start=0, end=self.prop_steps+1)

        elif message_op == "sum":
            self.pre_msg_op = SumMessageOp(start=0, end=self.prop_steps+1)

        elif message_op == "concat":
            self.pre_msg_op = ConcatMessageOp(start=0, end=self.prop_steps+1)

        if self.homo:
            self.processed_feat_list = self.pre_graph_op.propagate(adj, self.ori_feature)
            self.smoothed_feature = self.pre_msg_op.aggregate(self.processed_feat_list)

        else:
            self.universal_re = getre_scale(self.nodes_embedding)
            self.universal_re_smooth = torch.where(self.universal_re>0.999, 1, 0)
            edge_u = torch.where(self.universal_re_smooth != 0)[0].numpy()
            edge_v = torch.where(self.universal_re_smooth != 0)[1].numpy()
            self.universal_re_smooth = np.vstack((edge_u,edge_v))

            self.adj = sp.coo_matrix((torch.ones([len(self.universal_re_smooth[0])]), (self.universal_re_smooth[0], self.universal_re_smooth[1])), shape=(self.nodes_embedding.shape[0], self.nodes_embedding.shape[0]))
            row, col, edge_weight = self.adj.row, self.adj.col, self.adj.data
            if isinstance(row, Tensor) or isinstance(col, Tensor):
                self.adj = csr_matrix((edge_weight.numpy(), (row.numpy(), col.numpy())),
                                                shape=(self.nodes_embedding.shape[0], self.nodes_embedding.shape[0]))
            else:
                self.adj = csr_matrix((edge_weight, (row, col)), shape=(self.nodes_embedding.shape[0], self.nodes_embedding.shape[0]))

            self.processed_feat_list = self.pre_graph_op.propagate(self.adj, self.ori_feature)
            self.smoothed_feature = self.pre_msg_op.aggregate(self.processed_feat_list)

        self.processed_feature = self.nodes_embedding

    def homo_forward(self, device):
        local_smooth_emb, local_emb = self.homo_model(
            smoothed_feature=self.smoothed_feature,
            processed_feature=self.processed_feature,
            device=device
        )
        return local_smooth_emb, local_emb

    def hete_forward(self, device):
        local_smooth_emb, local_message_propagation, local_emb = self.hete_model(
            smoothed_feature=self.smoothed_feature,
            processed_feature=self.processed_feature,
            universal_re=self.universal_re,
            device=device)

        return local_smooth_emb, local_message_propagation, local_emb

    def postprocess(self, adj, output):
        if self.post_graph_op is not None:
            output = F.softmax(output, dim=1)
            output = output.detach().numpy()
            output = self.post_graph_op.propagate(adj, output)
            output = self.post_msg_op.aggregate(output)
        return output


class HetePropagateLayer(nn.Module):
    def __init__(self, feat_dim, output_dim):
        super(HetePropagateLayer, self).__init__()
        self.lr_hete_trans = nn.Linear(feat_dim, output_dim)

        self.softmax = nn.Softmax(dim=1)
        self.prelu = nn.PReLU()
        self.relu = nn.ReLU()
        self.leakyrely = nn.LeakyReLU(negative_slope=hete_negative_slope)
        self.reset_parameters()

    def reset_parameters(self):
        gain = nn.init.calculate_gain("relu")
        nn.init.xavier_uniform_(self.lr_hete_trans.weight, gain=gain)
        nn.init.zeros_(self.lr_hete_trans.bias)


    def forward(self, feature, device, learnable_re=None):
        feature_emb = self.lr_hete_trans(feature)
        feature_emb_re = getre_scale(feature_emb)
        learnable_re = learnable_re + hete_re_smooth * feature_emb_re
        learnable_re = learnable_re / (1 + hete_re_smooth)
        pos_signal = self.leakyrely(learnable_re)
        neg_signal = -self.leakyrely(learnable_re - (torch.ones(learnable_re.shape[0], learnable_re.shape[1]).to(device)))

        prop_pos = torch.mm(pos_signal, feature_emb)
        prop_neg = torch.mm(neg_signal, feature_emb)

        local_message_propagation =  (prop_pos + prop_neg + feature_emb)

        return local_message_propagation

class HetePropagateModel(nn.Module):
    def __init__(self, feat_dim, hidden_dim, output_dim, ori_feature, prop_steps, dropout=0.5, bn=False, ln=False):
        super(HetePropagateModel, self).__init__()
        self.feat_dim = feat_dim
        self.prop_steps = prop_steps
        self.dropout = dropout
        self.ori_feature = ori_feature
        self.bn = bn
        self.ln = ln

        self.dropout = nn.Dropout(dropout)
        self.prelu = nn.PReLU()
        self.softmax = nn.Softmax(dim=1)

        if message_op == "concat":
            self.lr_smooth_trans = nn.Linear((prop_steps+1) * feat_dim, output_dim)
        
        else:
            self.lr_smooth_trans = nn.Linear(feat_dim, output_dim)
        self.lr_local_trans1 = nn.Linear(feat_dim, hidden_dim)
        self.lr_global_trans = nn.Linear(output_dim, output_dim)
        self.hete_propagations = nn.ModuleList()

        self.hete_propagations.append(HetePropagateLayer(feat_dim, hidden_dim))
        self.hete_propagations.append(HetePropagateLayer(hidden_dim, output_dim))

        self.norms = nn.ModuleList()
        if bn:
            self.norms.append(nn.BatchNorm1d(hidden_dim))
        if ln:
            self.norms.append(nn.LayerNorm(hidden_dim))

        self.reset_parameters()

    def reset_parameters(self):
        gain = nn.init.calculate_gain("relu")
        nn.init.xavier_uniform_(self.lr_local_trans1.weight, gain=gain)
        nn.init.zeros_(self.lr_local_trans1.bias)

        nn.init.xavier_uniform_(self.lr_smooth_trans.weight, gain=gain)
        nn.init.zeros_(self.lr_smooth_trans.bias)

        nn.init.xavier_uniform_(self.lr_global_trans.weight, gain=gain)
        nn.init.zeros_(self.lr_global_trans.bias)

    def forward(self, smoothed_feature, processed_feature,universal_re, device):
        smoothed_feature = smoothed_feature.to(device)
        processed_feature = processed_feature.to(device)
        learnable_re = universal_re.to(device)


        local_smooth_emb = self.lr_smooth_trans(smoothed_feature)
        ori_feature = self.ori_feature.to(device)
        local_ori_emb1 = self.lr_local_trans1(ori_feature)
        local_ori_emb1 = self.prelu(local_ori_emb1)
        local_ori_emb1 = self.dropout(local_ori_emb1)

        if message_op == "concat":
            ptr = self.prop_steps * self.feat_dim
            input_feature = smoothed_feature[:, ptr:] + ori_feature
        else:
            input_feature = smoothed_feature + ori_feature


        local_message_propagation = self.hete_propagations[0](input_feature, device, learnable_re)

        if self.bn or self.ln:
            local_message_propagation = self.norms[0](local_message_propagation)
        local_message_propagation = self.prelu(local_message_propagation)
        local_message_propagation = self.dropout(local_message_propagation)


        local_message_propagation = local_message_propagation + local_ori_emb1
        local_message_propagation = self.hete_propagations[1](local_message_propagation, device, learnable_re)

        local_emb = self.lr_global_trans(processed_feature)

        return local_smooth_emb, local_message_propagation, local_emb

class HomoPropagateModel(nn.Module):
    def __init__(self, num_layers, feat_dim, hidden_dim, output_dim, prop_steps, dropout=0.5, bn=False, ln=False):
        super(HomoPropagateModel, self).__init__()
        self.num_layers = num_layers
        self.dropout = dropout

        self.lr_global_trans = nn.Linear(output_dim, output_dim)

        if self.num_layers == 1:
            if message_op == "concat":
                self.lr_smooth_trans = nn.Linear((prop_steps+1) * feat_dim, output_dim)
            
            else:
                self.lr_smooth_trans = nn.Linear(feat_dim, output_dim)

        else:
            self.lr_smooth_trans = nn.ModuleList()
            if message_op == "concat":
                self.lr_smooth_trans.append(nn.Linear((prop_steps+1) * feat_dim, hidden_dim))
            
            else:
                self.lr_smooth_trans.append(nn.Linear(feat_dim, hidden_dim))

            for _ in range(num_layers - 2):
                self.lr_smooth_trans.append(nn.Linear(hidden_dim, hidden_dim))
            self.lr_smooth_trans.append(nn.Linear(hidden_dim, output_dim))
        self.bn = bn
        self.ln = ln
        self.norms = nn.ModuleList()
        if bn:
            if self.num_layers != 1:
                for _ in range(num_layers-1):
                    self.norms.append(nn.BatchNorm1d(hidden_dim))
        if ln:
            if self.num_layers != 1:
                for _ in range(num_layers-1):
                    self.norms.append(nn.LayerNorm(hidden_dim))

        self.dropout = nn.Dropout(dropout)
        self.prelu = nn.PReLU()
        self.reset_parameters()

    def reset_parameters(self):
        gain = nn.init.calculate_gain("relu")
        if self.num_layers == 1:
            nn.init.xavier_uniform_(self.lr_smooth_trans.weight, gain=gain)
            nn.init.zeros_(self.lr_smooth_trans.bias)
            nn.init.xavier_uniform_(self.lr_global_trans.weight, gain=gain)
            nn.init.zeros_(self.lr_global_trans.bias)

        else:
            nn.init.xavier_uniform_(self.lr_global_trans.weight, gain=gain)
            nn.init.zeros_(self.lr_global_trans.bias)
            for lr_smooth_tran in self.lr_smooth_trans:
                nn.init.xavier_uniform_(lr_smooth_tran.weight, gain=gain)
                nn.init.zeros_(lr_smooth_tran.bias)


    def forward(self, smoothed_feature, processed_feature, device):
        smoothed_feature = smoothed_feature.to(device)
        processed_feature = processed_feature.to(device)

        if self.num_layers == 1:
            local_smooth_emb = self.lr_smooth_trans(smoothed_feature)

        else:
            for i in range(self.num_layers - 1):
                smoothed_feature = self.lr_smooth_trans[i](smoothed_feature)
                if self.bn is True or self.ln is True:
                    smoothed_feature = self.norms[i](smoothed_feature)
                smoothed_feature = self.prelu(smoothed_feature)
                smoothed_feature = self.dropout(smoothed_feature)
            local_smooth_emb = self.lr_smooth_trans[-1](smoothed_feature)

        local_emb = self.lr_global_trans(processed_feature)

        return local_smooth_emb, local_emb

def getre_scale(emb, temperature=hete_re_temperature):
    emb_softmax = nn.Softmax(dim=1)(emb)
    re = torch.mm(emb_softmax, emb_softmax.transpose(0,1)) / temperature
    re_self = torch.unsqueeze(torch.diag(re),1)
    scaling = torch.mm(re_self, torch.transpose(re_self, 0, 1))
    re = re / torch.max(torch.sqrt(scaling),1e-9*torch.ones_like(scaling))
    re = re - torch.diag(torch.diag(re))
    return re