import argparse
import collections
import wandb
import numpy as np
import torch
import torch.optim as optim

from conformal import *

from utils import (average_models, client_update, local_gen_update, local_gvae_update, local_edge_pred_update, evaluate_model, graph_mend_vgae,
                   make_gnn_model, prepareData_oneDS, generate_node_feat_from_cluster_centers, print_set_sizes_corrected)

from torch_kmeans import KMeans
from torch_geometric.utils import train_test_split_edges
from torch_geometric.loader import DataLoader

from opacus import PrivacyEngine
import warnings
import gc
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

from generator import *            
from utils import device
from gen_utils import hide_graph, train_fedgen, train_fedSagePlusClassifier


def parse_args():
    """Main script for TCT, FedAvg, and Centrally hosted experiments"""
    parser = argparse.ArgumentParser()
    # main
    parser.add_argument("--dataset", default="Cora", type=str)
    parser.add_argument("--num_client", default=5, type=int)
    parser.add_argument("--central", action="store_true")
    parser.add_argument("--architecture", default="gcn", type=str)
    parser.add_argument("--partition", default ="metis", type =str)
    parser.add_argument("--dropout", default=0.5, type=float)

    # general
    parser.add_argument("--batch_size", default=32, type=int)
    parser.add_argument("--seed", default=42, type=int)
    parser.add_argument('--weight_decay', type=float, default=5e-4,
                        help='Weight decay (L2 loss on parameters).')
    parser.add_argument("--data_dir", default="../datasets/", type=str)
    parser.add_argument("--save_dir", default="experiments", type=str)

    # fedgnn
    parser.add_argument("--rounds_stage1", default=400, type=int)
    parser.add_argument("--num_pred", default=5, type=int)
    parser.add_argument("--local_epochs_stage1", default=1, type=int)
    parser.add_argument("--local_lr_stage1", default=0.01, type=float)
    
    # vaegen
    parser.add_argument("--enc_hidden_channels", default=64, type=int)
    parser.add_argument("--enc_out_channels", default=16, type=int)
    parser.add_argument("--t_digest", action="store_true")

    args = vars(parser.parse_args())
    return args

