import argparse
import collections
import wandb
import numpy as np
import torch
import torch.optim as optim

from conformal import *

from utils import (average_models, client_update, local_gen_update, local_gvae_update, local_edge_pred_update, evaluate_model, graph_mend_vgae,
                   make_gnn_model, prepareData_oneDS, generate_node_feat_from_cluster_centers, print_set_sizes_corrected)

from vae_models import DeepVGAE, VAE_FeatureGen_Sage, VAE_FeatureGen_Linear, VAE_FeatureGen_GCN, AE_FeatureGen_Linear, GraphSAGEEdgePredictor

from torch_kmeans import KMeans
from torch_geometric.utils import train_test_split_edges
from torch_geometric.loader import DataLoader

from opacus import PrivacyEngine
import warnings
import gc
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

from utils import device

def parse_args():
    parser = argparse.ArgumentParser()
    # main
    parser.add_argument("--dataset", default="Cora", type=str)
    parser.add_argument("--num_client", default=10, type=int)
    parser.add_argument("--central", action="store_true")
    parser.add_argument("--architecture", default="gcn", type=str)
    parser.add_argument("--partition", default ="metis", type =str)
    parser.add_argument("--dropout", default=0.5, type=float)

    # options
    parser.add_argument("--use_vae_gen", default=True, action="store_true")

    # general
    parser.add_argument("--batch_size", default=32, type=int)
    parser.add_argument("--seed", default=42, type=int)
    parser.add_argument('--weight_decay', type=float, default=5e-4,
                        help='Weight decay (L2 loss on parameters).')
    parser.add_argument("--data_dir", default="../datasets/", type=str)
    parser.add_argument("--save_dir", default="experiments", type=str)

    # fedgnn
    parser.add_argument("--rounds_stage1", default=500, type=int)
    parser.add_argument("--local_epochs_stage1", default=1, type=int)
    parser.add_argument("--local_lr_stage1", default=0.01, type=float)

    # sparsity loss
    parser.add_argument("--sparsity_level", default=0.05, type=float)
    parser.add_argument("--beta", default=3, type=float)

    # vaegen
    parser.add_argument("--enc_hidden_channels", default=64, type=int)
    parser.add_argument("--enc_out_channels", default=16, type=int)
    parser.add_argument("--n_clusters", default=20, type=int)
    parser.add_argument("--feat_loss_weight", default=1, type=float)
    parser.add_argument("--conf_score", default="APS", type=str)
    parser.add_argument("--decoder", default="linear", type=str)
    parser.add_argument("--t_digest", action="store_true")
    parser.add_argument("--add_percentage", default=0.05, type=float)

    # differential-privacy
    parser.add_argument("--target_epsilon", default=25, type=int)
    parser.add_argument("--target_delta", default=1e-5, type=float)
    parser.add_argument("--max_grad_norm", default=10, type=float)

    args = vars(parser.parse_args())
    return args

