import copy
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import to_dense_adj

from model import MLP, GCN, MLP_Diag, Linear
from utils import normalize, symmetrize, contrastive_loss, subgraph


class Server():
    def __init__(self, client_list, client_ids, trainIdx, valIdx, dataset, device):
        self.client_list = client_list
        self.client_ids = client_ids
        self.valIdx = valIdx
        self.dataset = dataset
        self.num_train_nodes = np.sum([len(trainIdx[client_id]) for client_id in self.client_ids])

        self.num_val_nodes = [len(self.valIdx[client_id]) for client_id in self.client_ids]
        self.coefficients = [num_val_nodes / sum(self.num_val_nodes) for num_val_nodes in self.num_val_nodes]

        self.gnn = GCN(in_channel=dataset.num_node_features, out_channel=dataset.num_classes, hidden=16).to(device)
        self.mlp = MLP(in_channel=dataset.num_node_features, out_channel=dataset.num_classes, hidden=16).to(device)

    def train(self, rounds):
        best_val_acc = 0
        for round in range(1, rounds+1):
            gnn_averaged_weights = {}
            mlp_averaged_weights = {}

            for i, client in enumerate(self.client_list):
                # collect updated parameters from client i
                gnn_weight, mlp_weight = client.local_update(copy.deepcopy(self.gnn), copy.deepcopy(self.mlp), round)

                # average parameters
                for key in self.gnn.state_dict().keys():
                    if key in gnn_averaged_weights.keys():
                        gnn_averaged_weights[key] += self.coefficients[i] * gnn_weight[key]
                    else:
                        gnn_averaged_weights[key] = self.coefficients[i] * gnn_weight[key]

                for key in self.mlp.state_dict().keys():
                    if key in mlp_averaged_weights.keys():
                        mlp_averaged_weights[key] += self.coefficients[i] * mlp_weight[key]
                    else:
                        mlp_averaged_weights[key] = self.coefficients[i] * mlp_weight[key]

            self.gnn.load_state_dict(gnn_averaged_weights)
            self.mlp.load_state_dict(mlp_averaged_weights)

            loss_list = []
            val_loss_list = []
            labels_list = []
            pred_list = []
            val_acc_list = []
            test_acc_list = []

            for i, client in enumerate(self.client_list):
                loss, val_loss, labels, pred, acc_val, acc_test = client.stats(copy.deepcopy(self.gnn))
                loss_list.append(loss)
                val_loss_list.append(val_loss)
                labels_list.extend(labels)
                pred_list.extend(pred)
                val_acc_list.append(acc_val)
                test_acc_list.append(acc_test)

            if np.mean(val_acc_list) > best_val_acc:
                best_val_acc = np.mean(val_acc_list)
                best_test_acc_list = test_acc_list

                print('Round: {:5d} | train_loss: {:9.4f} | val_loss: {:9.4f} | acc_val: {:9.4f} | acc_test: {:9.4f}'
                      .format(round, np.mean(loss_list), np.mean(val_loss_list), np.mean(val_acc_list), np.mean(test_acc_list)))

        return np.mean(best_test_acc_list)


