import csv
import os.path
import time
import gc
import numpy as np
import torch
import torch.nn.functional as F
from torch import optim

from src.models import feat_loss
from src.utils import config
from src.utils.convert2pyg import convert_stellargraph_to_pyg
from src.utils.mending_graph import fill_graph_pytorch

from src.utils.save_and_load_models import save_experiment_results


def train_benign_gen(owner, optimizer, classifier=None, pre_train=False):
    local_model = owner.aux_model if pre_train else owner.fed_model
    input_feat = owner.all_feat
    input_edge = owner.edges
    input_adj = owner.adj

    # train
    local_model.train()
    optimizer.zero_grad()

    node_embeddings, output_missing, output_feat, logits = local_model(input_feat, input_edge, input_adj)
    output_nc = F.relu(logits)
    output_missing = torch.flatten(output_missing)
    output_feat = output_feat.view(len(owner.all_ids), owner.num_pred, owner.feat_shape)

    loss_train_missing = F.smooth_l1_loss(output_missing[owner.train_ilocs].float(),
                                          owner.all_targets_missing[owner.
                                          train_ilocs].reshape(-1).float())

    loss_train_feat = feat_loss.greedy_loss(output_feat[owner.train_ilocs],
                                            owner.all_targets_feat[owner.train_ilocs],
                                            output_missing[owner.train_ilocs],
                                            owner.all_targets_missing[
                                                owner.train_ilocs
                                            ]).unsqueeze(0).mean().float()

    true_nc_label = torch.argmax(owner.all_targets_subj[owner.train_ilocs], dim=1).view(-1)
    if config.cuda:
        true_nc_label = true_nc_label.cuda()
    loss_train_label = F.cross_entropy(output_nc[owner.train_ilocs], true_nc_label)

    loss = (config.a * loss_train_missing + config.b * loss_train_feat + config.c * loss_train_label).float()

    return output_missing, output_feat, loss_train_feat, loss, loss_train_missing


def train_mal_gen(owner, classifier, optimizer):
    local_model = owner.fed_model
    input_feat = owner.all_feat
    input_edge = owner.edges
    input_adj = owner.adj

    # train
    local_model.train()
    optimizer.zero_grad()


    _, output_missing, output_feat, logits = local_model(input_feat, input_edge, input_adj)

    impaired_graph_pyg = convert_stellargraph_to_pyg(owner.subG, classifier.train_subjects, classifier.test_subjects,
                                                     classifier.train_targets, classifier.test_targets).cuda()

    fillG_pyg = fill_graph_pytorch(impaired_graph_pyg, output_missing, output_feat, owner.feat_shape).cuda()

    out = classifier.model(fillG_pyg.x, fillG_pyg.edge_index)

    loss = -F.cross_entropy(out[fillG_pyg.train_mask], fillG_pyg.y[fillG_pyg.train_mask])

    return None, None, None, loss, None