def main(config= None):
    # wandb.init( project = 'clean code', entity='akgul')
    # args  = wandb.config

    dataset = args["dataset"]
    num_clients = args["num_client"]
    architecture = args["architecture"]
    partition = args["partition"]
    dropout = args["dropout"]
    seed = args["seed"]
    batch_size = args["batch_size"]
    weight_decay = args["weight_decay"]
    data_dir = args["data_dir"]
    epoch_stage1 = args["rounds_stage1"]
    local_epochs_stage1 = args["local_epochs_stage1"]
    lr = args["local_lr_stage1"]

    seed = 42
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


    splitedData, val_mask, tst_mask, client_label_map, in_channels, num_classes = prepareData_oneDS(data_dir,  dataset, num_client= num_clients, batchSize= batch_size, partition = partition)

    print("Federated Data loading complete!")
    print(client_label_map)

    client_datas = []
    client_loaders = []
    for i in range(num_clients):
        client_datas.append(splitedData[i]['client_data'])
        client_loaders.append(splitedData[i]['loader'])

    global_model = make_gnn_model(architecture = architecture, in_channels =  in_channels, num_classes = num_classes).to(device)
    client_models = [make_gnn_model(architecture = architecture, in_channels =  in_channels, num_classes = num_classes).to(device) for _ in range(num_clients)]
    optimizers = [optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) for model in client_models]

    for r in range(1, epoch_stage1 + 1):
        for model in client_models:
            model.load_state_dict(global_model.state_dict())

        loss = 0
        for i in range(num_clients):
            loss += client_update(client_models[i], optimizers[i], client_loaders[i], train_mask = client_datas[i].train_mask, epoch=local_epochs_stage1)

        loss /= num_clients
        average_models(global_model, client_models)

        clients_train_loss, clients_train_acc = 0, 0
        for i in range(num_clients):
            train_loss, train_acc = evaluate_model(model = client_models[i], data_loader = client_loaders[i], mask = client_datas[i].train_mask)
            clients_train_loss += train_loss
            clients_train_acc += train_acc

        clients_train_loss /= num_clients
        clients_train_acc /= num_clients
        # wandb.log({'clients_tr_loss_after_fedgnn' : clients_train_loss , 'clients_train_acc_after_fedgnn' : clients_train_acc})

        clients_test_loss = 0
        clients_test_acc, clients_test_max_score = [],  []
        for i, model in enumerate(client_models):
            test_loss, test_acc, logits, _ = evaluate_model(model, data_loader = client_loaders[i], mask = client_datas[i].test_mask, return_logits=True)
            clients_test_loss += test_loss
            clients_test_acc.append(test_acc)
            clients_test_max_score.append(torch.softmax(logits, 1).max(1).values)
            # wandb.log( {f"cli_{i}_test_loss_after_fedgnn" : test_loss, f"cli_{i}_test_acc_after_fedgnn" : test_acc, f"cli_{i}_test_max_score_after_fedgnn" : torch.softmax(logits, 1).max(1).values})

        clients_test_loss /= num_clients
        clients_test_max_score = torch.cat(clients_test_max_score).tolist()
        # wandb.log({"client_test_loss_after_fedgnn" : clients_test_loss, "client_test_acc_after_fedgnn" : clients_test_acc})

        global_train_loss, global_train_acc = 0, 0
        for i in range(num_clients):
            train_loss, train_acc = evaluate_model(global_model, data_loader = client_loaders[i], mask = client_datas[i].train_mask)
            global_train_loss += train_loss
            global_train_acc += train_acc

        global_train_loss /= num_clients
        global_train_acc /= num_clients
        global_test_loss, global_test_acc = evaluate_model( global_model, data_loader = splitedData[i]['glob_loader'], mask = tst_mask)
        # wandb.log({ "tr_loss_after_fedgnn":global_train_loss, "tst_loss_after_fedgnn":global_test_loss, "tr_acc_after_after_fedgnn":global_train_acc, "tst_acc_after_fedgnn":global_test_acc})

        if r % 50 == 0:
            val_loss, val_acc = evaluate_model(global_model, data_loader = splitedData[i]['glob_loader'], mask = val_mask)
            test_loss, test_acc = evaluate_model(global_model, data_loader = splitedData[i]['glob_loader'], mask = tst_mask)
            print("stage 1 val: ", val_acc, ", stage 1 test: ", test_acc)
            # wandb.log({ "val_loss_after_fedgnn":val_loss, "tst_loss_after_fedgnn":test_loss, "val_acc_after_fedgnn":val_acc, "tst_acc_after_fedgnn":test_acc})

    print("Completed FedGNN Training.")
    for alpha in [0.05, 0.1, 0.2, 0.3, 0.4, 0.5]:
        lac, aps, raps  = print_set_sizes_corrected(num_clients, client_datas, global_model, client_loaders, alpha=alpha, t_digest=args["t_digest"])
        print('lac_after_fedgnn', lac , 'aps_after_fedgnn', aps, 'raps_after_fedgnn', raps)
        # wandb.log({f'lac_after_fedgnn_{alpha}': lac, f'aps_after_fedgnn_{alpha}': aps, f'raps_after_fedgnn_{alpha}': raps})

    original_client_datas = []
    hidden_client_datas = []
    for i in range(num_clients):
        org_cli_data = splitedData[i]['client_data']
        hid_cli_data = hide_graph(org_cli_data)

        original_client_datas.append(org_cli_data)
        hidden_client_datas.append(hid_cli_data)

    global_data = splitedData[i]['global_data']

    generators = []
    optimizers=[]
    for i in range(num_clients):
        data = client_datas[i]
        feat_shape = data.x.shape[1]
        node_len = data.x.shape[0]
        num_classes = len(global_data.y.unique())

        neighbor_gen = FedSage_Plus(feat_shape, node_len, num_classes, args).to(device)
        generators.append(neighbor_gen)
        optimizers.append(torch.optim.Adam(neighbor_gen.parameters(), lr=lr, weight_decay=weight_decay))

    unique_class_labels = global_data.y.unique().tolist()
    train_fedgen(generators, optimizers, in_channels, hidden_client_datas, original_client_datas, unique_class_labels, args)

    classifiers=[]
    for i in range(num_clients):
        classifier = make_gnn_model(architecture = architecture, in_channels = in_channels, num_classes = num_classes).to(device)
        classifiers.append(classifier)

    mended_graph_loaders, mended_graph_list, global_model, _ = train_fedSagePlusClassifier(classifiers, generators, hidden_client_datas, original_client_datas, data.x.shape[1], unique_class_labels, args)

    val_loss, val_acc = evaluate_model(model = global_model, data_loader = mended_graph_loaders[0], mask = torch.where(hidden_client_datas[0].val_mask))
    test_loss, test_acc = evaluate_model(model = global_model, data_loader = mended_graph_loaders[0], mask = torch.where(hidden_client_datas[0].test_mask))
    print("Val Acc:", val_acc, " Test Acc:", test_acc)

    val_loss, val_acc = evaluate_model(model = global_model, data_loader = mended_graph_loaders[1], mask = torch.where(hidden_client_datas[1].val_mask))
    test_loss, test_acc = evaluate_model(model = global_model, data_loader = mended_graph_loaders[1], mask = torch.where(hidden_client_datas[1].test_mask))
    print("Val Acc:", val_acc, " Test Acc:", test_acc)

    val_loss, val_acc = evaluate_model(model = global_model, data_loader = mended_graph_loaders[2], mask = torch.where(hidden_client_datas[2].val_mask))
    test_loss, test_acc = evaluate_model(model = global_model, data_loader = mended_graph_loaders[2], mask = torch.where(hidden_client_datas[2].test_mask))
    print("Val Acc:", val_acc, " Test Acc:", test_acc)

    val_loss, val_acc = evaluate_model(model = global_model, data_loader = mended_graph_loaders[3], mask = torch.where(hidden_client_datas[3].val_mask))
    test_loss, test_acc = evaluate_model(model = global_model, data_loader = mended_graph_loaders[3], mask = torch.where(hidden_client_datas[3].test_mask))
    print("Val Acc:", val_acc, " Test Acc:", test_acc)

    val_loss, val_acc = evaluate_model(model = global_model, data_loader = mended_graph_loaders[4], mask = torch.where(hidden_client_datas[4].val_mask))
    test_loss, test_acc = evaluate_model(model = global_model, data_loader = mended_graph_loaders[4], mask = torch.where(hidden_client_datas[4].test_mask))
    print("Val Acc:", val_acc, " Test Acc:", test_acc)

    print(end="\n\n\n")
    torch.cuda.empty_cache()
    gc.collect()