class Client():
    def __init__(self, client_id, dataset, trainIdx, valIdx, testIdx, lr, epochs, device, graphless=False, graph_learner=None, k=10):
        self.client_id = client_id
        self.node_list = trainIdx + valIdx + testIdx
        self.data = dataset[0]
        self.trainIdx = list(range(0, len(trainIdx)))
        self.valIdx = list(range(len(trainIdx), len(trainIdx) + len(valIdx)))
        self.testIdx = list(range(len(trainIdx) + len(valIdx), len(trainIdx) + len(valIdx) + len(testIdx)))
        self.features = self.data.x[self.node_list]
        self.labels = self.data.y[self.node_list]
        self.features = self.features.to(device)
        self.labels = self.labels.squeeze().to(device)
        self.lr = lr
        self.epochs = epochs
        self.gl_update = 2
        self.device = device
        self.graphless = graphless

        # initialize the gnn model
        self.gnn = GCN(in_channel=dataset.num_node_features, out_channel=dataset.num_classes, hidden=16).to(device)
        self.optimizer1 = torch.optim.Adam(self.gnn.parameters(), lr=self.lr)

        # initialize the mlp model
        self.mlp = MLP(in_channel=dataset.num_node_features, out_channel=dataset.num_classes, hidden=16).to(device)
        self.optimizer2 = torch.optim.Adam(self.mlp.parameters(), lr=self.lr)
        self.criterion2 = nn.KLDivLoss(reduction="batchmean", log_target=True)
        if self.graphless:
            if graph_learner == 'Attentive':
                self.graph_learner = MLP_Diag(2, isize=dataset.num_node_features, k=k, knn_metric='cosine', non_linearity='relu', i=6, mlp_act='relu').to(device)
            elif graph_learner == 'MLP':
                self.graph_learner = Linear(2, isize=dataset.num_node_features, k=k, knn_metric='cosine', non_linearity='relu', i=6, mlp_act='relu').to(device)
            self.optimizer3 = torch.optim.Adam(self.graph_learner.parameters(), lr=0.001)
        else:
            self.subgraph = subgraph(subset=torch.tensor(self.node_list, dtype=torch.long), edge_index=self.data.edge_index, num_nodes=self.data.num_nodes)
            self.A = to_dense_adj(self.subgraph).squeeze()
            self.A_hat = normalize(self.A + torch.eye(self.A.size(0)), 'sym').to(device)

    def local_update(self, gnn, mlp, round):
        self.mlp.load_state_dict(mlp.state_dict())
        self.gnn.load_state_dict(gnn.state_dict())
        if self.graphless:
            if round % self.gl_update == 0:
                self.graph_learner.train()
                learned_adj = self.graph_learner(self.features)
                learned_adj = symmetrize(learned_adj)
                self.A_hat = normalize(learned_adj, 'sym')
                self.mlp.eval()
                output_mlp = self.mlp(self.features)
                output_gnn = self.gnn(self.features, self.A_hat)
                self.optimizer3.zero_grad()
                loss = contrastive_loss(output_mlp, output_gnn, tau=0.2)
                loss.backward()
                self.optimizer3.step()
            learned_adj = self.graph_learner(self.features)
            learned_adj = symmetrize(learned_adj)
            with torch.no_grad():
                self.A_hat = normalize(learned_adj, 'sym')

        self.gnn.train()
        for epoch in range(1, self.epochs + 1):
            self.optimizer1.zero_grad()
            output_gnn = self.gnn(self.features, self.A_hat)
            output_gnn = F.log_softmax(output_gnn, dim=1)
            loss = F.nll_loss(output_gnn[self.trainIdx], self.labels[self.trainIdx])
            loss.backward()
            self.optimizer1.step()

        output_gnn = self.gnn(self.features, self.A_hat)

        self.mlp.train()
        for epoch in range(1, self.epochs + 1):
            self.optimizer2.zero_grad()
            output_mlp = self.mlp(self.features)
            loss = self.criterion2(F.log_softmax(output_mlp, dim=1), F.log_softmax(output_gnn, dim=1).detach())
            loss.backward()
            self.optimizer2.step()

        return self.gnn.state_dict(), self.mlp.state_dict()

    def stats(self, gnn):
        self.gnn.load_state_dict(gnn.state_dict())
        with torch.no_grad():
            output = self.gnn(self.features, self.A_hat)
            output = F.log_softmax(output, dim=1)
            loss = F.nll_loss(output[self.trainIdx], self.labels[self.trainIdx])
            val_loss = F.nll_loss(output[self.valIdx], self.labels[self.valIdx])
            pred = output.argmax(dim=1)
            correct = sum(np.array(self.labels[self.valIdx].cpu()) == np.array(pred[self.valIdx].cpu()))
            acc_val = correct / len(self.valIdx)
            correct = sum(np.array(self.labels[self.testIdx].cpu()) == np.array(pred[self.testIdx].cpu()))
            acc_test = correct / len(self.testIdx)

        return loss.item(), val_loss.item() * len(self.valIdx), \
               self.labels[self.testIdx].tolist(), pred[self.testIdx].tolist(), acc_val, acc_test


def train_fedgls(dataset, num_clients, trainIdx, valIdx, testIdx, args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # initialize graphless clients
    client_ids = [i for i in range(args.num_graphless)]
    client_list1 = [Client(client, dataset, trainIdx[client], valIdx[client], testIdx[client], args.lr, args.epochs, device, graphless=True, graph_learner=args.graph_learner, k=args.k) for client in client_ids]

    # initialize other clients
    client_ids = [i for i in range(args.num_graphless, num_clients)]
    client_list2 = [Client(client, dataset, trainIdx[client], valIdx[client], testIdx[client], args.lr, args.epochs, device, graphless=False, k=args.k) for client in client_ids]

    # initialize the central server
    client_ids = [i for i in range(num_clients)]
    server = Server(client_list=client_list1+client_list2, client_ids=client_ids, trainIdx=trainIdx, valIdx=valIdx, dataset=dataset, device=device)

    # start training
    best_acc = server.train(rounds=args.rounds)
    return best_acc
