import torch
import warnings
from utils import set_seed
from config import args
from roles.server import ServerManager
from roles.client import ClientsManager
from datasets.graph_fl_dataset import GraphFLDataset
from models.label_propagation_models import NonParaLP
warnings.filterwarnings('ignore')

if __name__ == "__main__":
    gpu_id = 0
    num_clients=10
    # device = "cpu"
    device = torch.device('cuda:{}'.format(gpu_id) if torch.cuda.is_available() else 'cpu')
    set_seed(args.seed)
    datasets = GraphFLDataset(
        root='./datasets',
        name='Squirrel',
        sampling='Noniid',
        num_clients=num_clients,
        analysis_local_subgraph=False,
        analysis_global_graph=False
    )
    all_homo = True

    for i in range(len(datasets.subgraphs)):
        subgraph = datasets.subgraphs[i]
        subgraph.y = subgraph.y.to(device)
        NonPLP = NonParaLP(prop_steps=100, num_class=datasets.output_dim, alpha=0.5)
        NonPLP.preprocess(subgraph=subgraph, device=device)
        NonPLP.propagate(adj=subgraph.adj)
        reliability_acc = NonPLP.eval()
        print("| ◯  Client ID: {}, Homo Reliability Value: {}".format(i+1, round(reliability_acc, 4)))
        if reliability_acc <= 0.5:
            all_homo = False

    if args.homo:
        print("| Hete Datasets, Choose MLP-Based or Convolution-Based Model As Global Model")
        if all_homo == True:
            model_name = "SGC"
            # model_name = "GCN"
            # model_name = "ChebNet"
            # model_name = "NLGCN"
            print("| Multi Clients Hold Homo Subgraphs, Choose SGC Model As Global Model")

        else:
            model_name = "MLP"
            # model_name = "NLMLP"
            # model_name = "GGCN"
            # model_name = "GGCNSP"
            print("| Some Clients Hold Hete Subgraphs, Choose MLP-Based or Convolution-Based Model As Global Model")
        print("| ")
    else:
        model_name = "ChebNet"
        # model_name = "SGC"
        # model_name = "GCN"
        # model_name = "MLP"
        # model_name = "NLMLP"
        # model_name = "NLGCN"
        print("| Hete Datasets, Choose MLP-Based or Convolution-Based Model As Global Model")
        print("| ")
    
    model_name = "NLGCN"

    Server = ServerManager(
        model_name=model_name,
        datasets=datasets, 
        num_clients=num_clients, 
        device=device,
        num_rounds=args.gmlp_num_rounds,
        client_sample_ratio=10,
        eval_global=False
    )

    # Find total parameters and trainable parameters
    total_trainable_params = sum(p.numel() for p in Server.model.parameters() if p.requires_grad)
    print("| ★  Training model parameters: {}M".format(round(total_trainable_params/1000000, 3) * args.num_clients))
    print("| ")


    client_manager = ClientsManager(
        model_name=model_name,
        datasets=datasets, 
        num_clients=num_clients, 
        device=device,
        eval_single_client=False
    )
    
    # eval: True  -> do not save global knowledge extractor
    # normalize_trains -> 

    print("| ★  Data method: {}, Model name: {}".format(datasets.sampling, model_name))
    Server.collaborative_training_global_model_eval_multi_clients_no_trick(
        client_manager.clients, 
        datasets.name, 
        datasets.num_clients, 
        datasets.sampling,
        normalize_trains=1,
        eval=True)


    # Server.collaborative_training_global_model_eval_multi_clients_KD_EL(client_manager.clients)






    