def train_gen_fed(local_classifiers:list, local_owners:list, target_index=None, attacker_classifier=None):
    optim_list=[]
    lowest_acc = []
    losses = []
    for local_i in local_owners:
        optim_list.append(optim.Adam(local_i.fed_model.parameters(),
                                  lr=config.lr, weight_decay=config.weight_decay))
        lowest_acc.append(1)

    # clean OR target attack
    if config.num_attacker == 0 or target_index is not None:
        gc.collect()
        torch.cuda.empty_cache()
        # clean OR target attack from 0
        if config.num_attacker == 0 or config.phase == 'target_train_from0':
            for epoch in range(config.gen_epochs):
                for i in range(config.num_owners):
                    output_missing, output_feat, loss_train_feat, loss, loss_train_missing = train_benign_gen(local_owners[i], optim_list[i], local_classifiers[i])
                    losses.append(loss.cpu().data)

                    # exchange gradients
                    for j in range(config.num_owners):
                        if i != j:
                            choice = np.random.choice(len(list(local_owners[j].subG.nodes())),
                                                      len(local_owners[i].train_ilocs))
                            others_ids = local_owners[j].subG.nodes()[choice]
                            global_target_feat = []
                            for c_i in others_ids:
                                neighbors_ids = local_owners[j].subG.neighbors(c_i)
                                while len(neighbors_ids) == 0:
                                    c_i = np.random.choice(len(list(local_owners[j].subG.nodes())), 1)[0]
                                    id_i = local_owners[j].subG.nodes()[c_i]
                                    neighbors_ids = local_owners[j].subG.neighbors(id_i)
                                choice_i = np.random.choice(neighbors_ids, config.num_pred)
                                for ch_i in choice_i:
                                    global_target_feat.append(local_owners[j].subG.node_features([ch_i])[0])
                            global_target_feat = np.asarray(global_target_feat).reshape(
                                (len(local_owners[i].train_ilocs), config.num_pred, local_owners[i].feat_shape))
                            loss_train_feat_other = feat_loss.greedy_loss(output_feat[local_owners[i].train_ilocs],
                                                                          global_target_feat,
                                                                          output_missing[local_owners[i].train_ilocs],
                                                                          local_owners[i].all_targets_missing[
                                                                              local_owners[i].train_ilocs]
                                                                          ).unsqueeze(0).mean().float()
                            loss += config.b * loss_train_feat_other.detach()

                        loss = 1.0 / config.num_owners * loss
                    loss.backward()
                    print('Data owner ' + str(i),
                          ' Epoch: {:04d}'.format(epoch + 1),
                          'loss: {:.4f}'.format(loss.data),
                          'degree loss: {:.4f}'.format(loss_train_missing.data),
                          'feature loss: {:.4f}'.format(loss_train_feat.data))
                    optim_list[i].step()

    # global attack
    else:
        gc.collect()
        torch.cuda.empty_cache()

        sum_eu = np.zeros((config.num_owners, config.num_owners))
        all_eu = np.zeros((config.gen_epochs, config.num_owners, config.num_owners))
        for epoch in range(config.gen_epochs):
            euclidean_results = np.zeros((config.num_owners, config.num_owners))
            original_grads = {}
            received_grads = {}

            mal_grads = []
            # locally train gen
            for i in range(config.num_attacker):
                _, _, _, loss, _ = train_mal_gen(local_owners[i], local_classifiers[i], optim_list[i])
                # losses.append(loss_train_missing.cpu().data)

                # loss.backward()
                # optim_list[i].step()

                loss.backward()


                optim_list[i].step()

                print('Data owner ' + str(i),
                      ' Epoch: {:04d}'.format(epoch + 1),
                      'loss: {:.4f}'.format(loss.data))

            for i in range(config.num_attacker, config.num_owners):
                output_missing, output_feat, loss_train_feat, loss, loss_train_missing = train_benign_gen(local_owners[i], optim_list[i], local_classifiers[i])
                # losses.append(loss_train_missing.cpu().data)


                # get original local gradient for cosine similarity
                loss.backward(retain_graph=True)

                for name, param in local_owners[i].fed_model.gen.named_parameters():
                    if param.grad is not None:
                        original_grads[name] = param.grad.detach().clone()

                # reset local gradient
                local_owners[i].fed_model.zero_grad()

                # get received gradients from other clients
                for j in range(config.num_attacker, config.num_owners):
                    if i != j:
                        choice = np.random.choice(len(list(local_owners[j].subG.nodes())),
                                                  len(local_owners[i].train_ilocs))
                        others_ids = local_owners[j].subG.nodes()[choice]
                        global_target_feat = []
                        for c_i in others_ids:
                            neighbors_ids = local_owners[j].subG.neighbors(c_i)
                            while len(neighbors_ids) == 0:
                                c_i = np.random.choice(len(list(local_owners[j].subG.nodes())), 1)[0]
                                id_i = local_owners[j].subG.nodes()[c_i]
                                neighbors_ids = local_owners[j].subG.neighbors(id_i)
                            choice_i = np.random.choice(neighbors_ids, config.num_pred)
                            for ch_i in choice_i:
                                global_target_feat.append(local_owners[j].subG.node_features([ch_i])[0])
                        global_target_feat = np.asarray(global_target_feat).reshape(
                            (len(local_owners[i].train_ilocs), config.num_pred, local_owners[i].feat_shape))
                        loss_train_feat_other = feat_loss.greedy_loss(output_feat[local_owners[i].train_ilocs],
                                                                      global_target_feat,
                                                                      output_missing[local_owners[i].train_ilocs],
                                                                      local_owners[i].all_targets_missing[
                                                                          local_owners[i].train_ilocs]
                                                                      ).unsqueeze(0).mean().float()
                        # calculate received gradients
                        loss_train_feat_other.backward(retain_graph=True)

                        for name, param in local_owners[i].fed_model.gen.named_parameters():
                            if param.grad is not None:
                                received_grads[name] = param.grad.detach().clone()

                        # calculate cosine similarity
                        total_euclidean = 0
                        num_params, num_elements = 0, 0
                        for name, param in local_owners[i].fed_model.gen.named_parameters():
                            if param.grad is not None:
                                # Calculate Euclidean distance between flattened gradients
                                euclidean_dist = torch.sqrt(
                                    torch.sum((original_grads[name].flatten() - received_grads[name].flatten()) ** 2))
                                total_euclidean += euclidean_dist.item()
                                num_elements += param.numel()
                                num_params += 1
                        avg_euclidean = total_euclidean / num_params
                        euclidean_results[j, i] = avg_euclidean
                        local_owners[i].fed_model.zero_grad()

                loss.backward()
                print('Data owner ' + str(i),
                      ' Epoch: {:04d}'.format(epoch + 1),
                      'loss: {:.4f}'.format(loss.data),
                      'degree loss: {:.4f}'.format(loss_train_missing.data),
                      'feature loss: {:.4f}'.format(loss_train_feat.data))

            # calculate the malicious gradients that will be sent to benign clients
            sum_benign_grads = []
            attacker_model = local_owners[0].fed_model

            # Initialize sum_benign_grads as a list to match parameters
            for param in attacker_model.parameters():
                if param.requires_grad:
                    sum_benign_grads.append(torch.zeros_like(param))

            # get sum of benign gradients
            for j in range(config.num_attacker, config.num_owners):
                for k, (name, param) in enumerate(local_owners[j].fed_model.named_parameters()):
                    if param.requires_grad:
                        sum_benign_grads[k] += param.grad.detach().clone()

            for j in range(config.num_attacker, config.num_owners):
                benign_model = local_owners[j].fed_model
                mal_grad = []
                for (bp, mp, bg) in zip(benign_model.parameters(), attacker_model.parameters(), sum_benign_grads):
                    grad = config.num_owners * (bp.data - mp.data) / (config.lr * config.b) - bg
                    mal_grad.append(grad.detach())
                mal_grads.append(mal_grad)

            # exchange and assign gradients
            for i in range(config.num_attacker, config.num_owners):
                # exchange gradients of attacker and client i
                target_model = local_owners[i].fed_model

                original_grads = {}
                received_grads = {}
                for idx, (name, param) in enumerate([
                    (name, p) for name, p in target_model.named_parameters()
                    if p.requires_grad
                ]):
                    grad_mal = mal_grads[i - config.num_attacker][idx]
                    original_grads[name] = param.grad.detach().clone()
                    received_grads[name] = config.attack_intensity * config.b * grad_mal.detach()

                    param.grad += config.attack_intensity * config.b * (grad_mal + sum_benign_grads[idx]) / config.num_owners

                total_cos_sim, total_euclidean = 0, 0
                num_params, num_elements = 0, 0
                for name, param in target_model.named_parameters():
                    if param.grad is not None and 'gen' in name:
                        cos_sim = F.cosine_similarity(original_grads[name].flatten().unsqueeze(0),
                                                      received_grads[name].flatten().unsqueeze(0))
                        total_cos_sim += cos_sim.item() # .item() gets the scalar value
                        euclidean_dist = torch.sqrt(
                            torch.sum((original_grads[name].flatten() - received_grads[name].flatten()) ** 2))
                        total_euclidean += euclidean_dist.item()
                        num_params += 1
                avg_euclidean = total_euclidean / num_params
                euclidean_results[0, i] = avg_euclidean
            sum_eu += euclidean_results
            all_eu[epoch] = euclidean_results

            # parameters update
            for i in range(config.num_attacker, config.num_owners):
                optim_list[i].step()

    # # Save all losses at once after all epochs complete
    # filename = f'local_result/seed={config.seed}_degree_loss.npy'
    # np.save(filename, np.array(losses))
    if target_index is None and config.num_attacker > 0:
        avg_sum_eu = sum_eu / config.gen_epochs
        np.save(os.path.join(config.save_path, 'eu.npy'), avg_sum_eu)
        np.save(os.path.join(config.save_path, 'all_eu.npy'), all_eu)


