import argparse
import gc
import os
import sys
import warnings
import torch
from sklearn import preprocessing
from stellargraph import datasets

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from src.global_task import Global
from src.classifiers import Classifier
from src.dataowner import DataOwner
from src.mal_locsage_plus_GraphSage import train_gen_fed, test_classifier_plus_graphsage, \
    train_federated_classifier_graphsage, mend_graph_pyg, train_gen_local
from src.mal_locsage_plus_GraphSage import test_classifier_graphsage

from src.mal_locsage_plus_GraphSage import train_classifier_graphsage
from src.train_locSagePlus import LocalOwner
from src.utils import config
from src.utils.load_ms import load_from_npz
from src.utils.louvain_networkx import louvain_graph_cut
from src.utils.save_and_load_models import save_model, load_model, save_config_to_txt
from src.utils.seed_utils import set_seed_torch

warnings.filterwarnings('ignore')

parser = argparse.ArgumentParser(description="Run experiments with multiple param values")
parser.add_argument('--intensities', type=float, nargs='+', default=[1])
args = parser.parse_args()

def set_up_system():
    if config.attack_intensity == 0 or config.num_attacker == 0:
        config.num_attacker = 0
        config.attack_intensity = 0

    if config.phase != 'test':
        os.makedirs(config.save_path, exist_ok=True)
        save_config_to_txt(os.path.join(config.save_path, 'experiment_config.txt'))
    set_seed_torch(seed=config.seed)
    if config.dataset == 'cora':
        dataset = datasets.Cora()
        G, node_subjects = dataset.load()
    elif config.dataset == 'citeseer':
        dataset = datasets.CiteSeer()
        G, node_subjects = dataset.load()
    elif config.dataset == 'pubmeddiabetes':
        dataset = datasets.PubMedDiabetes()
        G, node_subjects = dataset.load()
    elif config.dataset == 'msacademic':
        G, node_subjects = load_from_npz(config.root_path + 'other_datasets/ms_academic.npz')
    else:
        print("dataset name does not exist!")
        return

    target_encoding = preprocessing.LabelBinarizer()
    target_encoding.fit_transform(node_subjects)
    global_targets = target_encoding.fit_transform(node_subjects)
    all_classes = target_encoding.classes_
    global_task = Global(G, node_subjects, global_targets)

    dataowner_list = []
    local_G, local_subj, local_target, local_nodes_ids = louvain_graph_cut(G, node_subjects, graph_split_seed=config.graph_split_seed)

    for owner_i in range(config.num_owners):
        do_i = DataOwner(do_id=owner_i, subG=local_G[owner_i], sub_ids=local_nodes_ids[owner_i],
                         node_subj=local_subj[owner_i],
                         node_target=local_target[owner_i])
        do_i.get_edge_nodes()
        do_i.set_classifier_path()
        do_i.set_gan_path()
        do_i.save_do_info()
        dataowner_list.append(do_i)

    # begin train local pre-train
    local_owners, local_classifiers = [], []
    for owner_i in range(config.num_owners):
        do_i = dataowner_list[owner_i]
        local_classifier = Classifier(hasG=do_i.hasG,
                                      all_classes=all_classes,
                                      has_node_subjects=do_i.has_subj,
                                      acc_path=do_i.test_acc_path, classifier_path=do_i.classifier_path,
                                      downstream_task_path=do_i.downstream_task_path)
        local_gen = LocalOwner(do_id=owner_i, subG=do_i.hasG, node_subjects=do_i.has_subj,
                               all_classes=all_classes,
                               num_samples=config.num_samples,
                               model_path=[do_i.fedgen_model_path, do_i.gen_model_path], reg_model_type=config.reg_model_type)
        # plot_degree_distribution(local_gen.subG.to_networkx(), config.dataset)
        local_classifier.set_classifiers_torch(classifier_path=do_i.classifier_path, dataowner=do_i,
                                               hasG_hide=local_gen.hasG_hide)
        local_classifier.model = local_classifier.build_classifier_torch()
        local_owners.append(local_gen)
        local_classifiers.append(local_classifier)

    # initialize GENs
    for idx in range(len(local_owners)):
        if idx < config.num_attacker:
            local_owners[idx].set_aux_model()
        local_owners[idx].set_fed_model()

    test_ids = []
    for do_i in range(config.num_owners):
        test_ids_i = local_classifiers[do_i].test_subjects.index
        for id_i in test_ids_i:
            test_ids.append(id_i)

    global_task.set_test_ids(test_ids)

    if config.phase == 'test':
        for i in range(config.num_owners):
            if config.plus:
                local_owners[i].fed_model = load_model(local_owners[i].fed_model, os.path.join(config.save_path, f'client_{i}_gen.pt')).cuda()
            classifier_model = torch.load(os.path.join(config.save_path, f'client_{i}_classifier.pt'))
            local_classifiers[i].model.load_state_dict(classifier_model)
    elif config.phase == 'train_cls':
        if config.plus:
            for i in range(config.num_owners):
                local_owners[i].fed_model = load_model(local_owners[i].fed_model,
                                                           os.path.join(config.save_path, f'client_{i}_gen.pt')).cuda()
            mend_graph_pyg(local_classifiers, local_owners)
        train_federated_classifier_graphsage(local_classifiers)
    else:

        # train classifier (as victim model)
        if config.plus and config.num_attacker > 0:
            print('==malicious clients train classifiers==============')
            if config.benign_gen_path != '':
                local_owners[0].fed_model = load_model(local_owners[0].fed_model,
                                                       config.benign_gen_path).cuda()
            else:
                train_gen_local(local_classifiers[:config.num_attacker], local_owners[:config.num_attacker],
                                pre_train=True)
            mend_graph_pyg(local_classifiers[:config.num_attacker], local_owners[:config.num_attacker])
            train_classifier_graphsage(local_classifiers[:config.num_attacker])

            if config.benign_gen_path == '':
                save_model(local_owners[0].fed_model, os.path.join(config.save_path, f'attacker_benign_gen.pt'))
            local_owners[0].set_fed_model()

        # train gen
        if config.plus:
            print('==train gen=================')
            train_gen_fed(local_classifiers, local_owners)

            # federated train classifiers
            print('==federated train classifiers=================')
            mend_graph_pyg(local_classifiers, local_owners)
        train_federated_classifier_graphsage(local_classifiers)

    # evaluate classifiers
    print('==evaluate classifiers===================')
    if config.plus:
        test_classifier_plus_graphsage(local_classifiers, local_owners)
        if config.phase == 'train_global_attack' and config.save_model:
            for i in range(config.num_owners):
                save_model(local_owners[i].fed_model, os.path.join(config.save_path, f'client_{i}_gen.pt'))
    else:
        test_classifier_graphsage(local_classifiers)

    if config.phase == 'train_global_attack' and config.save_model:
        for i in range(config.num_owners):
            torch.save(local_classifiers[i].model.state_dict(), os.path.join(config.save_path, f'client_{i}_classifier.pt'))

for intensity in args.intensities:
    gc.collect()
    if torch.cuda.is_available():
        # Empty the CUDA cache
        torch.cuda.empty_cache()
        # Synchronize to ensure all CUDA operations are complete
        torch.cuda.synchronize()
        torch.cuda.empty_cache()
    if config.phase != 'test':
        config.attack_intensity = intensity
        config.save_path = f'checkpoints\\sadtest_resetGen_{config.dataset}_{config.num_owners}clients_{config.num_attacker}attacker_intensity={config.attack_intensity}_PreMalEpoch={config.pre_benign_epoch}_genEpoch={config.gen_epochs}'
    set_up_system()
