import torch
import torch.nn as nn
import os.path as osp
import argparse
from flcore.fedgfm_final.server import Server
from flcore.fedgfm_final.client import Client
from utils.basic_utils import *
from utils.basic_utils import check_path, index_to_mask
from utils.partition_utils import graph_set_partition, single_graph_partition
from torch_geometric.nn.inits import glorot
import copy
from flcore.fedgfm_final.client import GPFplusAtt





if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--gpu_id", type=int, default=0)
    parser.add_argument("--seed", type=int, default=2025)
    parser.add_argument("--data_config", type=str, default='')
    parser.add_argument("--partition_root", type=str, default="")
    
    args = parser.parse_args()
    
    seed_everything(args.seed)
    device = torch.device(f"cuda:{args.gpu_id}")
    torch.cuda.set_device(args.gpu_id)
    
    # load global dataset
    num_clients_each_dataset, dataset_list = load_pretrain_mydataset_ofa(args)
    
    
    local_data_list = []
    
    for it, glb_data in enumerate(dataset_list):
        print(f"loading partition for {glb_data.name.lower()}...")
        if glb_data.name.lower() in ["cora","pubmed","wikics","arxiv"]:
            local_data_list += single_graph_partition(glb_data, num_partitions=num_clients_each_dataset[it], task="node_cls", root=args.partition_root)
        elif glb_data.name.lower() in ["wn18rr", "fb15k237"]:
            local_data_list += single_graph_partition(glb_data, num_partitions=num_clients_each_dataset[it], task="link_pre", root=args.partition_root)
        elif glb_data.name.lower() in ["chemhiv","chempcba", "chemblpre"]:
            local_data_list += graph_set_partition(glb_data, num_partitions=num_clients_each_dataset[it], root=args.partition_root, mode="pretrain")
         
    # free memory for g/lobal datasets
    dataset_name_list = [i.name.lower() for i in dataset_list]
    del dataset_list
    

    prompts = GPFplusAtt(in_channels=768, p_num=3).to(device)
    init_cache_root = osp.join(osp.dirname(__file__), 'ckpts', 'init', "_".join([f"{dataset_name}_{num_clients_each_dataset[i]}" for i, dataset_name in enumerate(dataset_name_list)]))
    args.init_cache_root = init_cache_root
    check_path(args.init_cache_root)
    
    
    num_clients = sum(num_clients_each_dataset)
    clients = []
    for client_id in range(num_clients):
        client = Client(args, client_id, local_data_list[client_id], copy.deepcopy(prompts), device)
        clients.append(client)
        
    server = Server(args, device)
    
    
    # initialization
    for client in clients:
        client.initialization()
    domain_prototypes = torch.vstack([i.domain_prototype for i in clients])
    server.initialization(domain_prototypes)
    

    
    
    # pre-training
    local_messages = {}

    
    for round_id in range(50):
        for client in clients:
            client.set_pretrain_model(server.get_global_message())
            
        
        for client_id in range(num_clients):
            # client execute:
            print(f"Round {round_id}, Client {client_id}")
            clients[client_id].pretrain()
            local_messages[f"client_{client_id}"] = clients[client_id].get_pretrain_model()
        
        
        
        # server execute:
        server.execute(local_messages)

        
    
    
    
    
        
    # save models
    model_path = osp.join(osp.dirname(__file__), 'ckpts', 'fedgfm', "_".join([f"{dataset_name}_{num_clients_each_dataset[i]}" for i, dataset_name in enumerate(dataset_name_list)]))
    for client_id in range(num_clients):
        client_save_path = osp.join(model_path, f"client_{client_id}")
        check_path(client_save_path)
        try:
            clients[client_id].pretrain_model.save_encoder(osp.join(client_save_path, f"encoder.pt"))
            clients[client_id].pretrain_model.save_vq(osp.join(client_save_path, f"vq.pt"))
            torch.save(clients[client_id].prompts.state_dict(), osp.join(client_save_path, f"prompts.pt")) # local knowledge
            print("Client Save the model")
        except:
            print("Failed to save the model")
    server_save_path = osp.join(model_path, "server")
    check_path(server_save_path)
    try:
        server.global_model.save_encoder(osp.join(server_save_path, f"encoder.pt"))
        server.global_model.save_vq(osp.join(server_save_path, f"vq.pt"))
        print("Server Save the model")
    except:
        print("Failed to save the model")
    