#!/usr/bin/env python3
import argument
import composition
import dataset
import header
import json
import logger
import model
import os
import sklearn.metrics
import spn
import torch
import tqdm
import type
import utility
import wandb
import csv

# NEW: import necessary packages
import sys
sys.path.append('../adversarial-attacks-pytorch')
from torchattacks import PGD, PGDL2, CW, CWBS


def test(model_decomposed, spn_joint, spn_marginal, spn_settings_joint, data_loader, device, batch_step, marginal_probabilities_counted = None):
    utility.loadCheckpointBest(header.config_decomposed["dir_checkpoints"], header.config_decomposed["file_name_checkpoint_best"], model_decomposed)
    utility.loadCheckpointBestSPN(spn_joint, header.config_spn["dir_checkpoints"], header.config_spn["file_name_checkpoint_best"])
    utility.loadCheckpointBestSPN(spn_marginal, header.config_spn["dir_checkpoints"], header.config_spn["file_name_checkpoint_best"])
    csv_results = {}

    # NEW for robust model
    # get each class's centers and neighbors
    with open(header.config_decomposed["file_path_neighbor_config"], 'r') as file:
        center_neighbor = json.load(file)
    # get all combinations/nodes
    # e.g.,
    # tensor([[0., 0., 0.],
    #         [0., 0., 1.],
    #         [0., 0., 2.],
    #         ...,
    #         [9., 9., 7.],
    #         [9., 9., 8.],
    #         [9., 9., 9.]], device='cuda:0')
    all_nodes = spn_settings_joint[spn_settings_joint[:, -1] == 0][:, :-1]

    accuracy_epoch_composed = 0
    accuracy_epoch_list_decomposed = []
    config_dataset = data_loader.dataset.config
    ground_truths_epoch_composed = []
    ground_truths_epoch_list_decomposed = []
    mpes = {}
    output_list_composed = []
    output_list_decomposed = []
    output_list_decomposed_accuracy = []
    output_list_decomposed_precision = []
    output_list_decomposed_recall = []
    predictions_epoch_composed = []
    predictions_epoch_list_decomposed = []
    progress_bar = tqdm.tqdm(total = len(data_loader), position = 0, leave = False)
    spn_output_rows = len(data_loader.dataset.classes_original)
    spn_output_cols = 1

    for attribute in config_dataset["attributes"]:
        attribute_labels = attribute["labels"]
        spn_output_cols *= len(attribute_labels)
        accuracy_epoch_list_decomposed.append(0)
        ground_truths_epoch_list_decomposed.append([])
        predictions_epoch_list_decomposed.append([])

    model_decomposed.eval()
    progress_bar.set_description_str("[INFO]: Testing progress")

    # NEW: configure attack ids and type
    attack_ids = torch.tensor(header.config_decomposed["attack_ids"]).to(device, non_blocking=True)

    if header.config_decomposed["attack_name"] == "pgd":
        atk = PGD(model_decomposed, eps=header.config_decomposed["attack_bound"], alpha=2/255, steps=50, random_start=True)
    elif header.config_decomposed["attack_name"] == "pgdl2":
        atk = PGDL2(model_decomposed, eps=header.config_decomposed["attack_bound"], alpha=header.config_decomposed["attack_bound"]/10, steps=50, random_start=True)
    elif header.config_decomposed["attack_name"] == "cw":
        atk = CWBS(model_decomposed, init_c=1.0, kappa=0, steps=10, lr=0.01, binary_search_steps=int(header.config_decomposed["attack_bound"]))
    else:
        raise ValueError(f"Unknown attack name: {header.config_decomposed['attack_name']}")

    if header.config_decomposed["targeted_attack"]:
        atk.set_mode_targeted_by_label() # New for targeted attack
    print(atk)
    adv_norm_list = []

    # NEW: set grad to be True
    with torch.set_grad_enabled(True):
        for (batch_index, (input, labels_decomposed, labels_original, input_file_paths)) in enumerate(data_loader):
            input = input.to(device, non_blocking = True)
            labels_decomposed = labels_decomposed.to(device, non_blocking = True)
            labels_original = labels_original.to(device, non_blocking = True)

            # NEW: add perturbations on inputs
            if header.config_decomposed["targeted_attack"]:
                target_labels = labels_decomposed.clone()
                target_labels[:, attack_ids[0]] = 0
                if header.config_decomposed["attack_name"] == "cw":
                    adv_images = atk(input, target_labels, attack_ids)
                    original_flat = input.view(input.size(0), -1)
                    perturbed_flat = adv_images.view(adv_images.size(0), -1)
                    adv_norm = (torch.sqrt(torch.sum((original_flat - perturbed_flat) ** 2, dim=1))).mean().item()
                else:
                    adv_images = atk(input, target_labels, attack_ids)
                    adv_norm = header.config_decomposed["attack_bound"]
            else:
                if header.config_decomposed["attack_name"] == "cw":
                    adv_images = atk(input, labels_decomposed, attack_ids)
                    original_flat = input.view(input.size(0), -1)
                    perturbed_flat = adv_images.view(adv_images.size(0), -1)
                    adv_norm = (torch.sqrt(torch.sum((original_flat - perturbed_flat) ** 2, dim=1))).mean().item()
                else:
                    adv_images = atk(input, labels_decomposed, attack_ids)
                    adv_norm = header.config_decomposed["attack_bound"]
            adv_norm_list.append(adv_norm)

            (outputs_decomposed, _) = model_decomposed(adv_images)  # len(outputs_decomposed)=3, outputs_decomposed[0].shape=torch.Size([384, 10])

            outputs_decomposed = utility.applySoftmaxDecomposed(outputs_decomposed)

            for (i, dataset_entry) in enumerate(config_dataset["attributes"]):
                (_, predictions_decomposed) = torch.max(outputs_decomposed[i], 1)

                corrects_decomposed = torch.sum(predictions_decomposed == labels_decomposed[:, i].data).item()
                accuracy_batch_decomposed = corrects_decomposed / input.size(0)
                accuracy_epoch_list_decomposed[i] += corrects_decomposed

                wandb.log({"testing/batch/" + dataset_entry["name"] + "/accuracy": accuracy_batch_decomposed})

                ground_truths_epoch_list_decomposed[i] += labels_decomposed[:, i].data.tolist()
                predictions_epoch_list_decomposed[i] += predictions_decomposed.tolist()

            (matrix_a, matrix_b, _) = composition.Composition.spn(outputs_decomposed, spn_joint, spn_marginal, spn_output_rows, spn_output_cols, device, marginal_probabilities_counted)

            outputs_composed_list = []

            # NEW for robust model
            for class_name in os.listdir(header.config_decomposed['dir_dataset_test']):
                center_list = center_neighbor[class_name]["center"]  # e.g., [0,0,0]
                centers = torch.tensor(center_list, dtype=matrix_a.dtype, device=matrix_a.device)
                centers_mask = (all_nodes.unsqueeze(1) == centers).all(-1).any(1)

                neighbor_list = center_neighbor[class_name]["neighbor"]  # e.g. [[0,0,0], [0,0,1], ...]
                neighbors = torch.tensor(neighbor_list, dtype=matrix_a.dtype, device=matrix_a.device)
                neighbors_mask = (all_nodes.unsqueeze(1) == neighbors).all(-1).any(1)

                vector_a = torch.sum(matrix_a[:, centers_mask], dim=1, keepdim=True)  # shape: torch.Size([10,1])
                vector_b = torch.sum(matrix_b[neighbors_mask, :], dim=0, keepdim=True)  # shape: torch.Size([1, 384])
                matrix_c = torch.matmul(vector_a, vector_b).t()  # shape: torch.Size([384, 10])

                outputs_composed_list.append(matrix_c)

            stacked_matrices = torch.stack(outputs_composed_list)  # Shape: [num_matrices, 384, 10]
            outputs_composed_unnml = torch.sum(stacked_matrices, dim=0)  # Shape: [384, 10]
            partitions = outputs_composed_unnml.sum(dim=1, keepdim=True)  # Shape: [384, 1]
            outputs_composed = outputs_composed_unnml / partitions

            (_, predictions_composed) = torch.max(outputs_composed, 1)

            corrects_composed = torch.sum(predictions_composed == labels_original.data).item()

            accuracy_batch_composed = corrects_composed / input.size(0)
            accuracy_epoch_composed += corrects_composed

            # mpe_attributes = utility.findMPEs(matrix_a, matrix_b, spn_settings_joint, labels_original)
            # utility.saveMPEs(input_file_paths, mpes, mpe_attributes, data_loader.dataset, labels_decomposed, labels_original, outputs_decomposed, outputs_composed)

            progress_bar.n = batch_index + 1
            progress_bar.refresh()

            wandb.log({"testing/batch/accuracy": accuracy_batch_composed})
            wandb.log({"testing/batch/step": batch_step})

            ground_truths_epoch_composed += labels_original.data.tolist()
            predictions_epoch_composed += predictions_composed.tolist()

            batch_step += 1

    progress_bar.close()

    csv_results["attack_name"] = header.config_decomposed["attack_name"]
    csv_results["attack_ids"] = " ".join(str(id) for id in header.config_decomposed["attack_ids"])
    csv_results["attack_norm"] = round(sum(adv_norm_list) / len(adv_norm_list), 2)

    # if not os.path.isdir(header.dir_output_mpe):
    #     os.makedirs(header.dir_output_mpe, exist_ok = True)
    #
    # with open(os.path.join(header.dir_output_mpe, header.file_name_mpe), "w") as file_mpe:
    #     json.dump(mpes, file_mpe, indent = 4)

    for (i, dataset_entry) in enumerate(config_dataset["attributes"]):
        accuracy_epoch_list_decomposed[i] /= len(data_loader.dataset)
        precision_epoch_decomposed = sklearn.metrics.precision_score(ground_truths_epoch_list_decomposed[i], predictions_epoch_list_decomposed[i], average = "macro", zero_division = 0)
        recall_epoch_decomposed = sklearn.metrics.recall_score(ground_truths_epoch_list_decomposed[i], predictions_epoch_list_decomposed[i], average = "macro", zero_division = 0)

        output_list_decomposed_accuracy.append(accuracy_epoch_list_decomposed[i])
        output_list_decomposed_precision.append(precision_epoch_decomposed)
        output_list_decomposed_recall.append(recall_epoch_decomposed)

        wandb.log({"testing/epoch/" + dataset_entry["name"] + "/accuracy": accuracy_epoch_list_decomposed[i]})
        wandb.log({"testing/epoch/" + dataset_entry["name"] + "/precision": precision_epoch_decomposed})
        wandb.log({"testing/epoch/" + dataset_entry["name"] + "/recall": recall_epoch_decomposed})
        wandb.summary["testing/epoch/" + dataset_entry["name"] + "/accuracy"] = accuracy_epoch_list_decomposed[i]
        wandb.summary["testing/epoch/" + dataset_entry["name"] + "/precision"] = precision_epoch_decomposed
        wandb.summary["testing/epoch/" + dataset_entry["name"] + "/recall"] = recall_epoch_decomposed

        logger.log_info("Decomposed testing accuracy for \"" + dataset_entry["name"] + "\": " + str(accuracy_epoch_list_decomposed[i]) + ".")
        logger.log_trace("Decomposed testing precision for \"" + dataset_entry["name"] + "\": " + str(precision_epoch_decomposed) + ".")
        logger.log_trace("Decomposed testing recall for \"" + dataset_entry["name"] + "\": " + str(recall_epoch_decomposed) + ".")
        csv_results[dataset_entry["name"]] = accuracy_epoch_list_decomposed[i]

    accuracy_epoch_composed /= len(data_loader.dataset)

    precision_epoch_composed = sklearn.metrics.precision_score(ground_truths_epoch_composed, predictions_epoch_composed, average = "macro", zero_division = 0)
    recall_epoch_composed = sklearn.metrics.recall_score(ground_truths_epoch_composed, predictions_epoch_composed, average = "macro", zero_division = 0)
    output_list_composed += [accuracy_epoch_composed, precision_epoch_composed, recall_epoch_composed]
    output_list_decomposed += output_list_decomposed_accuracy
    output_list_decomposed += output_list_decomposed_precision
    output_list_decomposed += output_list_decomposed_recall

    wandb.log({"testing/epoch/accuracy": accuracy_epoch_composed})
    wandb.log({"testing/epoch/precision": precision_epoch_composed})
    wandb.log({"testing/epoch/recall": recall_epoch_composed})
    wandb.summary["testing/epoch/accuracy"] = accuracy_epoch_composed
    wandb.summary["testing/epoch/precision"] = precision_epoch_composed
    wandb.summary["testing/epoch/recall"] = recall_epoch_composed

    logger.log_info("Composed testing accuracy: " + str(accuracy_epoch_composed) + ".")
    logger.log_trace("Composed testing precision: " + str(precision_epoch_composed) + ".")
    logger.log_trace("Composed testing recall: " + str(recall_epoch_composed) + ".")
    csv_results["class"] = accuracy_epoch_composed

    utility.logTestOutput(output_list_composed, header.config_baseline, config_dataset, True)
    utility.logTestOutput(output_list_decomposed, header.config_decomposed, config_dataset)

    # Write csv results
    fieldnames = csv_results.keys()
    csv_dir = os.path.join(header.config_decomposed["dir_results"], "test_composed_robust_attack")
    os.makedirs(csv_dir, exist_ok=True)
    csv_file = os.path.join(csv_dir, f"{csv_results['attack_name']}_{csv_results['attack_ids']}_{header.config_decomposed['attack_bound']}.csv")
    with open(csv_file, mode="w", newline="") as file:
        writer = csv.DictWriter(file, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerow(csv_results)

    return

def main():
    argument.processArgumentsTestComposed()

    header.run_name_baseline = header.config_decomposed["run_name"]
    header.config_baseline["dir_dataset_test"] = header.config_decomposed["dir_dataset_test"]
    header.config_baseline["file_name_checkpoint"] = header.run_name_decomposed + ".tar"
    header.config_baseline["file_name_checkpoint_best"] = header.run_name_decomposed + ".best.tar"
    header.config_baseline["run_name"] = header.run_name_decomposed

    utility.setSeed(header.seed)
    torch.backends.cuda.matmul.allow_tf32 = header.cuda_allow_tf32

    config = {
        "decomposed": header.config_decomposed,
        "spn": header.config_spn
    }

    wandb.init(config = config, mode = "disabled")

    dataset_transforms = utility.createTransform(header.config_decomposed)
    dataset_test = dataset.VISATDataset(header.config_decomposed["dir_dataset_test"], dataset_transforms)
    config_dataset = dataset_test.config
    data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size = header.config_decomposed["data_loader_batch_size"], shuffle = False, num_workers = header.config_decomposed["data_loader_worker_count"], pin_memory = True)
    device = torch.device("cuda")
    marginal_probabilities_counted = None
    model_decomposed = model.createModelDecomposed(device)
    model_decomposed = torch.nn.DataParallel(model_decomposed)
    model_decomposed = model_decomposed.to(device)
    spn_joint = spn.SPN(device)
    spn_marginal = spn.SPN(device)

    if header.config_spn["optimizer"] == type.OptimizerSPN.cccp_discriminative.name:
        (marginal_probabilities_counted, _) = utility.countAttributeJointProbabilities(config_dataset, device)

    logger.log_info("Loading SPN from \"" + header.config_spn["file_path_spn"] + "\"...")

    spn_joint.load(header.config_spn["file_path_spn"])
    spn_marginal.load(header.config_spn["file_path_spn"])

    logger.log_info("Loading SPN leaf node settings...")

    spn_settings_joint = utility.generateSPNSettings(config_dataset, device)
    spn_settings_marginal = torch.clone(spn_settings_joint)
    spn_settings_marginal[:, -1] = -1

    logger.log_info("Setting SPN leaf nodes...")

    spn_joint.set_leaf_nodes(spn_settings_joint)
    spn_marginal.set_leaf_nodes(spn_settings_marginal)

    if header.show_model_summary:
        logger.log_info("Number of nodes: " + str(len(spn_joint.nodes)) + ".")
        logger.log_info("Number of sum nodes: " + str(len(spn_joint.sum_nodes)) + ".")
        logger.log_info("Number of product nodes: " + str(len(spn_joint.product_nodes)) + ".")
        logger.log_info("Number of leaf nodes: " + str(len(spn_joint.leaf_nodes)) + ".")
        logger.log_info("SPN depths: " + str(spn_joint.depth) + ".")
        logger.log_info("SPN leaf node setting dimension: (" + str(int(spn_settings_joint.shape[0])) + ", " + str(int(spn_settings_joint.shape[1])) + ").")

    test(model_decomposed, spn_joint, spn_marginal, spn_settings_joint, data_loader_test, device, 1, marginal_probabilities_counted)

    return

if __name__ == "__main__":
    main()