def main(config= None):
    wandb.init( project = 'clean code', entity='akgul')
    args  = wandb.config

    dataset = args["dataset"]
    num_clients = args["num_client"]
    architecture = args["architecture"]
    partition = args["partition"]
    seed = args["seed"]
    batch_size = args["batch_size"]
    weight_decay = args["weight_decay"]
    data_dir = args["data_dir"]
    epoch_stage1 = args["rounds_stage1"]
    local_epochs_stage1 = args["local_epochs_stage1"]
    lr = args["local_lr_stage1"]
    hidden_channels = args["enc_hidden_channels"]
    out_channels = args["enc_out_channels"]
    n_clusters = args["n_clusters"]
    feat_loss_weight = args["feat_loss_weight"]
    sparsity_level = args["sparsity_level"]
    beta = args["beta"]
    add_percentage = args["add_percentage"]
    decoder = args["decoder"]

    seed = 42
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    splitedData, val_mask, tst_mask, client_label_map, in_channels, num_classes = prepareData_oneDS(data_dir,  dataset, num_client= num_clients, batchSize= batch_size, partition = partition)

    print("Federated Data loading complete!")
    print(client_label_map)

    client_datas = []
    client_loaders = []
    for i in range(num_clients):
        client_datas.append(splitedData[i]['client_data'])
        client_loaders.append(splitedData[i]['loader'])

    global_model = make_gnn_model(architecture = architecture, in_channels =  in_channels, num_classes = num_classes).to(device)
    client_models = [make_gnn_model(architecture = architecture, in_channels =  in_channels, num_classes = num_classes).to(device) for _ in range(num_clients)]
    optimizers = [optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) for model in client_models]

    for r in range(1, epoch_stage1 + 1):
        for model in client_models:
            model.load_state_dict(global_model.state_dict())

        loss = 0
        for i in range(num_clients):
            loss += client_update(client_models[i], optimizers[i], client_loaders[i], train_mask = client_datas[i].train_mask, epoch=local_epochs_stage1)

        loss /= num_clients
        average_models(global_model, client_models)

        clients_train_loss, clients_train_acc = 0, 0
        for i in range(num_clients):
            train_loss, train_acc = evaluate_model(model = client_models[i], data_loader = client_loaders[i], mask = client_datas[i].train_mask)
            clients_train_loss += train_loss
            clients_train_acc += train_acc

        clients_train_loss /= num_clients
        clients_train_acc /= num_clients
        wandb.log({'clients_tr_loss_after_fedgnn' : clients_train_loss , 'clients_train_acc_after_fedgnn' : clients_train_acc})

        clients_test_loss = 0
        clients_test_acc = 0.0
        total_num_test_samples = 0
        clients_test_max_score = []
        for i, model in enumerate(client_models):
            test_loss, test_acc, logits, _ = evaluate_model(model, data_loader = client_loaders[i], mask = client_datas[i].test_mask, return_logits=True)
            clients_test_loss += test_loss
            total_test_samples = torch.sum(client_datas[i].test_mask).item()
            clients_test_acc += test_acc * total_test_samples
            total_num_test_samples += total_test_samples
            clients_test_max_score.append(torch.softmax(logits, 1).max(1).values)
            wandb.log( {f"cli_{i}_test_loss_after_fedgnn" : test_loss, f"cli_{i}_test_acc_after_fedgnn" : test_acc, f"cli_{i}_test_max_score_after_fedgnn" : torch.softmax(logits, 1).max(1).values})
        wandb.log( {f"averaged_test_acc_after_fedgnn" : clients_test_acc / total_num_test_samples})

        clients_test_loss /= num_clients
        clients_test_max_score = torch.cat(clients_test_max_score).tolist()
        wandb.log({"client_test_loss_after_fedgnn" : clients_test_loss, "client_test_acc_after_fedgnn" : clients_test_acc})

        global_train_loss, global_train_acc = 0, 0
        for i in range(num_clients):
            train_loss, train_acc = evaluate_model(global_model, data_loader = client_loaders[i], mask = client_datas[i].train_mask)
            global_train_loss += train_loss
            global_train_acc += train_acc

        global_train_loss /= num_clients
        global_train_acc /= num_clients
        global_test_loss, global_test_acc = evaluate_model( global_model, data_loader = splitedData[i]['glob_loader'], mask = tst_mask)
        wandb.log({ "tr_loss_after_fedgnn":global_train_loss, "tst_loss_after_fedgnn":global_test_loss, "tr_acc_after_after_fedgnn":global_train_acc, "tst_acc_after_fedgnn":global_test_acc})

        if r % 50 == 0:
            val_loss, val_acc = evaluate_model(global_model, data_loader = splitedData[i]['glob_loader'], mask = val_mask)
            test_loss, test_acc = evaluate_model(global_model, data_loader = splitedData[i]['glob_loader'], mask = tst_mask)
            print("stage 1 val: ", val_acc, ", stage 1 test: ", test_acc)
            wandb.log({ "val_loss_after_fedgnn":val_loss, "tst_loss_after_fedgnn":test_loss, "val_acc_after_fedgnn":val_acc, "tst_acc_after_fedgnn":test_acc})

    print("Completed FedGNN Training.")
    for alpha in [0.05, 0.1, 0.2, 0.3, 0.4, 0.5]:
        lac, aps, raps  = print_set_sizes_corrected(num_clients, client_datas, global_model, client_loaders, alpha=alpha, t_digest=args["t_digest"])
        print('lac_after_fedgnn', lac , 'aps_after_fedgnn', aps, 'raps_after_fedgnn', raps)
        wandb.log({f'lac_after_fedgnn_{alpha}': lac, f'aps_after_fedgnn_{alpha}': aps, f'raps_after_fedgnn_{alpha}': raps})

    ### Node Generation Part ###
    if decoder == 'linear':
        local_feat_gen_models = [VAE_FeatureGen_Linear(enc_in_channels=in_channels, enc_hidden_channels=hidden_channels, enc_out_channels=out_channels, weight=feat_loss_weight).to(device) for _ in range(num_clients)]
    elif decoder == 'sage':
        local_feat_gen_models = [VAE_FeatureGen_Sage(enc_in_channels=in_channels, enc_hidden_channels=hidden_channels, enc_out_channels=out_channels, weight=feat_loss_weight).to(device) for _ in range(num_clients)]
    elif decoder == 'ae':
        local_feat_gen_models = [AE_FeatureGen_Linear(enc_in_channels=in_channels, enc_hidden_channels=hidden_channels, enc_out_channels=out_channels, weight=feat_loss_weight).to(device) for _ in range(num_clients)]
    else:
        local_feat_gen_models = [VAE_FeatureGen_GCN(enc_in_channels=in_channels, enc_hidden_channels=hidden_channels, enc_out_channels=out_channels, weight=feat_loss_weight).to(device) for _ in range(num_clients)]

    optimizers = [optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) for model in local_feat_gen_models]
    privacy_engines = [None for _ in range(num_clients)]
    # privacy_engines = [PrivacyEngine() for _ in range(num_clients)]
    generated_node_features = torch.zeros((num_clients, n_clusters, in_channels))
    cluster_centers_list = []
    latents = []
    generated_original_nodes = []

    train_loaders = [client_loaders[i] for i in range(num_clients)]

    # Uncomment this block and above privacy_engine definition to activate DP training
    # for i in range(num_clients):
    #     local_feat_gen_models[i], optimizers[i], train_loaders[i] = privacy_engines[i].make_private_with_epsilon(
    #         module=local_feat_gen_models[i],
    #         optimizer=optimizers[i],
    #         data_loader=train_loaders[i],
    #         epochs=epoch_stage1,
    #         target_epsilon=args["target_epsilon"],
    #         target_delta=args["target_delta"],
    #         max_grad_norm=args["max_grad_norm"],
    #     )

    for i, (local_gen, local_optim) in enumerate(zip(local_feat_gen_models, optimizers)):
        loss, latent, generated_x = local_gen_update(local_gen, local_optim, train_loaders[i], privacy_engine=privacy_engines[i], epoch=epoch_stage1, sparsity_level=sparsity_level, beta=beta)
        kmeans = KMeans(n_clusters=n_clusters)
        clusters_assigned = kmeans(generated_x.reshape(1, generated_x.shape[0], generated_x.shape[1]))
        cluster_centers = clusters_assigned.centers.reshape(n_clusters, -1)
        cluster_centers_list.append(cluster_centers)
        generated_original_nodes.append(generated_x)
        generated_node_features[i] = cluster_centers

    generated_node_features = torch.Tensor(generated_node_features.tolist())
    generated_node_features = generated_node_features.reshape(-1, in_channels).to(device)
    print("Completed node generation part.")

    # edge prediction part
    global_vgae_model = DeepVGAE(enc_in_channels=in_channels, enc_hidden_channels=hidden_channels, enc_out_channels=out_channels).to(device)
    vgae_models = [DeepVGAE(enc_in_channels=in_channels, enc_hidden_channels=hidden_channels, enc_out_channels=out_channels).to(device) for _ in range(num_clients)]
    vgae_optimizers = [optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) for model in vgae_models]

    client_datas_with_negative = []
    for i in range(num_clients):
        client_data = client_datas[i]
        all_edge_index = client_data.edge_index
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            client_data = train_test_split_edges(client_data, 0.2, 0.2) #check if this affects the results especially for large number of clients
        client_loader = DataLoader(dataset= [client_data], batch_size=32, pin_memory=False)
        client_datas_with_negative.append((client_data, all_edge_index, client_loader))

    for r in range(1, 5 * epoch_stage1 + 1):
        for model in vgae_models:
            model.load_state_dict(global_vgae_model.state_dict())

        loss = 0
        for i, (vgae_model, vgae_opt) in enumerate(zip(vgae_models, vgae_optimizers)):
            client_data, all_edge_index, client_loader = client_datas_with_negative[i]
            loss += local_gvae_update(vgae_model, vgae_opt, client_loader, client_data, all_edge_index, epoch=local_epochs_stage1)
        loss /= num_clients
        average_models(global_vgae_model, vgae_models)

        if r % 100 == 0:
            print("Loss: ", loss)

    generated_node_features = torch.Tensor(generated_node_features.tolist())
    generated_node_features = generated_node_features.reshape(-1, in_channels).to(device)

    new_client_datas = []
    for i in range(num_clients):
        print("CLIENT ", i)
        client_data = client_datas_with_negative[i][0].to(device)
        edge_index = client_datas_with_negative[i][1].to(device)

        generated_node_features_i = torch.cat((generated_node_features[:i*n_clusters], generated_node_features[(i+1)*n_clusters:]))

        pred_client = global_vgae_model.predict(client_data.x, edge_index, generated_node_features_i)

        pred_client_values = pred_client.cpu().detach().numpy()  # move to CPU if on GPU
        wandb.log({f'Prediction Distribution {i}': wandb.Histogram(pred_client_values)})

        num_edges_to_add = int(add_percentage * edge_index.shape[1])
        flat_pred = pred_client.flatten()
        top_k_indices = torch.topk(flat_pred, num_edges_to_add).indices
        top_k_2d_indices = torch.vstack(torch.unravel_index(top_k_indices, pred_client.shape)).T
        mask = torch.zeros_like(pred_client, dtype=torch.bool)
        mask[top_k_2d_indices[:, 0], top_k_2d_indices[:, 1]] = True

        print("Number of newly formed edges to add:", num_edges_to_add)

        new_client_datas.append(graph_mend_vgae(client_datas_with_negative[i][0],
                                                client_datas_with_negative[i][1],
                                                mask,  # Use the mask to identify edges above threshold
                                                generated_node_features_i,
                                                device))

    # fedgnn training on the mended datasets
    new_loaders = []
    for i in range(num_clients):
        client_data = new_client_datas[i]
        client_loader = DataLoader(dataset=[client_data], batch_size=32, pin_memory=False)
        new_loaders.append(client_loader)

    global_model = make_gnn_model(architecture = architecture, in_channels =  in_channels, num_classes = num_classes).to(device)
    client_models = [make_gnn_model(architecture = architecture, in_channels =  in_channels, num_classes = num_classes).to(device) for _ in range(num_clients)]
    optimizers = [optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay) for model in client_models]

    for r in range(1, epoch_stage1 + 1):
        for model in client_models:
            model.load_state_dict(global_model.state_dict())

        loss = 0
        for i in range(num_clients):
            loss += client_update(client_models[i], optimizers[i], new_loaders[i], train_mask = new_client_datas[i].train_mask, epoch=local_epochs_stage1)

        loss /= num_clients
        average_models(global_model, client_models)

        clients_train_loss, clients_train_acc = 0, 0
        for i in range(num_clients):
            train_loss, train_acc = evaluate_model(model = client_models[i], data_loader = new_loaders[i], mask = new_client_datas[i].train_mask)
            clients_train_loss += train_loss
            clients_train_acc += train_acc

        clients_train_loss /= num_clients
        clients_train_acc /= num_clients
        wandb.log({'clients_tr_loss_after_fedgnn2' : clients_train_loss , 'clients_train_acc_after_fedgnn2' : clients_train_acc})

        clients_test_loss = 0
        clients_test_acc = 0.0
        total_num_test_samples = 0
        clients_test_max_score = []
        for i, model in enumerate(client_models):
            test_loss, test_acc, logits, _ = evaluate_model(model, data_loader = new_loaders[i], mask = new_client_datas[i].test_mask, return_logits=True)
            clients_test_loss += test_loss
            total_test_samples = torch.sum(new_client_datas[i].test_mask).item()
            clients_test_acc += test_acc * total_test_samples
            total_num_test_samples += total_test_samples
            clients_test_max_score.append(torch.softmax(logits, 1).max(1).values)
            wandb.log( {f"cli_{i}_test_loss_after_fedgnn2" : test_loss, f"cli_{i}_test_acc_after_fedgnn2" : test_acc, f"cli_{i}_test_max_score_after_fedgnn2" : torch.softmax(logits, 1).max(1).values})
        wandb.log( {f"averaged_test_acc_after_fedgnn2" : clients_test_acc / total_num_test_samples})

        clients_test_loss /= num_clients
        clients_test_max_score = torch.cat(clients_test_max_score).tolist()
        wandb.log({"client_test_loss_after_fedgnn2" : clients_test_loss, "client_test_acc_after_fedgnn2" : clients_test_acc})

        global_train_loss, global_train_acc = 0, 0
        for i in range(num_clients):
            train_loss, train_acc = evaluate_model(global_model, data_loader = new_loaders[i], mask = new_client_datas[i].train_mask)
            global_train_loss += train_loss
            global_train_acc += train_acc

        global_train_loss /= num_clients
        global_train_acc /= num_clients
        global_test_loss, global_test_acc = evaluate_model( global_model, data_loader = splitedData[i]['glob_loader'], mask = tst_mask)
        wandb.log({ "tr_loss_after_fedgnn2":global_train_loss, "tst_loss_after_fedgnn2":global_test_loss, "tr_acc_after_after_fedgnn2":global_train_acc, "tst_acc_after_fedgnn2":global_test_acc})

        if r % 50 == 0:
            val_loss, val_acc = evaluate_model(global_model, data_loader = splitedData[i]['glob_loader'], mask = val_mask)
            test_loss, test_acc = evaluate_model(global_model, data_loader = splitedData[i]['glob_loader'], mask = tst_mask)
            print("stage 1 val: ", val_acc, ", stage 1 test: ", test_acc)
            wandb.log({ "val_loss_after_fedgnn2":val_loss, "tst_loss_after_fedgnn2":test_loss, "val_acc_after_fedgnn2":val_acc, "tst_acc_after_fedgnn2":test_acc})

    print("Completed VAE_GEN Training.")
    for alpha in [0.05, 0.1, 0.2, 0.3, 0.4, 0.5]:
        lac, aps, raps = print_set_sizes_corrected(num_clients, new_client_datas, global_model, new_loaders, alpha=alpha, t_digest=args["t_digest"])
        print('lac_after_vae', lac , 'aps_after_vae', aps, 'raps_after_vae', raps)
        wandb.log({f'lac_after_vae{alpha}': lac, f'aps_after_vae{alpha}': aps, f'raps_after_vae{alpha}': raps})

    print(end="\n\n\n")
    torch.cuda.empty_cache()
    gc.collect()