if __name__ == "__main__":
    print('cuda', torch.cuda.is_available())
    print(torch.cuda.device_count())
    args = parse_args()

    sweep_config = {
        "method": "random",
        "metric": {"goal": "minimize", "name": "size"},
        "parameters": {
            # general params
            "num_client": {'value': 5},
            "dataset": {"value": "Cora"},
            "architecture": {'values' : ['sage']},
            "central": {'value': False},
            "dropout": {'value' : 0.5} ,
            "conf_score": {'values' : ['APS', 'RAPS' ]},

            # fedgnn params
            "rounds_stage1": {'value' : 400},
            "local_epochs_stage1": {'value' : 1},
            "local_lr_stage1": {'value' : 0.01},
            "batch_size": {'value': 32},

            # vaegen params
            "enc_hidden_channels": {'value' : 64},
            "num_pred": {'values' : [2, 5, 10]},
            "enc_out_channels": {'value' : 16},
            "t_digest": {'value': False},
        }
    }

    # sweep_config['parameters'].update({key :{"value" : val} for key, val in args.items() if key not in sweep_config['parameters'].keys()})
    # sweep_config['name'] = f"{args['dataset']}-different_sparsity_and_beta-{args['num_client']}clients"
    # sweep_id = wandb.sweep(sweep = sweep_config, project="gcn_linear_sage_decoder_ve_different_parametreler",  entity = 'akgul')
    # wandb.agent(sweep_id, function=main, count=30)
    # wandb.finish()
    main()
