import torch
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool
from torch_scatter import scatter
from GCL.augmentors import Compose
from models.GraphHD.gnn import GNN, GNN_graphpred
from models.GraphHD.losses import InfoNCE
from GCL.models import SameScaleSampler
from torch_geometric.nn.inits import glorot, zeros

def normalized_mse_loss(h1, h2, reduction='mean'):

    h1 = F.normalize(h1)
    h2 = F.normalize(h2)
    loss = F.mse_loss(h1, h2, reduction=reduction)

    return loss

def normalized_l1_loss(h1, h2, reduction='mean'):

    h1 = F.normalize(h1)
    h2 = F.normalize(h2)
    loss = F.l1_loss(h1, h2, reduction=reduction)

    return loss

def add_extra_mask(pos_mask, neg_mask=None, extra_pos_mask=None, extra_neg_mask=None):
    if extra_pos_mask is not None:
        pos_mask = extra_pos_mask
    if extra_neg_mask is not None:
        neg_mask = neg_mask * extra_neg_mask
    else:
        neg_mask = 1. - pos_mask
    return pos_mask, neg_mask

def add_extra_mask_neg(neg_mask, extra_neg_mask=None):

    assert extra_neg_mask is not None
    neg_mask = neg_mask * extra_neg_mask

    return neg_mask

def _similarity(h1: torch.Tensor, h2: torch.Tensor):
    h1 = F.normalize(h1)
    h2 = F.normalize(h2)
    return h1 @ h2.t()

