"""Runs a sanity check experiment to estimate adversarial transferability
from the first model to the second one.

To run, please, specify
1) The first model checkpoint to load: --model_proxy=<path/to/checkpoint/file>
2) The second model checkpoint to load: --model_target=<path/to/checkpoint/file>
3) Dataset name: data_name=<dataset_name>
4) The first model's backbone name: --model_name=<model_name>
5) Adversarial attack algorthm: --attack_name=<algorithm>
6) Experiment group name: --experiment_group=<group>
7) Experiment name: --experiment_name=<experiment_name>

Check README.md for more information.

The script will compute adversarial transferability score from model_proxy to model_target.
"""
import argparse
import logging
import os
from datetime import datetime

import torch
import torchvision.transforms as vision_transforms
from classy_vision.generic.util import load_checkpoint
from torch import Tensor
from vissl.models import build_model
from vissl.models.base_ssl_model import BaseSSLMultiInputOutputModel
from vissl.utils.checkpoint import init_model_from_consolidated_weights
from vissl.utils.hydra_config import compose_hydra_configuration, convert_to_attrdict
from vissl.utils.logger import setup_logging, shutdown_logging

from data.get_mnist import ROOT, IDX2LABEL
from ats.utils import ProxyModel, set_seeds, show
from ats.eval import evaluate
from ats.attacks import attack_blackbox, get_data_and_attack

# Configurations for fine-tuned models.
cfg_finetune = ["config.MODEL.FEATURE_EVAL_SETTINGS.EVAL_MODE_ON=True",
                "config.MODEL.FEATURE_EVAL_SETTINGS.FREEZE_TRUNK_ONLY=True",
                "config.MODEL.FEATURE_EVAL_SETTINGS.FREEZE_TRUNK_AND_HEAD=False",
                "config.MODEL.FEATURE_EVAL_SETTINGS.EXTRACT_TRUNK_FEATURES_ONLY=False",
                "config.MODEL.FEATURE_EVAL_SETTINGS.EVAL_TRUNK_AND_HEAD=True"]

# Configurations for fulltuned models.
cfg_fulltune = ["config.MODEL.FEATURE_EVAL_SETTINGS.EVAL_MODE_ON=True",
                "config.MODEL.FEATURE_EVAL_SETTINGS.FREEZE_TRUNK_ONLY=False",
                "config.MODEL.FEATURE_EVAL_SETTINGS.FREEZE_TRUNK_AND_HEAD=False",
                "config.MODEL.FEATURE_EVAL_SETTINGS.EXTRACT_TRUNK_FEATURES_ONLY=False",
                "config.MODEL.FEATURE_EVAL_SETTINGS.EVAL_TRUNK_AND_HEAD=False"]


def init_model(config: list, is_fulltuned: bool):
    """Initializes model from configs.
    Args:
        config: Configuration arguments.
        is_fulltuned: True if a model is fulltuned (including backbone) otherwse False.
    Returns:
        Initialized model.
    """
    config = compose_hydra_configuration(config)
    #preety print config
    from pprint import pprint
    _, config = convert_to_attrdict(config)
    # pprint(config)

    config["MODEL"]["WEIGHTS_INIT"]["APPEND_PREFIX"] = "trunk._feature_blocks."
    # config["MODEL"]["WEIGHTS_INIT"]["REMOVE_PREFIX"] = "trunk"
    from pprint import pprint
    # pprint(config)
    # config["MODEL"]["WEIGHTS_INIT"]["APPEND_PREFIX"] = "trunk."

    # If classification head is changed, then define its structure by uncommenting the line below.
    # config["MODEL"]["HEAD"]["PARAMS"] = [["eval_mlp", {"in_channels": 2048, "dims": [2048, 1024, 10]}], ]

    model = BaseSSLMultiInputOutputModel(config.MODEL, config.OPTIMIZER)
    weights = load_checkpoint(checkpoint_path=config.MODEL.WEIGHTS_INIT.PARAMS_FILE)

    # init_model_from_consolidated_weights(config=config, model=model, state_dict=weights,
    #                                      state_dict_key_name="classy_state_dict", skip_layers=[])

    vissl_state_dict = weights.get("classy_state_dict")
    # print("vissl_state_dict", vissl_state_dict["base_model"]['model']['trunk'].keys())
    # print("model", model.state_dict().keys())
    model.set_classy_state(vissl_state_dict["base_model"])

    return ProxyModel(model, device)


