import time
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tasks.base_task import BaseTask
from tasks.utils import train, kd_el_train, evaluate, accuracy, kd_el_evaluate



class SGLNodeClassification(BaseTask):
    def __init__(self, dataset, model, lr, weight_decay, epochs, device, show_epoch_info = 20, loss_fn=nn.CrossEntropyLoss()):
        super(SGLNodeClassification, self).__init__()

        self.dataset = dataset
        self.labels = self.dataset.y

        self.model = model
        self.optimizer = Adam(model.parameters(), lr=lr,
                                weight_decay=weight_decay)
        self.epochs = epochs
        self.show_epoch_info = show_epoch_info
        self.loss_fn = loss_fn
        self.device = device



    def execute(self):

        self.model.preprocess(self.dataset.adj, self.dataset.x)
        self.model = self.model.to(self.device)
        self.labels = self.labels.to(self.device)

        best_val = 0.
        best_test = 0.
        for epoch in range(self.epochs):
            t = time.time()

            loss_train, acc_train = train(self.model, self.dataset.train_idx, self.labels, self.device,
                                            self.optimizer, self.loss_fn)
            acc_val, acc_test = evaluate(self.model, self.dataset.val_idx, self.dataset.test_idx,
                                            self.labels, self.device)



            # print('| Epoch: {:03d}'.format(epoch + 1),
            #     'loss_train: {:.4f}'.format(loss_train),
            #     'acc_train: {:.4f}'.format(acc_train),
            #     'acc_val: {:.4f}'.format(acc_val),
            #     'acc_test: {:.4f}'.format(acc_test),
            #     'time: {:.4f}s'.format(time.time() - t))
            if acc_val > best_val:
                best_val = acc_val
                best_test = acc_test

        acc_val, acc_test = self.postprocess()
        if acc_val > best_val:
            best_val = acc_val
            best_test = acc_test

        return best_val, best_test, self.model

    def postprocess(self):
        self.model.eval()

        outputs = self.model.model_forward(
            range(self.dataset.num_nodes), self.device).to("cpu")


        final_output = self.model.postprocess(self.dataset.adj, outputs)
        acc_val = accuracy(
            final_output[self.dataset.val_idx], self.labels[self.dataset.val_idx])
        acc_test = accuracy(
            final_output[self.dataset.test_idx], self.labels[self.dataset.test_idx])
        return acc_val, acc_test


class SGLNodeClassificationKDEL(BaseTask):
    def __init__(self, dataset, model, lr, weight_decay, epochs, device, model_list,
                loss_ce=nn.CrossEntropyLoss(reduction='mean'), loss_kd=nn.KLDivLoss(reduction='batchmean')):
        super(SGLNodeClassificationKDEL, self).__init__()

        self.dataset = dataset
        self.labels = self.dataset.y
        self.model_list = model_list

        self.model = model
        self.optimizer = Adam(model.parameters(), lr=lr,
                                weight_decay=weight_decay)
        self.epochs = epochs
        self.loss_ce = loss_ce
        self.loss_kd = loss_kd
        self.device = device


    def execute(self):
        self.model.preprocess(self.dataset.adj, self.dataset.x)
        self.model = self.model.to(self.device)
        self.labels = self.labels.to(self.device)
        if len(self.model_list) != 0:
            for i in range(len(self.model_list)):
                self.model_list[i].preprocess(self.dataset.adj, self.dataset.x)
                self.model_list[i].to(self.device)

        best_val = 0.
        best_test = 0.
        for epoch in range(self.epochs):
            t = time.time()

            loss_train, acc_train = kd_el_train(self.model, self.model_list, self.dataset.train_idx, self.labels, self.device,
                                            self.optimizer, self.loss_ce, self.loss_kd)
            acc_val, acc_test = evaluate(self.model, self.dataset.val_idx, self.dataset.test_idx,
                                            self.labels, self.device)


            # print('| Epoch: {:03d}'.format(epoch + 1),
            #     'loss_train: {:.4f}'.format(loss_train),
            #     'acc_train: {:.4f}'.format(acc_train),
            #     'acc_val: {:.4f}'.format(acc_val),
            #     'acc_test: {:.4f}'.format(acc_test),
            #     'time: {:.4f}s'.format(time.time() - t))
            if acc_val > best_val:
                best_val = acc_val
                best_test = acc_test

        acc_val, acc_test = self.postprocess()
        if acc_val > best_val:
            best_val = acc_val
            best_test = acc_test

        return best_val, best_test, self.model

    def postprocess(self):
        self.model.eval()

        outputs = self.model.model_forward(
            range(self.dataset.num_nodes), self.device).to("cpu")


        final_output = self.model.postprocess(self.dataset.adj, outputs)
        acc_val = accuracy(
            final_output[self.dataset.val_idx], self.labels[self.dataset.val_idx])
        acc_test = accuracy(
            final_output[self.dataset.test_idx], self.labels[self.dataset.test_idx])
        return acc_val, acc_test


class SGLEvaluateModelClients(BaseTask):
    def __init__(self, dataset, model, device):
        super(SGLEvaluateModelClients, self).__init__()
        self.dataset = dataset
        self.labels = self.dataset.y

        self.model = model
        self.device = device



    def execute(self):
        self.model.preprocess(self.dataset.adj, self.dataset.x)
        self.model = self.model.to(self.device)
        self.labels = self.labels.to(self.device)


        acc_val, acc_test = evaluate(self.model, self.dataset.val_idx, self.dataset.test_idx,
                                        self.labels, self.device)


        acc_val, acc_test = self.postprocess()

        return acc_val, acc_test

    def postprocess(self):
        self.model.eval()

        outputs = self.model.model_forward(
            range(self.dataset.num_nodes), self.device).to("cpu")


        final_output = self.model.postprocess(self.dataset.adj, outputs)
        acc_val = accuracy(
            final_output[self.dataset.val_idx], self.labels[self.dataset.val_idx])
        acc_test = accuracy(
            final_output[self.dataset.test_idx], self.labels[self.dataset.test_idx])
        return acc_val, acc_test


class SGLEvaluateModelClientsKDEL(BaseTask):
    def __init__(self, dataset, model_list, device):
        super(SGLEvaluateModelClientsKDEL, self).__init__()
        self.dataset = dataset
        self.labels = self.dataset.y
        self.model_list = model_list
        self.device = device


    def execute(self):
        for i in range(len(self.model_list)):
            self.model_list[i].preprocess(self.dataset.adj, self.dataset.x)
            self.model_list[i].to(self.device)
        self.labels = self.labels.to(self.device)


        acc_val, acc_test = kd_el_evaluate(self.model_list, self.dataset.val_idx, self.dataset.test_idx,
                                        self.labels, self.device)

        acc_val, acc_test = self.postprocess()

        return acc_val, acc_test

    # adaptive label propagation
    def postprocess(self):


        outputs_list = []
        for i in range(len(self.model_list)):
            self.model_list[i].eval()
            outputs = self.model_list[i].model_forward(range(self.dataset.num_nodes), self.device)
            outputs_list.append(outputs)

        z_ensemble = outputs_list[0]
        for i in range(1, len(self.model_list)):
            z_ensemble += outputs_list[i]
        outputs = z_ensemble / len(self.model_list)


        final_output = self.model_list[-1].postprocess(self.dataset.adj, outputs)
        acc_val = accuracy(
            final_output[self.dataset.val_idx], self.labels[self.dataset.val_idx])
        acc_test = accuracy(
            final_output[self.dataset.test_idx], self.labels[self.dataset.test_idx])
        return acc_val, acc_test