class UnbiasGCL(nn.Module):
    def __init__(self, augmentor, n_protos, cf):
        super(UnbiasGCL, self).__init__()
        self.n_hidden = cf.n_hidden
        self.device = cf.compute_dev
        self.batch_size = cf.batch_size
        self.augmentor = augmentor
        self.pretrain_pool = global_mean_pool
        self.sampler = SameScaleSampler(intraview_negs=cf.intra_negative)
        self.loss = InfoNCE(cf.tau)
        self.proto_loss = torch.nn.CrossEntropyLoss().to(self.device)
        self.tau = cf.tau
        self.encoder = GNN(num_layer=cf.n_layer, emb_dim=cf.n_hidden, JK=cf.JK, drop_ratio=cf.dropout, gnn_type=cf.gnn_type)

        self.n_protos = n_protos
        self.h_level = len(self.n_protos)

        if cf.JK == 'concat':
            project_dim = cf.n_hidden * cf.n_layer
        else:
            project_dim = cf.n_hidden

        self.proj = nn.Sequential(
            nn.Linear(project_dim, project_dim),
            nn.ReLU(),
            nn.Linear(project_dim, project_dim))

        self.init_emb()

    def init_emb(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                th.nn.init.xavier_uniform_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.fill_(0.0)

    def cal_instance_loss(self, g1, g2, extra_neg_mask=None):

        assert g1.size(0) == g2.size(0)
        anchor1, sample1, pos_mask1, neg_mask1 = self.sampler(anchor=g1, sample=g2)
        anchor2, sample2, pos_mask2, neg_mask2 = self.sampler(anchor=g2, sample=g1)

        if extra_neg_mask is not None:
            pos_mask1, neg_mask1 = add_extra_mask_neg(neg_mask1, extra_neg_mask)
            pos_mask2, neg_mask2 = add_extra_mask_neg(neg_mask2, extra_neg_mask)

        loss1 = self.loss(anchor=anchor1, sample=sample1, pos_mask=pos_mask1, neg_mask=neg_mask1)
        loss2 = self.loss(anchor=anchor2, sample=sample2, pos_mask=pos_mask2, neg_mask=neg_mask2)
        loss = (loss1 + loss2) * 0.5

        return loss

    def cal_local_loss(self, z1, z2, batch):

        assert z1.size(0) == z2.size(0)

        pos_mask = torch.stack([(batch == i).float() for i in batch.tolist()], dim=1)
        neg_mask = 1. - pos_mask

        loss1 = self.loss(anchor=z1, sample=z2, pos_mask=pos_mask, neg_mask=neg_mask)
        loss2 = self.loss(anchor=z2, sample=z1, pos_mask=pos_mask, neg_mask=neg_mask)
        loss = (loss1 + loss2) * 0.5
        return loss

    def cal_proto_loss(self, g, cluster_result, index):
        loss = 0
        proto_prob_list = []

        for l in range(len(self.n_protos)):
            pos_proto_id = cluster_result['g2cluster'][l][index].unsqueeze(-1)
            temp_proto = cluster_result['centroids'][l]
            density = cluster_result['density'][l]
            temp_map = density.repeat([g.shape[0], 1])

            all_proto_id = th.LongTensor([i for i in range(self.n_protos[l])]).expand(pos_proto_id.size(0), -1).to(
                self.device)
            pos_proto_mask = (all_proto_id == pos_proto_id).clone().float()
            neg_proto_mask = 1. - pos_proto_mask

            if l != len(self.n_protos)-1:
                sample_neg_mask = self.sample_neg_protos(pos_proto_id, cluster_result, l)
                neg_proto_mask *= sample_neg_mask

            sim = _similarity(g, temp_proto) / temp_map
            proto_prob_list.append(torch.mean(F.softmax(sim, dim=-1), dim=0))

            exp_sim = torch.exp(sim)
            positive = exp_sim * pos_proto_mask
            negative = exp_sim * neg_proto_mask
            positive_ratio = positive.sum(1) / (positive.sum(1) + negative.sum(1))
            g_proto_loss = -torch.log(positive_ratio).mean()
            loss += g_proto_loss

        return loss, proto_prob_list

    def cal_proto_loss_bp(self, g, cluster_result, ptoto_update='bp'):

        exp_sim_list = []
        mask_list = []
        proto_prob_list = []
        batch_centroid_list = []
        loss = 0

        for l in range(len(self.n_protos)-1, -1, -1):
            temp_proto = cluster_result['centroids'][l]
            sim = _similarity(g, temp_proto) / self.tau
            proto_prob_list = [torch.mean(F.softmax(sim, dim=-1), dim=0)] + proto_prob_list
            exp_sim = torch.exp(sim)

            if ptoto_update == 'momentum':
                max_id = th.argmax(exp_sim, dim=1)
                batch_centroid = scatter(g, max_id, dim=0, dim_size=self.n_protos[l], reduce='mean')
                batch_centroid = F.normalize(batch_centroid).detach()
                batch_centroid_list = [batch_centroid] + batch_centroid_list

            if l != len(self.n_protos)-1:
                exp_sim_last = exp_sim_list[-1]
                idx_last = torch.argmax(exp_sim_last, dim=1).unsqueeze(-1)
                cluster2luster = cluster_result['cluster2cluster'][l]
                cluster2luster_mask = (cluster2luster.unsqueeze(0) == idx_last.float()).float()
                exp_sim = exp_sim * cluster2luster_mask ## only keep the sim within the same larger cluster

                upper_proto = cluster_result['centroids'][l + 1]
                proto_sim = _similarity(temp_proto, upper_proto) / self.tau
                proto_exp_sim = torch.exp(proto_sim)
                proto_positive_list = [proto_exp_sim[j, cluster2luster[j].long()] for j in range(proto_exp_sim.shape[0])]
                proto_positive = torch.stack(proto_positive_list, dim=0)
                proto_positive_ratio = proto_positive / (proto_exp_sim.sum(1))
                proto_proto_loss = -torch.log(proto_positive_ratio).mean()
                loss += proto_proto_loss

            temp_pos_mask = (exp_sim == exp_sim.max(1)[0].unsqueeze(-1)).float()
            exp_sim_list.append(exp_sim)
            mask_list.append(temp_pos_mask)

        for i in range(len(self.n_protos)):
            exp_sim = exp_sim_list[i]
            temp_pos_mask = mask_list[i]

            positive = exp_sim * temp_pos_mask
            negative = exp_sim * (1 - temp_pos_mask)
            positive_ratio = positive.sum(1) / (positive.sum(1) + negative.sum(1))
            g_proto_loss = -torch.log(positive_ratio).mean()
            loss += g_proto_loss

        return loss, proto_prob_list, batch_centroid_list

    def forward(self, data):

        x, edge_index, edge_attr, batch, id = data.x, data.edge_index, data.edge_attr, data.batch, data.id
        aug1, aug2 = self.augmentor
        x1, edge_index1, edge_attr1 = aug1(x, edge_index, edge_attr)
        x2, edge_index2, edge_attr2 = aug2(x, edge_index, edge_attr)
        z1 = self.encoder(x1, edge_index1, edge_attr1)
        z2 = self.encoder(x2, edge_index2, edge_attr2)

        g1 = self.pretrain_pool(z1, batch)
        g2 = self.pretrain_pool(z2, batch)

        return z1, z2, g1, g2

    def inference(self, x, edge_index, edge_attr, batch):
        z = self.encoder(x, edge_index, edge_attr)
        g = self.pretrain_pool(z, batch)
        return g

    def get_protos(self, index, cluster_result=None):

        if cluster_result is not None:

            proto_selecteds = []
            temp_protos = []
            proto_pos_masks = []
            proto_neg_masks = []

            for h, (g2cluster, prototypes, density) in enumerate(zip(cluster_result['g2cluster'], cluster_result['centroids'], cluster_result['density'])):
                pos_proto_id = g2cluster[index].unsqueeze(-1)
                proto_selecteds.append(prototypes)
                temp_protos.append(density)

                all_proto_id = th.LongTensor([i for i in range(self.n_protos[h])]).expand(pos_proto_id.size(0), -1).to(self.device)
                pos_proto_mask = (all_proto_id == pos_proto_id).clone().float()
                neg_proto_mask = 1. - pos_proto_mask

                proto_pos_masks.append(pos_proto_mask)
                proto_neg_masks.append(neg_proto_mask)

            return proto_pos_masks, proto_neg_masks, proto_selecteds, temp_protos

        else:
            return None, None, None, None

    def sample_neg_protos(self, pos_proto_id, cluster_results, h):

        cluster2cluster = cluster_results['cluster2cluster'][h]  ## clsuter to the upper level cluster
        prot_logits = cluster_results['logits'][h]  ## cluster logits to the upper level cluster

        all_proto_id = th.LongTensor([i for i in range(self.n_protos[h])]).expand(pos_proto_id.size(0), -1).to(self.device)
        upper_pos_proto_id = cluster2cluster[pos_proto_id].squeeze(-1)  # [N_q]
        densities = cluster_results['density'][h + 1] / cluster_results['density'][h + 1].mean() * self.tau

        sampling_prob = (prot_logits / densities).softmax(-1)[all_proto_id, :]
        sampling_prob = 1. - sampling_prob[th.arange(sampling_prob.size(0)), :, upper_pos_proto_id].reshape(-1, 1)

        neg_sampler = th.distributions.bernoulli.Bernoulli(sampling_prob.clamp(0.0001, 0.999))
        selected_mask = neg_sampler.sample().reshape(pos_proto_id.size(0), -1)  # [N_q, N_neg]

        return selected_mask


class Teacher_Model(nn.Module):
    def __init__(self, teacher_model, cf):
        super(Teacher_Model, self).__init__()
        self.teahcer_model = teacher_model
        self.device = cf.compute_dev
        self.pretrain_pool = global_mean_pool
        self.n_layer = cf.n_layer
        self.n_hidden = cf.n_hidden
        self.JK = cf.JK
        self.graph_pooling = cf.graph_pooling
        self.gnn_type = cf.gnn_type

        self.encoder = GNN(num_layer=cf.n_layer, emb_dim=self.n_hidden,
                                             JK=cf.JK, drop_ratio=cf.dropout, gnn_type=cf.gnn_type)

        self.init_emb()

    def init_emb(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                th.nn.init.xavier_uniform_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.fill_(0.0)

    def from_pretrained(self, model_file, key=None):
        if key is not None:
            checkpoint = torch.load(model_file)
            self.encoder.load_state_dict(checkpoint[key])
        else:
            self.encoder.load_state_dict(torch.load(model_file))

    def forward(self, x, edge_index, edge_attr, batch):
        z = self.encoder(x, edge_index, edge_attr)
        g = self.pretrain_pool(z, batch)
        return g



class Student_Model(nn.Module):
    def __init__(self, h_level, cf, pred=False):
        super(Student_Model, self).__init__()
        self.n_hidden = cf.n_hidden
        self.device = cf.compute_dev
        self.batch_size = cf.batch_size
        self.pretrain_pool = global_mean_pool
        self.h_level = h_level
        self.student_layer = cf.student_layer
        self.pred = pred

        if cf.dim_align == 1:
            assert cf.n_hidden % self.h_level == 0
            self.student_n_hidden = cf.n_hidden // self.h_level
        elif cf.dim_align == 0:
            self.student_n_hidden = cf.n_hidden
        else:
            raise ValueError

        if cf.JK == 'concat':
            project_dim = self.student_n_hidden * self.student_layer
        else:
            project_dim = self.student_n_hidden

        self.encoders = nn.ModuleList()
        for _ in range(self.h_level):
            if not self.pred:
                self.encoders.append(GNN(num_layer=cf.student_layer, emb_dim=project_dim,
                                                 JK=cf.JK, drop_ratio=cf.dropout, gnn_type=cf.gnn_type))
            else:
                self.encoders.append(GNN(num_layer=cf.student_layer, emb_dim=project_dim,
                                         JK=cf.JK, drop_ratio=cf.ft_dropout, gnn_type=cf.gnn_type))

        self.init_emb()

    def init_emb(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                th.nn.init.xavier_uniform_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.fill_(0.0)

    def forward(self, x, edge_index, edge_attr, batch):
        g_list = []
        for l in range(self.h_level):
            z = self.encoders[l](x, edge_index, edge_attr)
            g = self.pretrain_pool(z, batch)
            g_list.append(g)
        return g_list


class Student_Model_graphpred(nn.Module):
    def __init__(self, h_level, num_tasks, cf, pred_type='cat'):
        super(Student_Model_graphpred, self).__init__()
        self.n_hidden = cf.n_hidden
        self.device = cf.compute_dev
        self.pretrain_pool = global_mean_pool
        self.h_level = h_level
        self.num_tasks = num_tasks
        self.student_layer = cf.student_layer
        self.pred_type = pred_type
        self.dropout = nn.Dropout(cf.ft_dropout)

        if cf.dim_align == 1:
            assert cf.n_hidden % self.h_level == 0
            self.student_n_hidden = cf.n_hidden // self.h_level
        elif cf.dim_align == 0:
            self.student_n_hidden = cf.n_hidden
        else:
            raise ValueError

        if cf.JK == 'concat':
            project_dim = self.student_n_hidden * self.student_layer
        else:
            project_dim = self.student_n_hidden

        self.encoders = Student_Model(h_level=self.h_level, cf=cf, pred=True)

        if self.pred_type == 'cat':
            self.graph_pred_linear = torch.nn.Linear(project_dim * self.h_level, num_tasks)
        elif self.pred_type == 'mean':
            self.graph_pred_linear = nn.ModuleList()
            for l in range(self.h_level):
                self.graph_pred_linear.append(torch.nn.Linear(project_dim, num_tasks))
        elif self.pred_type == 'ensemble':
            self.graph_pred_linear = nn.ModuleList()
            for l in range(self.h_level):
                self.graph_pred_linear.append(torch.nn.Linear(project_dim, num_tasks))
        else:
            raise ValueError

        self.init_emb()

    def init_emb(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                th.nn.init.xavier_uniform_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.fill_(0.0)

    def from_pretrained(self, model_file, key=None):
        if key is not None:
            checkpoint = torch.load(model_file)
            self.encoders.load_state_dict(checkpoint[key])
        else:
            self.encoders.load_state_dict(torch.load(model_file))

    def forward(self, x, edge_index, edge_attr, batch):
        g_list = self.encoders(x, edge_index, edge_attr, batch)
        if self.pred_type == 'cat':
            g = torch.cat(g_list, dim=-1)
            g = self.graph_pred_linear(g)
            return g
        elif self.pred_type == 'mean':
            pred_list = []
            for l in range(self.h_level):
                pred_list.append(self.graph_pred_linear[l](g_list[l]).unsqueeze(0))
            g = torch.mean(torch.cat(pred_list, dim=0), dim=0)
            return g
        elif self.pred_type == 'ensemble':
            pred_list = []
            for l in range(self.h_level):
                pred_list.append(self.graph_pred_linear[l](g_list[l]).unsqueeze(0))
            return pred_list
        else:
            raise ValueError

class ProjectNet(torch.nn.Module):
    def __init__(self, rep_dim):
        super(ProjectNet, self).__init__()
        self.rep_dim = rep_dim
        self.proj = torch.nn.Sequential(
            torch.nn.Linear(self.rep_dim, self.rep_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(self.rep_dim, self.rep_dim)
        )

    def forward(self, x):
        x_proj = self.proj(x)

        return x_proj

class Student_ProjectNet(torch.nn.Module):
    def __init__(self, h_level, cf):
        super(Student_ProjectNet, self).__init__()
        self.rep_dim = cf.n_hidden
        self.h_level = h_level
        self.projs = nn.ModuleList()

        for i in range(self.h_level):
            self.projs.append(ProjectNet(self.rep_dim))

    def forward(self, x_list):

        x_proj_list = [self.projs[l](x_list[l]) for l in range(self.h_level)]

        return x_proj_list