import copy
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter
from torch_geometric.utils import dropout_adj

from models import NormLayer
from utils import drop_features


def index_to_mask(index, size=None):
    index = index.view(-1)
    size = int(index.max()) + 1 if size is None else size
    mask = index.new_zeros(size, dtype=torch.bool)
    mask[index] = True
    return mask


class LogReg(nn.Module):
    def __init__(self, hid_dim, n_classes):
        super(LogReg, self).__init__()

        self.fc = nn.Linear(hid_dim, n_classes)

    def forward(self, x):
        ret = self.fc(x)
        return ret


class GRACE(torch.nn.Module):
    def __init__(self, encoder, input_dim, num_hidden, num_proj_hidden, tau, drop_rate, args):
        super(GRACE, self).__init__()
        self.encoder = encoder
        self.tau = tau
        self.drop_rate = drop_rate
        self.num_proj_hidden = num_proj_hidden
        self.fc1 = torch.nn.Linear(input_dim, num_proj_hidden)
        self.fc2 = torch.nn.Linear(num_proj_hidden, num_hidden)
        self.norm_layer = NormLayer(args, num_proj_hidden)

    def augmentation(self, x, edge_index):
        edge_index_1 = dropout_adj(edge_index, p=self.drop_rate[0])[0]
        edge_index_2 = dropout_adj(edge_index, p=self.drop_rate[1])[0]
        x1 = drop_features(x, self.drop_rate[2])
        x2 = drop_features(x, self.drop_rate[3])
        return x1, edge_index_1, x2, edge_index_2

    def forward(self, x, edge_index):
        x1, edge_index_1, x2, edge_index_2 = self.augmentation(x, edge_index)
        z1 = self.encoder(x1, edge_index_1)
        z2 = self.encoder(x2, edge_index_2)
        return z1, z2

    def get_embedding(self, x, edge_index):
        z = self.encoder.get_embeddings(x, edge_index)
        return z.detach()

    def projection_mlp(self, z: torch.Tensor) -> torch.Tensor:
        z = F.elu(self.fc1(z))
        z = self.norm_layer(z)
        return self.fc2(z)
    
    def projection_direct(self, z: torch.Tensor) -> torch.Tensor:
        return z[: , :self.num_proj_hidden]

    def sim(self, z1: torch.Tensor, z2: torch.Tensor):
        z1 = F.normalize(z1)
        z2 = F.normalize(z2)
        return torch.mm(z1, z2.t())
        
    def infonce(self, z1, z2):
        f = lambda x: torch.exp(x / self.tau)
        between_sim = f(self.sim(z1, z2))
        alignment_loss = -torch.log(between_sim.diag())
        refl_sim = f(self.sim(z1, z1))
        uniformity_loss = torch.log(refl_sim.sum(1) + between_sim.sum(1) - refl_sim.diag())
        loss = alignment_loss + uniformity_loss
        return loss
    
    def infonce_loss(self, z1, z2):
        z1 = self.projection_mlp(z1)
        z2 = self.projection_mlp(z2)
        l1 = self.infonce(z1, z2)
        l2 = self.infonce(z2, z1)
        ret = (l1 + l2) * 0.5
        ret = ret.mean()
        return ret


class SupCon(nn.Module):
    def __init__(self, encoder, input_dim, num_hidden, num_proj_hidden, tau):
        super(SupCon, self).__init__()
        self.encoder = encoder
        self.tau = tau
        self.fc1 = nn.Linear(input_dim, num_proj_hidden)
        self.fc2 = nn.Linear(num_proj_hidden, num_hidden)
    
    def forward(self, x, edge_index):
        z = self.encoder(x, edge_index)
        return z

    def projection(self, z):
        z = F.elu(self.fc1(z))
        return self.fc2(z)

    def get_embedding(self, x, edge_index):
        z = self.encoder.get_embeddings(x, edge_index)
        return z.detach()
    
    def sim(self, z1, z2):
        z1 = F.normalize(z1)
        z2 = F.normalize(z2)
        return torch.mm(z1, z2.t())
    
    def infonce_loss(self, z, y, num_classes):
        z = self.projection(z)

        f = lambda x: torch.exp(x / self.tau)
        loss = None
        for c in range(num_classes):
            c_mask = (y == c).bool()
            intraclass_sim = f(self.sim(z[c_mask], z[c_mask]))
            interclass_sim = f(self.sim(z[c_mask], z[~c_mask]))
            if loss is None:
                loss = -torch.log(intraclass_sim.sum(1)) + torch.log(interclass_sim.sum(1))
            else:
                loss = torch.cat((loss, -torch.log(intraclass_sim.sum(1)) + torch.log(interclass_sim.sum(1))), dim=0)
        return loss.mean()