def train_gen_local(local_classifiers:list, local_owners:list, pre_train=False, is_aux=False):
    optim_list=[]
    t=time.time()
    for local_i in local_owners:
        if pre_train:
            optim_list.append(optim.Adam(local_i.aux_model.parameters(),
                                  lr=config.lr, weight_decay=config.weight_decay))
        else:
            optim_list.append(optim.Adam(local_i.fed_model.parameters(),
                                  lr=config.lr, weight_decay=config.weight_decay))

    gc.collect()
    torch.cuda.empty_cache()
    if pre_train:
        for epoch in range(config.pre_benign_epoch):

            for i in range(len(local_classifiers)):
                output_missing, output_feat, loss_train_feat, loss, loss_train_missing = train_benign_gen(
                    local_owners[i], optim_list[i], local_classifiers[i], pre_train=True)

                loss.backward()
                optim_list[i].step()
                print('Data owner ' + str(i),
                      ' Epoch: {:04d}'.format(epoch + 1),
                      'loss: {:.4f}'.format(loss.item()),
                      'loss_feat: {:.8f}'.format(loss_train_feat.item()))
    else:
        for epoch in range(config.pre_malicious_epoch):
            if config.num_attacker > 0:

                for i in range(config.num_attacker):
                    output_missing, output_feat, loss_train_feat, loss, loss_train_missing = train_mal_gen(local_owners[i],
                                                                                       local_classifiers[i], optim_list[i])
                    loss.backward()
                    optim_list[i].step()

                    print('Data owner ' + str(i),
                          ' Epoch: {:04d}'.format(epoch + 1),
                          'loss: {:.4f}'.format(loss.item()),
                          'time: {:.4f}s'.format(time.time() - t))
    return