if __name__ == "__main__":
    print('cuda', torch.cuda.is_available())
    print(torch.cuda.device_count())
    args = parse_args()

    sweep_config = {
        "method": "random",
        "metric": {"goal": "minimize", "name": "size"},
        "parameters": {
            # general params
            "num_client": {'value': 3},
            "dataset": {"value": "Cora"},
            "architecture": {'values' : ['sage']},
            "central": {'value': False},
            "dropout": {'value' : 0.5} ,
            "conf_score": {'values' : ['APS', 'RAPS' ]},

            # fedgnn params
            "rounds_stage1": {'value' : 500},
            "local_epochs_stage1": {'value' : 1},
            "local_lr_stage1": {'value' : 0.01},
            "batch_size": {'value': 32},

            # models to train
            "use_vae_gen": {'value': True},

            # vaegen params
            "enc_hidden_channels": {'value' : 64},
            "enc_out_channels": {'value' : 16},
            "n_clusters" : {'values' : [2, 5, 10, 20, 30]},
            "sparsity_level" : {'values' : [0.01, 0.05, 0.1, 0.2, 0.4]},
            "beta" : {'values' : [0.1, 1, 3, 5, 10]},
            "feat_loss_weight" : {'values' : [1., 10.]},
            "add_percentage" : {'values' : [0.01, 0.02, 0.04, 0.06, 0.08, 0.10]},
            "t_digest": {'value': False},
            "decoder": {'values' : ['sage', 'linear', 'gcn']},

            # differential privacy
            # "target_epsilon": {'values' : [1, 5, 10, 50, 100, 200]},
            # "target_delta": {'values' : [1e-5, 1e-4, 1e-3]},
            # "max_grad_norm": {'values' : [1.5, 5.5, 8.5]},
        }
    }

    sweep_config['parameters'].update({key :{"value" : val} for key, val in args.items() if key not in sweep_config['parameters'].keys()})
    sweep_config['name'] = f"{args['dataset']}-coverage-{args['num_client']}clients"
    sweep_id = wandb.sweep(sweep = sweep_config, project="coverage",  entity = 'akgul')
    wandb.agent(sweep_id, function=main, count=1)
    wandb.finish()
    # main()