class SupGRACE(torch.nn.Module):
    def __init__(self, encoder, input_dim, num_hidden, num_proj_hidden, tau, drop_rate, args):
        super(SupGRACE, self).__init__()
        self.encoder = encoder
        
        self.tau = tau
        self.num_proj_hidden = num_proj_hidden
        self.drop_rate = drop_rate
        self.fc1 = torch.nn.Linear(input_dim, num_proj_hidden)
        self.fc2 = torch.nn.Linear(num_proj_hidden, num_hidden)

    def augmentation(self, edge_index, x):
        edge_index_1 = dropout_adj(edge_index, p=self.drop_rate[0])[0]
        edge_index_2 = dropout_adj(edge_index, p=self.drop_rate[1])[0]
        x1 = drop_features(x, self.drop_rate[2])
        x2 = drop_features(x, self.drop_rate[3])
        return edge_index_1, edge_index_2, x1, x2

    def forward(self, x, edge_index):
        edge_index_1, edge_index_2, x1, x2 = self.augmentation(edge_index, x)
        z1 = self.encoder(x1, edge_index_1)
        z2 = self.encoder(x2, edge_index_2)
        return z1, z2
    
    def projection_mlp(self, z: torch.Tensor) -> torch.Tensor:
        z = F.elu(self.fc1(z))
        z = self.norm_layer(z)
        return self.fc2(z)

    def projection_direct(self, z: torch.Tensor) -> torch.Tensor:
        return z[: , :self.num_proj_hidden]

    def get_embedding(self, x, edge_index):
        z = self.encoder.get_embeddings(x, edge_index)
        return z.detach()

    def sim(self, z1: torch.Tensor, z2: torch.Tensor):
        z1 = F.normalize(z1)
        z2 = F.normalize(z2)
        return torch.mm(z1, z2.t())
        
    def infonce(self, z1, z2, y, num_classes):
        f = lambda x: torch.exp(x / self.tau)
        for c in range(num_classes):
            c_mask = (y == c).bool()
            intraclass_sim = f(self.sim(z1[c_mask], z2[c_mask]))
            interclass_sim = f(self.sim(z1[c_mask], z2[~c_mask]))
            if c == 0:
                loss = -torch.log(intraclass_sim.sum(1)) + torch.log(interclass_sim.sum(1))
            else:
                loss = torch.cat((loss, -torch.log(intraclass_sim.sum(1)) + torch.log(interclass_sim.sum(1))), dim=0)
        return loss
    
    def infonce_loss(self, z1, z2, y, num_classes):
        z1 = self.projection_direct(z1)
        z2 = self.projection_direct(z2)
        l1 = self.infonce(z1, z2, y=y, num_classes=num_classes)
        l2 = self.infonce(z2, z1, y=y, num_classes=num_classes)
        ret = (l1 + l2) * 0.5
        ret = ret.mean()
        return ret


class DGI(nn.Module):
    def __init__(self, encoder, out_dim):
        super(DGI, self).__init__()

        self.encoder = encoder

        self.fn = nn.Bilinear(out_dim, out_dim, 1)
        self.act_fn = nn.ReLU()
        self.loss_fn = nn.BCEWithLogitsLoss()

    def get_embedding(self, feat, edge_index):
        h = self.encoder(x=feat, edge_index=edge_index)
        return h.detach()

    def forward(self, edge_index, feat, shuf_feat):
        h = self.encoder(x=feat, edge_index=edge_index)
        h_neg = self.encoder(x=shuf_feat, edge_index=edge_index)
        g = self.act_fn(torch.mean(h, dim=0))
        g_x = g.expand_as(h).contiguous()

        sc_1 = self.fn(h, g_x).squeeze(1)
        sc_2 = self.fn(h_neg, g_x).squeeze(1)
        logits = torch.cat((sc_1, sc_2))
        return logits
    