def train_classifier_graphsage(classifier_list):
    optim_list =[]
    for classifier_i in classifier_list:
        optim_list.append(optim.Adam(classifier_i.model.parameters(), lr=config.lr))

    for epoch in range(config.epoch_classifier):
        for classifier_i in range(len(classifier_list)):

            classifier = classifier_list[classifier_i]
            for epoch_local in range(config.epochs_local):
                classifier.model.train()
                optim_list[classifier_i].zero_grad()

                out = classifier.model(classifier.pyg_data.x, classifier.pyg_data.edge_index)

                loss = F.cross_entropy(out[classifier.pyg_data.train_mask], classifier.pyg_data.y[classifier.pyg_data.train_mask])

                loss.backward()
                optim_list[classifier_i].step()
                print("local do = " + str(classifier_i) + " epoch = " + str(epoch) + " loss = " + str(loss.data))
    return


def test_classifier_graphsage(classifier_list):
    client_results = []
    for classifier_i in range(len(classifier_list)):
        classifier = classifier_list[classifier_i]
        classifier.model.eval()
        out = classifier.model(classifier.pyg_data.x, classifier.pyg_data.edge_index)
        pred = out.argmax(dim=1)
        test_correct = pred[classifier.pyg_data.test_mask] == classifier.pyg_data.y[classifier.pyg_data.test_mask]
        test_acc = float(test_correct.sum()) / int(classifier.pyg_data.test_mask.sum())
        print(f"Classifier {classifier_i}, test num: {classifier.pyg_data.test_mask.sum()}")
        print(f"Classifier {classifier_i}, Test Accuracy: {test_acc:.4f}")
        client_results.append({'client_id': classifier_i, 'test_acc': test_acc})
    save_experiment_results(config, client_results, os.path.join(config.save_path, 'log.json'))
    return


def mend_graph_pyg(classifier_list, local_owners):
    for i in range(len(classifier_list)):
        local_owner = local_owners[i]
        classifier = classifier_list[i]
        input_feat = local_owner.all_feat
        input_edge = local_owner.edges
        input_adj = local_owner.adj

        local_owner.fed_model.eval()
        _, pred_missing, pred_feats, _ = local_owner.fed_model(input_feat, input_edge, input_adj)

        impaired_graph_pyg = convert_stellargraph_to_pyg(local_owner.subG, classifier.train_subjects,
                                                         classifier.test_subjects, classifier.train_targets,
                                                         classifier.test_targets).cuda()


        fillG_pyg = fill_graph_pytorch(impaired_graph_pyg, pred_missing, pred_feats, local_owner.feat_shape).cuda()

        classifier.pyg_data = fillG_pyg
        classifier.pyg_data.x = classifier.pyg_data.x.detach()
        classifier.pyg_data.edge_index = classifier.pyg_data.edge_index.detach()
        classifier.pyg_data.y = classifier.pyg_data.y.detach()
        classifier.pyg_data.train_mask = classifier.pyg_data.train_mask.detach()
        classifier.pyg_data.test_mask = classifier.pyg_data.test_mask.detach()
    return


