import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import faiss
import numpy as np
from torch_geometric.nn import TransformerConv

from utils_afgrl import loss_fn, update_moving_average, EMA

def set_requires_grad(model, val):
    for p in model.parameters():
        p.requires_grad = val

class GraphEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_heads):
        super(GraphEncoder, self).__init__()
        self.conv1 = TransformerConv(in_channels, hidden_channels, heads=num_heads, edge_dim=1, dropout=0.1)
        self.conv2 = TransformerConv(hidden_channels * num_heads, hidden_channels, heads=num_heads, edge_dim=1, dropout=0.1)
        self.conv3 = TransformerConv(hidden_channels * num_heads, hidden_channels, heads=num_heads, edge_dim=1, dropout=0.1)
        self.conv4 = TransformerConv(hidden_channels * num_heads, hidden_channels, heads=num_heads, edge_dim=1, dropout=0.1)
        self.conv5 = TransformerConv(hidden_channels * num_heads, out_channels, heads=1, edge_dim=1, dropout=0.1)

    def forward(self, x, edge_index, edge_attr):
        x = F.relu(self.conv1(x, edge_index, edge_attr))
        x = F.relu(self.conv2(x, edge_index, edge_attr))
        x = F.relu(self.conv3(x, edge_index, edge_attr))
        x = F.relu(self.conv4(x, edge_index, edge_attr))
        x = self.conv5(x, edge_index, edge_attr)
        return x

class Neighbor(nn.Module):
    def __init__(self, device, num_centroids, num_kmeans, clus_num_iters):
        super().__init__()
        self.device = device
        self.num_centroids = num_centroids
        self.num_kmeans = num_kmeans
        self.clus_num_iters = clus_num_iters

    def forward(self, adj, student, teacher, top_k, epoch):
        n_data, d = student.shape
        similarity = torch.matmul(student, teacher.T.detach())
        similarity += torch.eye(n_data, device=self.device) * 10

        _, I_knn = similarity.topk(k=top_k, dim=1, largest=True, sorted=True)  # [N,k]

        # locality
        knn_sparse = self.create_sparse(I_knn)
        locality = knn_sparse * adj

        # clustering on teacher
        teacher_cpu = teacher.detach().cpu().numpy()
        pred_labels = []
        for seed in range(self.num_kmeans):
            kmeans = faiss.Kmeans(d, self.num_centroids, niter=self.clus_num_iters, gpu=False, seed=seed + 1234)
            kmeans.train(teacher_cpu)
            _, I_kmeans = kmeans.index.search(teacher_cpu, 1)
            clust_labels = I_kmeans[:, 0]
            pred_labels.append(clust_labels)
        pred_labels = np.stack(pred_labels, axis=0)  # [num_kmeans, N]
        cluster_labels = torch.from_numpy(pred_labels).long().to(self.device)  # [num_kmeans, N]

        all_close = None
        with torch.no_grad():
            for k_idx in range(self.num_kmeans):
                node_labels = cluster_labels[k_idx]  # [N]
                neigh_labels = node_labels[I_knn]    # [N,k]
                expanded = node_labels.unsqueeze(1).expand_as(neigh_labels)  # [N,k]
                close = torch.eq(expanded, neigh_labels)  # [N,k]
                if all_close is None:
                    all_close = close
                else:
                    all_close = all_close | close
        all_close = all_close.to(self.device)  # [N,k]
        globality = self.create_sparse_revised(I_knn, all_close)

        pos_ = locality + globality
        return pos_.coalesce()._indices(), I_knn.shape[1]

    def create_sparse(self, I):
        index = torch.arange(I.shape[0], device=self.device).unsqueeze(1).expand(-1, I.shape[1]).reshape(-1)
        similar = I.reshape(-1)
        indices = torch.stack([index, similar], dim=0)
        values = torch.ones(indices.shape[1], device=self.device)
        return torch.sparse_coo_tensor(indices, values, (I.shape[0], I.shape[0])).coalesce()

    def create_sparse_revised(self, I, mask):
        n_data, k = I.shape
        idx_list = []
        for j in range(n_data):
            for i in range(k):
                if mask[j, i]:
                    idx_list.append([j, I[j, i].item()])
        if len(idx_list) == 0:
            return torch.sparse_coo_tensor(torch.zeros((2,0), device=self.device, dtype=torch.long),
                                           torch.tensor([], device=self.device),
                                           (n_data, n_data))
        indices = torch.tensor(idx_list, device=self.device).t()
        values = torch.ones(indices.shape[1], device=self.device)
        return torch.sparse_coo_tensor(indices, values, (n_data, n_data)).coalesce()

class AFGRLModel(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_heads, args):
        super().__init__()
        self.student_encoder = GraphEncoder(in_channels, hidden_channels, out_channels, num_heads)
        self.teacher_encoder = copy.deepcopy(self.student_encoder)
        set_requires_grad(self.teacher_encoder, False)
        self.teacher_ema_updater = EMA(args.mad, args.num_epochs)
        self.neighbor = Neighbor(device=args.device,
                                 num_centroids=args.num_centroids,
                                 num_kmeans=args.num_kmeans,
                                 clus_num_iters=args.clus_num_iters)
        rep_dim = out_channels  # encoder output
        self.predictor = nn.Sequential(
            nn.Linear(rep_dim, args.pred_hid),
            nn.BatchNorm1d(args.pred_hid),
            nn.PReLU(),
            nn.Linear(args.pred_hid, rep_dim)
        )
        # predictor init
        for m in self.predictor:
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    m.bias.data.fill_(0.01)
        self.topk = args.topk

    def update_teacher(self):
        update_moving_average(self.teacher_ema_updater, self.teacher_encoder, self.student_encoder)

    def forward(self, x, edge_index, edge_attr, epoch):
        student = self.student_encoder(x, edge_index, edge_attr)
        pred = self.predictor(student)
        with torch.no_grad():
            teacher = self.teacher_encoder(x, edge_index, edge_attr)

        n = x.shape[0]
        if edge_attr is None:
            adj = torch.sparse.FloatTensor(edge_index, torch.ones(edge_index.shape[1], device=x.device), (n, n)).to(x.device)
        else:
            weight = edge_attr.squeeze(-1) if edge_attr.dim() > 1 else edge_attr
            adj = torch.sparse_coo_tensor(edge_index, weight, (n, n), device=x.device)


        ind, k = self.neighbor(adj, F.normalize(student, dim=-1), F.normalize(teacher, dim=-1), self.topk, epoch)
        loss1 = loss_fn(pred[ind[0]], teacher[ind[1]].detach())
        loss2 = loss_fn(pred[ind[1]], teacher[ind[0]].detach())
        loss = (loss1 + loss2).mean()
        return student, loss, ind, k
