from time import time
import os.path as osp
import math
import torch as th
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
import GCL.augmentors as A

from models.GraphHD.model import UnbiasGCL
from models.GraphHD.h_cluster import run_hkmeans_faiss, run_hkmeans

from utils.util_funcs import print_log
import wandb

class Teacher_Trainer():
    def __init__(self, dataset, cf):
        self.__dict__.update(cf.__dict__)
        self.cf = cf
        self.res_file = cf.res_file
        self.device = cf.compute_dev
        self.batch_size = cf.batch_size
        self.JK = cf.JK
        self.emb_dim = cf.n_hidden * (cf.n_layer + 1) if self.JK == 'concat' else cf.n_hidden
        self.n_protos = [int(x) for x in cf.n_protos.split('_')]
        cf.h_level = len(self.n_protos)
        self.h_level = cf.h_level
        self.local_view = True if self.laml > 0 else False
        self.proto_up = cf.proto_up
        self.use_faiss = cf.use_faiss
        self.momentum = cf.momentum
        self.mo_budget = cf.mo_budget

        self.teacher_file_path = cf.teacher_file
        self.warmup_file_path = cf.warmup_file

        self.dataloader = DataLoader(dataset, batch_size=cf.batch_size, shuffle=True, num_workers=4)
        aug1 = A.Identity()
        aug2 = A.RandomChoice([A.NodeDropping(pn=0.2), A.EdgeRemoving(pe=0.2)], 1)
        self.encoder_model = UnbiasGCL(augmentor=(aug1, aug2), n_protos=self.n_protos, cf=cf).to(self.device)

        self.data_len = len(dataset)
        self.pretrain_optimizer = th.optim.Adam(self.encoder_model.parameters(), lr=cf.lr,
                                                weight_decay=cf.weight_decay)

    def run_preatrain(self):

        print('Initializing...')
        for epoch in range(1, self.w_epochs+1):
            t0 = time()
            loss, global_loss, local_loss = self.init()
            print_log({'Epoch': epoch, 'Time': time() - t0, 'Loss': loss,
                                     'GlobalLoss': global_loss, 'LocalLoss': local_loss})
            wandb.log({'Epoch': epoch, 'Loss': loss,
                                     'GlobalLoss': global_loss, 'LocalLoss': local_loss})

        print('Clustering...')
        self.cluster_result = self.run_cluster()

        model_param_group = [{"params": self.encoder_model.encoder.parameters(), "lr": self.cf.lr},
                             {"params": self.encoder_model.proj.parameters(), "lr": self.cf.lr}]

        if self.proto_up == 'bp':
            for l in range(len(self.n_protos)):
                self.cluster_result['centroids'][l] = th.nn.parameter.Parameter(self.cluster_result['centroids'][l]).to(self.device)
                self.cluster_result['centroids'][l].requires_grad = True
                model_param_group += [{'params': self.cluster_result['centroids'][l], 'lr': self.cf.lr, 'weight_decay': self.cf.weight_decay}]
        self.optimizer = th.optim.Adam(model_param_group, lr=self.cf.lr, weight_decay=self.cf.weight_decay)

        print('Start training...')
        for epoch in range(1, self.p_epochs+1):
            t0 = time()
            loss, global_loss, local_loss, proto_loss, r_loss = self.train(epoch)
            print_log({'Epoch': epoch, 'Time': time() - t0, 'Loss': loss,'GlobalLoss': global_loss,
                                     'LocalLoss': local_loss, 'ProtoLoss': proto_loss, 'ReguLoss': r_loss})

            wandb.log({'Epoch': epoch, 'Loss': loss,'GlobalLoss': global_loss,
                                     'LocalLoss': local_loss, 'ProtoLoss': proto_loss, 'ReguLoss': r_loss})

            if epoch % self.save_freq == 0:
                th.save({
                    'encoder': self.encoder_model.encoder.state_dict(),
                    'proj': self.encoder_model.proj.state_dict(),
                    'prototypes': self.cluster_result['centroids'],
                }, self.teacher_file_path + f"_tcp{epoch}" + ".pth")

        return self.encoder_model

    def run_cluster(self):
        graph_features = self.compute_features()
        if self.use_faiss:
            cluster_result = run_hkmeans_faiss(graph_features, self.n_protos, self.cf)
        else:
            cluster_result = run_hkmeans(self.encoder_model, self.dataloader, self.n_protos, self.cf)

        return cluster_result

    def init(self,):
        epoch_loss = 0
        epoch_global_loss = 0
        epoch_local_loss = 0
        self.encoder_model.train()

        for step, data in enumerate(self.dataloader):
            data = data.to(self.device)

            if data.x is None:
                num_nodes = data.batch.size(0)
                data.x = th.ones((num_nodes, 1), dtype=th.float32, device=self.device)

            z1, z2, g1, g2 = self.encoder_model(data)
            z1, z2, g1, g2 = [self.encoder_model.proj(x) for x in [z1, z2, g1, g2]]
            global_loss = self.encoder_model.cal_instance_loss(g1, g2)

            if self.local_view:
                local_loss = self.encoder_model.cal_local_loss(z1, z2, batch=data.batch)
            else:
                local_loss = th.tensor(-1)

            loss = self.lamg*global_loss + self.laml*local_loss

            self.pretrain_optimizer.zero_grad()
            loss.backward()
            self.pretrain_optimizer.step()
            epoch_loss += loss.item()
            epoch_global_loss += global_loss.item()
            epoch_local_loss += local_loss.item()

        return epoch_loss, epoch_global_loss, epoch_local_loss

    def train(self, epoch):
        epoch_loss = 0
        epoch_global_loss = 0
        epoch_local_loss = 0
        epoch_proto_loss = 0
        epoch_r_loss = 0
        budget_list = []

        for step, data in enumerate(self.dataloader):
            data = data.to(self.device)
            if data.x is None:
                num_nodes = data.batch.size(0)
                data.x = th.ones((num_nodes, 1), dtype=th.float32, device=self.device)

            self.encoder_model.train()

            z1, z2, g1, g2 = self.encoder_model(data)
            z1, z2, g1, g2 = [self.encoder_model.proj(x) for x in [z1, z2, g1, g2]]
            global_loss = self.encoder_model.cal_instance_loss(g1, g2)

            if self.local_view:
                local_loss = self.encoder_model.cal_local_loss(z1, z2, batch=data.batch)
            else:
                local_loss = th.tensor(-1)

            if self.proto_up == 'bp':
                proto_loss, proto_probs, _ = self.encoder_model.cal_proto_loss_bp(g1, self.cluster_result)
            elif self.proto_up == 'momentum':
                proto_loss, proto_probs, batch_centroids = self.encoder_model.cal_proto_loss_bp(g1, self.cluster_result, ptoto_update=self.proto_up)
                budget_list = self.update_budget(budget_list, batch_centroids, g1)
            else:
                raise NotImplementedError

            loss = self.lamg*global_loss + self.laml*local_loss + self.lamp*proto_loss

            if self.lamh > 0:
                rloss = 0
                for l in range(self.h_level):
                    avg_prob = proto_probs[l]
                    rloss = rloss - th.sum(th.log(avg_prob ** (-avg_prob))) + math.log(float(len(avg_prob)))
                loss += self.lamh*rloss
            else:
                rloss = th.tensor(-1)

            self.optimizer.zero_grad()
            loss.backward()
            if self.clip > 0:
                for l in range(self.h_level):
                    th.nn.utils.clip_grad_norm_(self.cluster_result['centroids'][l], self.clip)
            self.optimizer.step()

            epoch_loss += loss.item()
            epoch_global_loss += global_loss.item()
            epoch_local_loss += local_loss.item()
            epoch_proto_loss += proto_loss.item()
            epoch_r_loss += rloss.item()

        if self.proto_up == 'momentum' and self.mo_budget == 0:
            self.update_centroid(budget_list)

        return epoch_loss, epoch_global_loss, epoch_local_loss, epoch_proto_loss, epoch_r_loss

    @th.no_grad()
    def compute_features(self,):
        self.encoder_model.eval()
        graph_features = th.zeros(self.data_len, self.emb_dim).to(self.device)
        for step, data in enumerate(self.dataloader):
            data = data.to(self.device)
            x = self.encoder_model.encoder(data.x, data.edge_index, data.edge_attr)
            g = self.encoder_model.pretrain_pool(x, data.batch)
            g = self.encoder_model.proj(g)
            graph_features[data.id] = g

        return graph_features.cpu().numpy()

    @th.no_grad()
    def update_centroid(self, budget_list):
        for l in range(self.h_level):
            temp_proto = sum([b[l] for b in budget_list])
            temp_proto = F.normalize(temp_proto)
            assert temp_proto.shape == self.cluster_result['centroids'][l].shape
            self.cluster_result['centroids'][l] = self.cluster_result['centroids'][l]*self.momentum + (1.-self.momentum)*temp_proto

    @th.no_grad()
    def update_budget(self, budget_list, batch_centroids, graph_rep):
        if graph_rep.shape[0] < self.batch_size:
            budget_list = budget_list
            return budget_list
        else:
            budget_list.append(batch_centroids)

        if self.mo_budget == 0:
            return budget_list
        elif len(budget_list) % self.mo_budget != 0:
            return budget_list
        else:
            self.update_centroid(budget_list)
            budget_list = []
            return budget_list