import copy
import time
import numpy as np
from config import args
from models.sgl_models import SGC, MLP
from models.nlgnn import NLMLP, NLGCN
from models.ggcn import GGCN, GGCNSP
from models.gcn import GCN, ChebNet
from tasks.node_cls import SGLNodeClassification

class ClientsManager():
    def __init__(self, model_name, datasets, num_clients, device, eval_single_client=False):
        self.model_name = model_name
        self.global_data = datasets.global_data
        self.input_dim = datasets.input_dim
        self.output_dim = datasets.output_dim
        self.subgraphs = datasets.subgraphs
        self.device = device

        self.num_clients = num_clients
        self.clients = []
        self.tot_nodes = 0
        
        if datasets.name in ["Cora", "CiteSeer", "PubMed"]:
            self.hidden_dim = 64
        elif datasets.name in ["Squirrel", "Chameleon"]:
            self.hidden_dim = 64

        self.initClient()
        if eval_single_client:
            self.evaluate_data_isolate()

    def initClient(self):
        for client_id in range(self.num_clients):
            client = Client(
                model_name = self.model_name, 
                input_dim = self.input_dim, 
                output_dim = self.output_dim, 
                client_id = client_id, 
                local_subgraph = self.subgraphs[client_id],
                hidden_dim = self.hidden_dim
            )

            self.clients.append(client)
            self.tot_nodes += client.num_nodes

    def evaluate_data_isolate(self, normalize_trains=10, lr=args.evalute_lr, weight_decay=5e-4, epochs=200):
        print("| ★  Evaluate Isolate Data")
        global_acc_test_mean = 0
        global_acc_test_std = 0
        global_acc_val_mean = 0
        global_acc_val_std = 0
        for i in range(len(self.clients)):
            test_acc_list = []
            val_acc_list = []
            t_total = time.time()
            for _ in range(normalize_trains):
                self.clients[i].init_model()
                val_acc, test_acc, _ = SGLNodeClassification(dataset = self.clients[i].local_subgraph, 
                model = self.clients[i].model, 
                lr = lr, 
                weight_decay = weight_decay, 
                epochs = epochs, 
                device = self.device).execute()
                test_acc_list.append(test_acc)
                val_acc_list.append(val_acc)
            print("| ◯  Client ID: {}".format(i+1))
            print("| Normalize Train: {}, Total Time Elapsed: {:.4f}s".format(normalize_trains, time.time() - t_total))
            print("| Mean Val ± Std Val: {}±{}, Mean Test ± Std Test: {}±{}".format(round(np.mean(val_acc_list), 4), round(np.std(val_acc_list, ddof=1), 4), round(np.mean(test_acc_list), 4), round(np.std(test_acc_list, ddof=1), 4)))
            global_acc_test_mean += np.mean(test_acc_list) * self.clients[i].local_subgraph.num_nodes / self.global_data.num_nodes
            global_acc_test_std += np.std(test_acc_list, ddof=1) * self.clients[i].local_subgraph.num_nodes / self.global_data.num_nodes
            global_acc_val_mean += np.mean(val_acc_list) * self.clients[i].local_subgraph.num_nodes / self.global_data.num_nodes
            global_acc_val_std += np.std(val_acc_list, ddof=1) * self.clients[i].local_subgraph.num_nodes / self.global_data.num_nodes
        print("| ")
        print("| Global Eval Mean Val ± Std Val: {}±{}, Mean Test ± Std Test: {}±{}".format(round(global_acc_val_mean, 4), round(global_acc_val_std, 4), round(global_acc_test_mean, 4), round(global_acc_test_std, 4)))
        print("| ")


class Client(object):
    def __init__(self, model_name, input_dim, output_dim, client_id, local_subgraph, hidden_dim):
        self.model_name = model_name
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.client_id = client_id
        self.local_subgraph = local_subgraph
        self.model_records = []
        self.num_nodes = self.local_subgraph.num_nodes
        self.hidden_dim = hidden_dim
        self.init_model()

    def init_model(self):
        if self.model_name == "SGC":
            self.model = SGC(prop_steps=3, feat_dim=self.input_dim, output_dim=self.output_dim)
        
        elif self.model_name == "MLP":
            self.model = MLP(feat_dim=self.input_dim, hidden_dim=self.hidden_dim, num_layers=args.gmlp_num_layers, output_dim=self.output_dim, dropout=args.gmlp_drop, bn=args.gmlp_bn)
        
        elif self.model_name == "NLMLP":
            self.model = NLMLP(feat_dim=self.input_dim, hidden_dim=self.hidden_dim, output_dim=self.output_dim, dropout=args.gmlp_drop, bn=args.gmlp_bn)

        elif self.model_name == "NLGCN":
            self.model = NLGCN(feat_dim=self.input_dim, hidden_dim=self.hidden_dim, output_dim=self.output_dim, dropout=args.gmlp_drop, bn=args.gmlp_bn, ln=args.gmlp_ln)
        
        elif self.model_name == "GCN":
            self.model = GCN(feat_dim=self.input_dim, hidden_dim=self.hidden_dim, output_dim=self.output_dim, dropout=args.gmlp_drop, bn=args.gmlp_bn, ln=args.gmlp_ln)
        
        elif self.model_name == "ChebNet":
            self.model = ChebNet(feat_dim=self.input_dim, hidden_dim=self.hidden_dim, output_dim=self.output_dim, dropout=args.gmlp_drop, bn=args.gmlp_bn, ln=args.gmlp_ln)

        elif self.model_name == "GGCN":
            self.model = GGCN(feat_dim=self.input_dim, num_layers=2, hidden_dim=self.hidden_dim, output_dim=self.output_dim, dropout=args.gmlp_drop, decay_rate=1.0, exponent=3.0)
        
        elif self.model_name == "GGCNSP":
            self.model = GGCNSP(feat_dim=self.input_dim, num_layers=2, hidden_dim=self.hidden_dim, output_dim=self.output_dim, dropout=args.gmlp_drop, decay_rate=1.0, exponent=3.0, bn=args.gmlp_bn, ln=args.gmlp_ln)

    def clear_record(self):
        self.model_records = []

    def add_record(self, model):
        self.model_records.append(copy.deepcopy(model))

    def set_state_dict(self, model):
        self.model.load_state_dict(state_dict=copy.deepcopy(model.state_dict()))








