import time
import numpy as np
import random
import torch
import os.path as osp
from config import args
from collections import OrderedDict
from models.sgl_models import SGC, MLP
from models.nlgnn import NLMLP, NLGCN
from models.ggcn import GGCN, GGCNSP
from models.gcn import GCN, ChebNet
from tasks.node_cls import SGLNodeClassification, SGLNodeClassificationKDEL, SGLEvaluateModelClients



class ServerManager():
    def __init__(self, model_name, datasets, num_clients, device, num_rounds, client_sample_ratio, eval_global):
        self.model_name = model_name
        self.datasets = datasets
        self.input_dim = datasets.input_dim
        self.output_dim = datasets.output_dim
        self.global_data = datasets.global_data
        self.subgraphs = datasets.subgraphs
        self.num_clients = num_clients
        self.device = device
        self.client_sample_ratio = client_sample_ratio
        self.state_dict_records = []
        self.num_rounds = num_rounds

        if self.datasets.name in ["Cora", "CiteSeer", "PubMed"]:
            self.hidden_dim = 64
        elif self.datasets.name in ["Squirrel", "Chameleon"]:
            self.hidden_dim = 64

        self.init_model()
        if eval_global:
            self.evaluate_global()

    def init_model(self):
        if self.model_name == "SGC":
            self.model = SGC(prop_steps=3, feat_dim=self.input_dim, output_dim=self.output_dim)

        elif self.model_name == "MLP":
            self.model = MLP(feat_dim=self.input_dim, hidden_dim=self.hidden_dim, num_layers=args.gmlp_num_layers, output_dim=self.output_dim, dropout=args.gmlp_drop, bn=args.gmlp_bn, ln=args.gmlp_ln)
        
        elif self.model_name == "NLMLP":
            self.model = NLMLP(feat_dim=self.input_dim, hidden_dim=self.hidden_dim, output_dim=self.output_dim, dropout=args.gmlp_drop, bn=args.gmlp_bn, ln=args.gmlp_ln)

        elif self.model_name == "NLGCN":
            self.model = NLGCN(feat_dim=self.input_dim, hidden_dim=self.hidden_dim, output_dim=self.output_dim, dropout=args.gmlp_drop, bn=args.gmlp_bn, ln=args.gmlp_ln)
        
        elif self.model_name == "GCN":
            self.model = GCN(feat_dim=self.input_dim, hidden_dim=self.hidden_dim, output_dim=self.output_dim, dropout=args.gmlp_drop, bn=args.gmlp_bn, ln=args.gmlp_ln)
        
        elif self.model_name == "ChebNet":
            self.model = ChebNet(feat_dim=self.input_dim, hidden_dim=self.hidden_dim, output_dim=self.output_dim, dropout=args.gmlp_drop, bn=args.gmlp_bn, ln=args.gmlp_ln)

        elif self.model_name == "GGCN":
            self.model = GGCN(feat_dim=self.input_dim, num_layers=2, hidden_dim=self.hidden_dim, output_dim=self.output_dim, dropout=args.gmlp_drop, decay_rate=1.0, exponent=3.0, bn=args.gmlp_bn, ln=args.gmlp_ln)
        
        elif self.model_name == "GGCNSP":
            self.model = GGCNSP(feat_dim=self.input_dim, num_layers=2, hidden_dim=self.hidden_dim, output_dim=self.output_dim, dropout=args.gmlp_drop, decay_rate=1.0, exponent=3.0, bn=args.gmlp_bn, ln=args.gmlp_ln)

    def evaluate_global(self, normalize_trains=10, lr=args.evalute_lr, weight_decay=5e-4, epochs=200):
        test_acc_list = []
        val_acc_list = []
        t_total = time.time()
        for _ in range(normalize_trains):
            self.init_model()
            val_acc, test_acc, _ = SGLNodeClassification(dataset = self.global_data, 
            model = self.model, 
            lr = lr, 
            weight_decay = weight_decay, 
            epochs = epochs, 
            device = self.device).execute()
            test_acc_list.append(test_acc)
            val_acc_list.append(val_acc)
        print("| ★  Evaluate Global Data")
        print("| Normalize Train: {}, Total Time Elapsed: {:.4f}s".format(normalize_trains, time.time() - t_total))
        print("| Mean Val ± Std Val: {}±{}, Mean Test ± Std Test: {}±{}".format(round(np.mean(val_acc_list), 4), round(np.std(val_acc_list, ddof=1), 4), round(np.mean(test_acc_list), 4), round(np.std(test_acc_list, ddof=1), 4)))

    def set_state_dict(self, state_dict):
        self.model.load_state_dict(state_dict=state_dict)

    def model_aggregation(self, server_input, mixing_coefficients):
        aggregated_model = OrderedDict()

        server_input = [a.state_dict() for a in server_input]
        for it, state_dict in enumerate(server_input):
            for key in state_dict.keys():
                if it == 0:
                    aggregated_model[key] = mixing_coefficients[it] * state_dict[key]
                else:
                    aggregated_model[key] += mixing_coefficients[it] * state_dict[key]

        return aggregated_model

    def collaborative_training_global_model_eval_multi_clients_no_trick(self, clients, data_name, num_clients, sampling, normalize_trains=args.gmlp_normalize_train, lr=args.gmlp_lr, weight_decay=args.gmlp_weight_decay, epochs=args.gmlp_num_epochs, eval=False):
        print("| ★  Start Training Federated GNN Model...")
        print("| Global Model Layers: {}, Training Round: {}, Local Epochs: {}".format(args.gmlp_num_layers, args.gmlp_num_rounds, args.gmlp_num_epochs))
        print("| Global Model Lr: {}, Weight decay: {}, Dropout: {}".format(args.gmlp_lr, args.gmlp_weight_decay, args.gmlp_drop))

        normalize_record = {"val_acc": [], "test_acc": []}
        t_total = time.time()
        for _ in range(normalize_trains):
            clients_test_acc = []
            clients_val_acc = []
            self.init_model()
            for client_id in range(self.num_clients):
                clients_test_acc.append(0)
                clients_val_acc.append(0)
                clients[client_id].clear_record()
                clients[client_id].init_model()

            round_global_record = {"global_val_acc": 0, "global_test_acc": 0}

            for round_id in range(self.num_rounds):
                # [1] sample clients for aggregating and updating
                all_client_idx = list(range(self.num_clients))
                random.shuffle(all_client_idx)
                sample_num = int(len(all_client_idx) * self.client_sample_ratio)
                sample_idx = sorted(all_client_idx[:sample_num])
                mixing_coefficients = [clients[idx].num_nodes for idx in sample_idx]
                mixing_coefficients = [val / sum(mixing_coefficients) for val in mixing_coefficients]
                aggregated_model_list = []

                # [2] collaborative_training
                for client_id in sample_idx:
                    clients[client_id].set_state_dict(self.model)
                    if round_id + 1 != 1:
                        clients[client_id].add_record(self.model)

                    _, _, local_model = SGLNodeClassification(dataset = clients[client_id].local_subgraph, 
                    model = clients[client_id].model, 
                    lr = lr, 
                    weight_decay = weight_decay, 
                    epochs = epochs, 
                    device = self.device,).execute()

                    aggregated_model_list.append(local_model)
                
                # [3] server executes
                aggregated_model = self.model_aggregation(aggregated_model_list, mixing_coefficients)
                self.set_state_dict(aggregated_model)

                # record last round result
                if round_id + 1 == self.num_rounds:
                    for client_id in sample_idx:
                        clients[client_id].set_state_dict(self.model)
                        clients[client_id].add_record(self.model)

                # [4] global model eval process
                global_val_acc = 0
                global_test_acc = 0

                # print("| Round: {}".format(round_id + 1))
                for client_id in range(self.num_clients):
                    val_acc, test_acc = SGLEvaluateModelClients(dataset = clients[client_id].local_subgraph, 
                    model = self.model, 
                    device = self.device).execute()

                    if val_acc > clients_val_acc[client_id]:
                        clients_val_acc[client_id] = val_acc
                        clients_test_acc[client_id] = test_acc

                    global_val_acc += (val_acc * clients[client_id].local_subgraph.num_nodes / self.datasets.global_data.num_nodes)
                    global_test_acc += (test_acc * clients[client_id].local_subgraph.num_nodes / self.datasets.global_data.num_nodes)
                    # print("| Client id: {}, Val Acc: {}, Test Acc: {}".format(client_id, round(val_acc, 4), round(test_acc, 4)))
                    # print("| Global Eval Val Acc: {}, Test Acc: {}".format(round(global_val_acc, 4), round(global_test_acc, 4)))

                if global_val_acc > round_global_record["global_val_acc"]:
                    round_global_record["global_val_acc"] = global_val_acc
                    round_global_record["global_test_acc"] = global_test_acc
                    if normalize_trains == 1 and not eval:
                        torch.save(self.model, osp.join("./model_weights", "{}_Client{}_{}_model.pt".format(data_name, num_clients, sampling)))
                        # print("| ▨ Save Global Model!")

            if normalize_trains == 1:
                for client_id in range(self.num_clients):
                    print("| Client id: {}, Val Acc: {}, Test Acc: {}".format(client_id+1, round(clients_val_acc[client_id], 4), round(clients_test_acc[client_id], 4)))

            normalize_record["val_acc"].append(round_global_record["global_val_acc"])
            normalize_record["test_acc"].append(round_global_record["global_test_acc"])

        print("| ★  Normalize Train Completed")
        print("| Normalize Train: {}, Total Time Elapsed: {:.4f}s".format(normalize_trains, time.time() - t_total))
        print("| Mean Val ± Std Val: {}±{}, Mean Test ± Std Test: {}±{}".format(round(np.mean(normalize_record["val_acc"]), 4), round(np.std(normalize_record["val_acc"], ddof=1), 4), round(np.mean(normalize_record["test_acc"]), 4), round(np.std(normalize_record["test_acc"], ddof=1), 4)))