class Discriminator(nn.Module):
    def __init__(self, n_h):
        super(Discriminator, self).__init__()
        self.f_k = nn.Bilinear(n_h, n_h, 1)
        for m in self.modules():
            self.weights_init(m)

    def weights_init(self, m):
        if isinstance(m, nn.Bilinear):
            torch.nn.init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.fill_(0.0)
    
    def forward(self, c1, c2, h1, h2, h3, h4):
        c_x1 = torch.unsqueeze(c1, 1)
        c_x1 = c_x1.expand_as(h1).contiguous()
        c_x2 = torch.unsqueeze(c2, 1)
        c_x2 = c_x2.expand_as(h2).contiguous()

        # positive
        sc_1 = torch.squeeze(self.f_k(h2, c_x1), 2)
        sc_2 = torch.squeeze(self.f_k(h1, c_x2), 2)

        # negetive
        sc_3 = torch.squeeze(self.f_k(h4, c_x1), 2)
        sc_4 = torch.squeeze(self.f_k(h3, c_x2), 2)

        logits = torch.cat((sc_1, sc_2, sc_3, sc_4), 1)
        return logits
    

class MVGRL(nn.Module):
    def __init__(self, encoder1, encoder2, n_h):
        super(MVGRL, self).__init__()
        self.encoder1 = encoder1
        self.encoder2 = encoder2

        self.sigm = nn.Sigmoid()
        self.disc = Discriminator(n_h)
        self.loss_fn = nn.BCEWithLogitsLoss()
    
    def forward(self, seq1, seq2, adj, diff):
        h_1 = self.encoder1(seq1, adj)
        c_1 = self.sigm(torch.mean(h_1, 1))

        h_2 = self.encoder2(seq1, diff)
        c_2 = self.sigm(torch.mean(h_2, 1))

        h_3 = self.encoder1(seq2, adj)
        h_4 = self.encoder2(seq2, diff)

        ret = self.disc(c_1, c_2, h_1, h_2, h_3, h_4)
        return ret

    def get_embedding(self, seq, adj, diff):
        h_1 = self.encoder1(seq, adj)
        h_2 = self.encoder2(seq, diff)
        return (h_1 + h_2).squeeze(0).detach()
    


class MLP_Predictor(nn.Module):
    def __init__(self, input_size, output_size, hidden_size=512):
        super().__init__()

        self.net = nn.Sequential(
            nn.Linear(input_size, hidden_size, bias=True),
            nn.PReLU(1),
            nn.Linear(hidden_size, output_size, bias=True)
        )
        self.reset_parameters()

    def forward(self, x):
        return self.net(x)

    def reset_parameters(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                m.reset_parameters()


class BGRL(torch.nn.Module):
    def __init__(self, encoder, predictor):
        super().__init__()
        self.online_encoder = encoder
        self.predictor = predictor
        self.target_encoder = copy.deepcopy(encoder)

        self.target_encoder.reset_parameters()
        for param in self.target_encoder.parameters():
            param.requires_grad = False

    def trainable_parameters(self):
        return list(self.online_encoder.parameters()) + list(self.predictor.parameters())

    @torch.no_grad()
    def update_target_network(self, mm):
        assert 0.0 <= mm <= 1.0, "Momentum needs to be between 0.0 and 1.0, got %.5f" % mm
        for param_q, param_k in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
            param_k.data.mul_(mm).add_(param_q.data, alpha=1. - mm)

    def forward(self, online_data, target_data):
        online_y = self.online_encoder(online_data.x, online_data.edge_index)
        online_q = self.predictor(online_y)
        with torch.no_grad():
            target_y = self.target_encoder(target_data.x, target_data.edge_index).detach()
        return online_q, target_y
    
    def get_embeddings(self, data):
        x = self.online_encoder(data.x, data.edge_index)
        return x.detach()
    

class CCA_SSG(nn.Module):
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder

    def forward(self, feat1, edge_index1, feat2, edge_index2):
        h1 = self.encoder(feat1, edge_index1)
        h2 = self.encoder(feat2, edge_index2)

        z1 = (h1 - h1.mean(0)) / h1.std(0)
        z2 = (h2 - h2.mean(0)) / h2.std(0)
        return z1, z2
    
    def get_embeddings(self, feat, graph):
        out = self.encoder(feat, graph)
        return out.detach()
    


class GCA(torch.nn.Module):
    def __init__(self, encoder, num_hidden: int, num_proj_hidden: int, tau: float = 0.5):
        super(GCA, self).__init__()
        self.encoder = encoder
        self.tau: float = tau

        self.fc1 = torch.nn.Linear(num_hidden, num_proj_hidden)
        self.fc2 = torch.nn.Linear(num_proj_hidden, num_hidden)
        self.num_hidden = num_hidden

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        return self.encoder(x, edge_index)

    def projection(self, z: torch.Tensor) -> torch.Tensor:
        z = F.elu(self.fc1(z))
        return self.fc2(z)

    def sim(self, z1: torch.Tensor, z2: torch.Tensor):
        z1 = F.normalize(z1)
        z2 = F.normalize(z2)
        return torch.mm(z1, z2.t())

    def semi_loss(self, z1: torch.Tensor, z2: torch.Tensor):
        f = lambda x: torch.exp(x / self.tau)
        refl_sim = f(self.sim(z1, z1))
        between_sim = f(self.sim(z1, z2))
        return -torch.log(between_sim.diag() / (refl_sim.sum(1) + between_sim.sum(1) - refl_sim.diag()))

    def batched_semi_loss(self, z1: torch.Tensor, z2: torch.Tensor, batch_size: int):
        device = z1.device
        num_nodes = z1.size(0)
        num_batches = (num_nodes - 1) // batch_size + 1
        f = lambda x: torch.exp(x / self.tau)
        indices = torch.arange(0, num_nodes).to(device)
        losses = []

        for i in range(num_batches):
            mask = indices[i * batch_size:(i + 1) * batch_size]
            refl_sim = f(self.sim(z1[mask], z1))  # [B, N]
            between_sim = f(self.sim(z1[mask], z2))  # [B, N]

            losses.append(-torch.log(between_sim[:, i * batch_size:(i + 1) * batch_size].diag()
                                     / (refl_sim.sum(1) + between_sim.sum(1)
                                        - refl_sim[:, i * batch_size:(i + 1) * batch_size].diag())))
        return torch.cat(losses)

    def loss(self, z1: torch.Tensor, z2: torch.Tensor, mean: bool = True, batch_size= None):
        h1 = self.projection(z1)
        h2 = self.projection(z2)

        if batch_size is None:
            l1 = self.semi_loss(h1, h2)
            l2 = self.semi_loss(h2, h1)
        else:
            l1 = self.batched_semi_loss(h1, h2, batch_size)
            l2 = self.batched_semi_loss(h2, h1, batch_size)

        ret = (l1 + l2) * 0.5
        ret = ret.mean() if mean else ret.sum()
        return ret


class GraphConvolution(Module):
    def __init__(self, in_features, out_features, dropout=0., act=F.relu):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.dropout = dropout
        self.act = act
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.weight)

    def forward(self, input, adj):
        input = F.dropout(input, self.dropout, self.training)
        support = torch.mm(input, self.weight)
        output = torch.spmm(adj, support)
        output = self.act(output)
        return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'