def train_federated_classifier_graphsage(classifier_list):
    # Initialize global model with weights from first classifier
    base_weights = [param.clone().detach() for param in classifier_list[0].model.parameters()]
    weights_len = len(base_weights)
    optimizers = [optim.Adam(classifier.model.parameters(), lr=config.lr)
                  for classifier in classifier_list]

    os.makedirs(os.path.join(config.save_path, 'log'), exist_ok=True)

    loss_history = {i: [] for i in range(len(classifier_list))}
    acc_history = {i: [] for i in range(len(classifier_list))}

    for epoch in range(config.epoch_classifier):
        # Store the current global weights
        global_weights = [param.clone().detach() for param in classifier_list[0].model.parameters()]

        # List to store updated weights from each client
        local_weights_list = []

        # Local training for each client
        for classifier_i in range(len(classifier_list)):
            classifier = classifier_list[classifier_i]

            # Set classifier to global weights before local training
            with torch.no_grad():
                for param, global_w in zip(classifier.model.parameters(), global_weights):
                    param.copy_(global_w)

            # Perform local training
            classifier.model.train()
            optimizers[classifier_i].zero_grad()

            out = classifier.model(classifier.pyg_data.x, classifier.pyg_data.edge_index)
            pred = out.argmax(dim=1)
            test_correct = pred[classifier.pyg_data.train_mask] == classifier.pyg_data.y[classifier.pyg_data.train_mask]
            client_acc = float(test_correct.sum()) / int(classifier.pyg_data.train_mask.sum())

            loss = F.cross_entropy(out[classifier.pyg_data.train_mask],
                                   classifier.pyg_data.y[classifier.pyg_data.train_mask])

            loss.backward()
            optimizers[classifier_i].step()

            loss_history[classifier_i].append(loss.item())
            acc_history[classifier_i].append(client_acc)

            # print("local do = " + str(classifier_i) + " epoch = " + str(epoch) + " loss = " + str(loss.data))

            # Store updated weights after local training
            local_weights = [param.clone().detach() for param in classifier.model.parameters()]
            local_weights_list.append(local_weights)

        # Average the weights from all clients
        avg_weights = [torch.zeros_like(w) for w in global_weights]
        for local_weights in local_weights_list:
            for i in range(weights_len):
                avg_weights[i] += local_weights[i]

        for i in range(weights_len):
            avg_weights[i] *= 1.0 / len(classifier_list)

        # Update all models with the new averaged weights
        for classifier in classifier_list:
            with torch.no_grad():
                for param, avg_w in zip(classifier.model.parameters(), avg_weights):
                    param.copy_(avg_w)

    # After training is complete, write the data to CSV files
    for classifier_i in range(len(classifier_list)):
        # Save loss data
        with open(os.path.join(config.save_path, f'log/classifier_{classifier_i}_loss.csv'), 'w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(['epoch', 'loss'])
            for epoch, loss_val in enumerate(loss_history[classifier_i]):
                writer.writerow([epoch, loss_val])

        # Save accuracy data
        with open(os.path.join(config.save_path, f'log/classifier_{classifier_i}_accuracy.csv'), 'w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(['epoch', 'accuracy'])
            for epoch, acc_val in enumerate(acc_history[classifier_i]):
                writer.writerow([epoch, acc_val])

    return

def test_classifier_plus_graphsage(local_classifiers:list, local_owners:list, target_index=None):
    global_accs, target_accs, client_results = [], [], []
    for i in range(config.num_owners):
        local_owner = local_owners[i]
        classifier = local_classifiers[i]
        input_feat = local_owner.all_feat
        input_edge = local_owner.edges
        input_adj = local_owner.adj
        _, pred_missing, pred_feats, _ = local_owner.fed_model(input_feat, input_edge, input_adj)

        impaired_graph_pyg = convert_stellargraph_to_pyg(local_owner.subG, classifier.train_subjects,
                                                         classifier.test_subjects, classifier.train_targets,
                                                         classifier.test_targets).cuda()

        fillG_pyg = fill_graph_pytorch(impaired_graph_pyg, pred_missing, pred_feats, local_owner.feat_shape).cuda()

        out = classifier.model(fillG_pyg.x, fillG_pyg.edge_index)
        pred = out.argmax(dim=1)
        test_correct = pred[fillG_pyg.test_mask] == fillG_pyg.y[fillG_pyg.test_mask]
        global_acc = float(test_correct.sum()) / int(fillG_pyg.test_mask.sum())
        print(f'global acc: {global_acc}')
        global_accs.append(global_acc)
        added_indices = torch.where(fillG_pyg.cpu().added_mask)[0].numpy()
        all_num_nodes = fillG_pyg.num_nodes
        avg_added = 0
        added_len = len(added_indices)
        if len(added_indices) > 0:
            avg_added = added_len / (all_num_nodes - added_len)

        global_acc =  round(global_acc, 4)

        client_results.append({'client_id': i, 'test_acc': global_acc, 'added_all': added_len, 'added_avg': avg_added})

    # save results in json
    avg_global_acc = sum(global_accs[1:]) / (len(global_accs) - 1)
    avg_global_acc = round(avg_global_acc, 4)
    print(f'avg global acc: {avg_global_acc}')
    save_experiment_results(config, client_results, os.path.join(config.save_path, 'log.json'))
    return target_accs

