import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import add_remaining_self_loops, remove_self_loops, dense_to_sparse, to_scipy_sparse_matrix, to_dense_adj
from torch.nn import Sequential, Linear, ReLU
import scipy.sparse as sp 
from utils import *
from encoders import *
from edge_evaluator import *


class HeteModel(nn.Module):
    def __init__(self, args, g, se, sum_adj, in_dim, st_dim, device):
        super(HeteModel, self).__init__()
        self.sum_adj = sum_adj
        self.emb_dim = args.emb_dim
        self.batch_size = args.loss_batch_size
        self.tau = args.tau
        self.ncolors = args.ncolors
        self.sparse = args.sparse
        self.alpha= args.alpha
        self.beta = args.beta
        self.gamma = args.gamma
        self.eta = args.eta
        self.device = device

        self.evaluator = Edge_Evaluator(in_dim + st_dim, self.sparse, device=device).to(device)

        self.agcnconv = AGCN(g, se, in_dim, self.emb_dim, args.dropout, eps=args.eps, layer_num=args.nlayers).to(device)

        self.mlp = Linear(in_dim + st_dim, self.emb_dim)
        self.edge_mlp = Linear(self.emb_dim * 2, 1)
        self.fc = Linear(self.emb_dim, self.ncolors)
        self.softmax = nn.Softmax(dim=1)
        self.logsoftmax = nn.LogSoftmax(dim=1)
        self.sigmoid = nn.Sigmoid()
        self.logsigmoid = nn.LogSigmoid()
        self.proj = Linear(self.emb_dim, args.proj_dim)
        

    def get_embedding(self, x):

        emb = self.agcnconv(x)
        return emb

    def forward(self, x, edge_index, str_encodings):
        
        z = self.agcnconv(x)

        color_logits = self.fc(z)
        color_probs = self.softmax(color_logits)

        W = self.evaluator(torch.cat((x, str_encodings), 1), edge_index)
        
    
        # coloring loss
        lc = self.loss_func_col(x, edge_index, W, color_probs) 

        # entropy loss
        le = self.loss_entropy(color_logits)

        lr = self.loss_edge_evaluation(W, color_probs, edge_index)

        # sample class
        hard_probs = F.gumbel_softmax(color_probs, tau=self.tau, hard=True)

        # loss of number of coloring
        lnc = self.loss_num_col(hard_probs)

        # loss of contrastive learning
        pos_mask, neg_mask = self.get_mask(hard_probs)
        zp = F.relu(self.proj(z))
        # lcl = self.loss_cl(zp, pos_mask, neg_mask)

        nnodes = z.shape[0]
        if (self.batch_size == 0) or (self.batch_size > nnodes):
            lcl = self.loss_cl(zp, pos_mask, neg_mask)
        else:
            node_idxs = list(range(nnodes))
            random.shuffle(node_idxs)
            batches = split_batch(node_idxs, self.batch_size)
            lcl = 0
            for b in batches:
                weight = len(b) / nnodes
                lcl += self.loss_cl(zp[b], pos_mask[:,b][b,:], neg_mask[:,b][b,:]) * weight

        # total loss
        loss = lc + lnc * self.alpha +  lr * self.beta +  lcl* self.gamma + le * self.eta

        return loss
    
    def get_mask(self, probs):
        _, y = torch.max(probs, dim=1)
        pos_mask = torch.eq(y, y.unsqueeze(dim=1)).to(self.device) 
        pos_mask = pos_mask * self.sum_adj 
        neg_mask = 1 - pos_mask

        return pos_mask, neg_mask

    
    def loss_func_col(self, x, edges, W, probs):

        nnodes = x.shape[0]

        if not self.sparse:
            adj = to_dense_adj(edges)[0]

            sim = self.similarity(probs, probs)
            
            loss = torch.mul(adj, (1 - W) * sim)
            loss_ = loss.sum(1) / (adj.sum(1) + 1e-12)
        else:

            sim = F.cosine_similarity(probs[edges[0]], probs[edges[1]], dim=1)

            loss = (1 - W.edata['w']) * sim  
            loss_full = torch.zeros(nnodes, device=self.device)  
            loss_full.index_add_(0, edges[0], loss)  

            adj_sum = torch.zeros(nnodes, device=self.device)
            adj_sum.index_add_(0, edges[0], torch.ones_like(loss))  
            loss_ = loss_full / (adj_sum + 1e-12)

        return loss_.mean()
    
    def loss_entropy(self, logits):

        loss_ = - torch.mul(self.softmax(logits), self.logsoftmax(logits))

        return loss_.mean()
    
    def loss_edge_evaluation(self, W, probs, edges):
        nnodes = probs.shape[0]
        rand_np = generate_random_node_pairs(probs.shape[0], edges.shape[1], device=self.device)

        if not self.sparse:
            adj = to_dense_adj(edges)[0]
            edge_prob_sim = self.similarity(probs, probs)
            rnp_prob_sim = torch.zeros_like(edge_prob_sim) 
            rnp_prob_sim[rand_np[0], rand_np[1]] = edge_prob_sim[rand_np[0], rand_np[1]]

            R = - ((1 - W) * self.logsigmoid(rnp_prob_sim - edge_prob_sim) +\
                    W * torch.log(1 - self.sigmoid(rnp_prob_sim - edge_prob_sim)))
            
            loss = torch.mul(adj, R)
            loss_ = loss.sum(1) / (adj.sum(1) + 1e-12)
        else:

            edge_prob_sim = F.cosine_similarity(probs[edges[0]], probs[edges[1]])
            rnp_prob_sim = F.cosine_similarity(probs[rand_np[0]], probs[rand_np[1]])

            R = - ((1 - W.edata['w']) * self.logsigmoid(rnp_prob_sim - edge_prob_sim) +\
                    W.edata['w'] * torch.log(1 - self.sigmoid(rnp_prob_sim - edge_prob_sim)))
            
            loss_full = torch.zeros(nnodes, device=self.device)  
            loss_full.index_add_(0, edges[0], R)  

            adj_sum = torch.zeros(nnodes, device=self.device)
            adj_sum.index_add_(0, edges[0], torch.ones_like(R))  
            loss_ = loss_full / (adj_sum + 1e-12)

        return loss_.mean()
    
    def loss_num_col(self, probs):
        
        max_values = probs.max(dim=0).values
        loss_ = max_values
        return loss_.mean()
    
    def loss_cl(self, x, pos_mask, neg_mask, tau=0.2): 
        
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(pos_mask),
            1,
            torch.arange(pos_mask.shape[0]).view(-1, 1).to(self.device),
            0
        )

        neg_mask = neg_mask * logits_mask

        sim = self.similarity(x, x) / tau
        # exp_sim = torch.exp(sim) * logits_mask  
        exp_sim = torch.exp(sim) * neg_mask 
        log_prob = sim - torch.log(exp_sim.sum(1, keepdim=True) + 1e-12)  
        loss_ = - (log_prob * pos_mask).sum(1) / (pos_mask.sum(1) + 1e-12)   

        return loss_.mean()
    
    def similarity(self, h1: torch.Tensor, h2: torch.Tensor):
        h1 = F.normalize(h1)
        h2 = F.normalize(h2)
        return h1 @ h2.t()
    
    












        