class GCNModelAE(nn.Module):
    def __init__(self, input_feat_dim, hidden_dim1, hidden_dim2, dropout):
        super(GCNModelAE, self).__init__()
        self.gc1 = GraphConvolution(input_feat_dim, hidden_dim1, dropout, act=F.relu)
        self.gc2 = GraphConvolution(hidden_dim1, hidden_dim2, dropout, act=lambda x: x)
        self.dc = InnerProductDecoder(dropout, act=lambda x: x)

    def encode(self, x, adj):
        hidden1 = self.gc1(x, adj)
        embeddings = self.gc2(hidden1, adj)
        return embeddings, None

    def forward(self, x, adj):
        embeddings, _ = self.encode(x, adj)
        reconstructions = self.dc(embeddings)
        return reconstructions, embeddings, None


class GCNModelVAE(nn.Module):
    def __init__(self, input_feat_dim, hidden_dim1, hidden_dim2, dropout):
        super(GCNModelVAE, self).__init__()
        self.gc1 = GraphConvolution(input_feat_dim, hidden_dim1, dropout, act=F.relu)
        self.gc2 = GraphConvolution(hidden_dim1, hidden_dim2, dropout, act=lambda x: x)
        self.gc3 = GraphConvolution(hidden_dim1, hidden_dim2, dropout, act=lambda x: x)
        self.dc = InnerProductDecoder(dropout, act=lambda x: x)

    def encode(self, x, adj):
        hidden1 = self.gc1(x, adj)
        return self.gc2(hidden1, adj), self.gc3(hidden1, adj)

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu

    def forward(self, x, adj):
        mu, logvar = self.encode(x, adj)
        z = self.reparameterize(mu, logvar)
        return self.dc(z), mu, logvar


class InnerProductDecoder(nn.Module):
    def __init__(self, dropout, act=torch.sigmoid):
        super(InnerProductDecoder, self).__init__()
        self.dropout = dropout
        self.act = act

    def forward(self, z):
        z = F.dropout(z, self.dropout, training=self.training)
        adj = self.act(torch.mm(z, z.t()))
        return adj