if __name__ == "__main__":
    timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
    set_seeds(5)  # Ensuring reproducability.
    parser = argparse.ArgumentParser(description="Sanity check experiment",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--model", required=True, action="store",
                        help="The relative path to the first model weights from ./vissl_tools/checkpoints/. For \
                        example: 'sanity_check_model_proxy/model_final_checkpoint_phase0.torch'")
    parser.add_argument("--model_extra_config", required=False, action="store",
                        help="The relative path to the extra config for loading proxy: file from congifs/config/models/")
    parser.add_argument("--model_is_fulltuned", action="store_true",
                        help="Wheather the second model is fulltuned or not. False is not specified.")
    parser.add_argument("--data_name", required=True, action="store",
                        help="Dataset name. Supported datasets are ('mnist', ).")
    parser.add_argument("--model_name", required=True, action="store",
                        help="First model name. Supported models are ('resnet50', 'deit' ).")
    parser.add_argument("--attack_name", required=True, action="store",
                        help="Attack name. Supported attacks are ('fgsm', etc.).")
    parser.add_argument("--experiment_group", required=True, action="store",
                        help="Experiment group (directory name to be created if needed in ./vissl_tools/experiments/).")
    parser.add_argument("--experiment_name", required=True, action="store",
                        help="Experiment name in the experiment group.")
    parser.add_argument("--model_eval", required=False, action="store_true",default=False,
                        help="Evaluate the performance of the proxy model on the test dataset")
    parser.add_argument("--blackbox_budget", required=False, action="store", type=int, default=1000,
                        help="Number of queries to the blackbox model.")

    args = parser.parse_args()

    # Creating checkpoint directory.
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    root = f"ats/experiments/{args.experiment_group}/"
    models_root = "ats/tuned_models/"
    checkpoint_dir = os.path.join(root, f"{args.experiment_name}_{timestamp}")
    os.makedirs(f"{checkpoint_dir}/results")

    # Setting up a logger.
    setup_logging(__name__, output_dir=checkpoint_dir)
    logging.info(f"Model For Blackbox attack: {args.model}")

    # Getting a config file and filling in the required information.
    if args.data_name == "mnist" and args.model_name == "resnet50":
        config_filename = "config=eval_resnet_8gpu_transfer_mnist_linear.yaml"
    elif args.data_name == "cifar10" and args.model_name == "deit":
        config_filename = "config=eval_resnet_8gpu_transfer_cifar10_linear.yaml"
    else:
        raise ValueError(f"config is unknown for the given dataset '{args.data_name}' and model '{args.model_name}'.")

    cfg = cfg_fulltune[:] if args.model_is_fulltuned else cfg_finetune[:]
    
    if args.model_extra_config:
        path_to_file, file_name = os.path.split(args.model_extra_config)
        cfg = cfg + [f"+{path_to_file}={file_name.split('.')[0]}"]
        print(cfg)

    cfg = cfg + [f"config.MODEL.WEIGHTS_INIT.PARAMS_FILE={os.path.join(models_root, args.model)}",
                   config_filename, f"config.CHECKPOINT.DIR={checkpoint_dir}"]

    # Initializing models.
    model = init_model(cfg, args.model_is_fulltuned)
    model.eval()

    # Initializing attacking algorithm, dataloaders and transformations.
    atk, transform, inverse_transform, dataloader = get_data_and_attack(model, data_name=args.data_name,
                                                                        model_name=args.model_name,
                                                                        attack_name=args.attack_name)


    if args.model_eval:
        logging.info("Evaluating the proxy model on the test dataset.")
        correctly_classified = evaluate(model, dataloader, device)
        logging.info(f"Correctly classified: {correctly_classified}")

    attack_transferability_score = attack_blackbox(model, dataloader, atk, device, \
        checkpoint_dir, inverse_transform, IDX2LABEL)


    shutdown_logging()
