import torch
import torch.nn as nn
import torch.nn.functional as F
import os.path as osp
import argparse
from utils.basic_utils import *
from flcore.fedgfm.server import Server
from flcore.fedgfm.client import Client
import yaml
from utils.partition_utils import graph_set_partition, single_graph_partition
from flcore.fedgfm.client import GPFplusAtt

import copy

import torch
import torch.nn as nn
import torch.nn.functional as F

class GPFEnsemble(nn.Module):
    def __init__(self, prompt_models: list, freeze=False):
        super(GPFEnsemble, self).__init__()
        self.prompt_models = nn.ModuleList(prompt_models)
        
        if freeze:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, x: torch.Tensor):
        score_all = [prompt.get_score(x) for prompt in self.prompt_models]
        norm_score_all = F.softmax(torch.hstack(score_all), dim=1)

        
        
        p_all = [
            norm_score_all[:, i * score_all[0].shape[1] : (i + 1) * score_all[0].shape[1]].mm(
                self.prompt_models[i].p_list
            )
            for i in range(len(self.prompt_models))
        ]

        result = x.clone()
        for p in p_all:
            result = result + p

        return result



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--gpu_id", type=int, default=1)
    parser.add_argument("--seed", type=int, default=2025)
    parser.add_argument("--data_config", type=str, default='/home/ai/xkli/FedGFM-main/config/pretrain_config.json')
    parser.add_argument("--partition_root", type=str, default="/home/ai/xkli/FedGFM-main/partition")
    parser.add_argument("--k_shot", type=int, default=2) # default rich label
    
    args = parser.parse_args()
    
    seed_everything(args.seed)
    device = torch.device(f"cuda:{args.gpu_id}")
    torch.cuda.set_device(args.gpu_id)
    
    # load global dataset
    
    
    num_clients_each_dataset, dataset_list = load_pretrain_mydataset_ofa(args, mode="finetune")

        
    local_data_list = []
    pretrain_list = []
    

    
    
    for it, glb_data in enumerate(dataset_list):
        print(f"loading partition for {glb_data.name.lower()}...")
        if glb_data.name.lower() in ["cora","pubmed","wikics","arxiv"]:
            local_data_list += single_graph_partition(glb_data, num_partitions=num_clients_each_dataset[it], task="node_cls", root=args.partition_root)
        elif glb_data.name.lower() in ["wn18rr", "fb15k237"]:
            local_data_list += single_graph_partition(glb_data, num_partitions=num_clients_each_dataset[it], task="link_pre", root=args.partition_root)
        elif glb_data.name.lower() in ["chemhiv","chempcba", "chemblpre"]:
            local_data_list += graph_set_partition(glb_data, num_partitions=num_clients_each_dataset[it], root=args.partition_root, mode="finetune")
    
    
    
    with open(args.data_config, 'r', encoding='utf-8') as file:
        dataset_config_pretrain = json.load(file)
        num_clients_each_dataset_pretrain = dataset_config_pretrain["pretrain_num_clients"]
        dataset_name_list_pretrain = dataset_config_pretrain["datasets"]
    init_cache_root = osp.join(osp.dirname(__file__), 'ckpts', 'init', "_".join([f"{dataset_name.lower()}_3" for i, dataset_name in enumerate(dataset_name_list_pretrain)]))
    args.init_cache_root = init_cache_root
    check_path(args.init_cache_root)
    
    
    # free memory for global datasets
    del dataset_list

    # load finetuning params for each dataset
    with open('/home/ai/xkli/FedGFM-main/config/finetune.yaml', 'r') as file:
        params = yaml.safe_load(file)
    
    params_list = []    
    for idx, data_tag in enumerate(local_data_list):
        if data_tag.name in ["Cora", "Pubmed", "arxiv", "wikics"]:
            client_params = params["node"][data_tag.name.lower()]
        elif data_tag.name in ["FB15K237", "WN18RR"]:
            client_params = params["link"][data_tag.name.upper()]
        else:
            client_params = params["graph"][data_tag.name]
        params_list.append(client_params)
                

    # load partitions
    prompts = GPFplusAtt(in_channels=768, p_num=3).to(device)
    num_clients = sum(num_clients_each_dataset)
    clients = []
    for client_id in range(num_clients):
        client = Client(args, client_id, local_data_list[client_id], copy.deepcopy(prompts) , device, params_list[client_id], mode="finetune")
        clients.append(client)
    server = Server(args, device)
    
    
    
    # load pretrained model
    with open(args.data_config, 'r', encoding='utf-8') as file:
        dataset_config_pretrain = json.load(file)
        num_clients_each_dataset_pretrain = dataset_config_pretrain["pretrain_num_clients"]
        dataset_name_list_pretrain = dataset_config_pretrain["datasets"]
        model_path = osp.join(osp.dirname(__file__), 'ckpts', 'fedgfm', "_".join([f"{dataset.lower()}_{num_clients_each_dataset_pretrain[i]}" for i, dataset in enumerate(dataset_name_list_pretrain)]))
    
    for client_id in range(num_clients):
        # clients[client_id].encoder.load_state_dict(torch.load(os.path.join(model_path, f'server/encoder.pt')))
        # clients[client_id].vq.load_state_dict(torch.load(os.path.join(model_path, f'server/vq.pt')))
        clients[client_id].encoder.load_state_dict(torch.load(os.path.join(model_path, f'client_{client_id}', f'encoder.pt')))
        clients[client_id].vq.load_state_dict(torch.load(os.path.join(model_path, f'client_{client_id}', f'vq.pt')))
        clients[client_id].prompts.load_state_dict(torch.load(os.path.join(model_path, f'client_{client_id}', f'prompts.pt')))
    
    
    # initialization 

    # 1. ensemble prompts
    
    
    for client_id in range(num_clients):
        ensemble_prompt = GPFEnsemble(prompt_models=[copy.deepcopy(client.prompts) for client in clients])
        # if client_id < 3:
        #     ensemble_prompt = GPFEnsemble(prompt_models=[copy.deepcopy(client.prompts) for client in clients[:3]])
        # else:
        #     ensemble_prompt = GPFEnsemble(prompt_models=[copy.deepcopy(client.prompts) for client in clients[3:]])
            
        clients[client_id].initialization(ensemble_prompt)
    
    # final results
    result = {}
    
    
    # isolated finetune
    for client_id in range(num_clients):
        clients[client_id].data_tag = get_k_shot(clients[client_id].data_tag, k=args.k_shot)
        clients[client_id].finetune()
        name = clients[client_id].data_tag.name.lower()
        if name not in result:
            result[name] = {}
        for standard_id in range(1):
            if standard_id not in result[name]:
                result[name][standard_id] = []
            
            single_best = clients[client_id].logger.best[standard_id]
            num_samples = clients[client_id].data_tag.test_mask.nonzero().shape[0]
        
            result[name][standard_id].append((single_best, num_samples))
            
        
                
                
                
        
            
        if (client_id+1) % 3 == 0:
            standard_dict = result[name]
            standard_result = []
            
            for values in standard_dict.values():
                result_test = sum([pair[0]['test'] * pair[1] for pair in values]) / sum([pair[1] for pair in values])
                standard_result.append(result_test)
                
                
            print(f"{name}: {np.mean(standard_result):.2f}±{np.std(standard_result):.2f}")
    
