"""Runs a sanity check experiment to estimate adversarial transferability
from the first model to the second one.

To run, please, specify
1) The first model checkpoint to load: --model_proxy=<path/to/checkpoint/file>
2) The second model checkpoint to load: --model_target=<path/to/checkpoint/file>
3) Dataset name: data_name=<dataset_name>
4) The first model's backbone name: --model_name=<model_name>
5) Adversarial attack algorthm: --attack_name=<algorithm>
6) Experiment group name: --experiment_group=<group>
7) Experiment name: --experiment_name=<experiment_name>

Check README.md for more information.

The script will compute adversarial transferability score from model_proxy to model_target.
"""
import argparse
import logging
import os
from datetime import datetime

import torch
import torchvision.transforms as vision_transforms
from classy_vision.generic.util import load_checkpoint
from torch import Tensor
from vissl.models import build_model
from vissl.models.base_ssl_model import BaseSSLMultiInputOutputModel
from vissl.utils.checkpoint import init_model_from_consolidated_weights
from vissl.utils.hydra_config import compose_hydra_configuration, convert_to_attrdict
from vissl.utils.logger import setup_logging, shutdown_logging

from data.get_mnist import ROOT, IDX2LABEL
from ats.utils import ProxyModel, set_seeds, show, get_config_from_name, find_final_model, get_random_key, name_parser
from ats.eval import evaluate
from ats.attacks import attack_transferability, get_data_and_attack

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def init_model(config: list):
    """Initializes model from configs.
    Args:
        config: Configuration arguments.
        is_fulltuned: True if a model is fulltuned (including backbone) otherwse False.
    Returns:
        Initialized model.
    """
    config = compose_hydra_configuration(config)
    _, config = convert_to_attrdict(config)
    model = BaseSSLMultiInputOutputModel(config.MODEL, config.OPTIMIZER)
    weights = load_checkpoint(checkpoint_path=config.MODEL.WEIGHTS_INIT.PARAMS_FILE)

    # init_model_from_consolidated_weights(config=config, model=model, state_dict=weights,
    #                                      state_dict_key_name="classy_state_dict", skip_layers=[])

    vissl_state_dict = weights.get("classy_state_dict")
    model.set_classy_state(vissl_state_dict["base_model"])
     
    check_validity = False
    if check_validity:
        check_key = get_random_key(weights['classy_state_dict']['base_model']['model']['trunk'])
        if 'base_model' in check_key:
            check_key_changed = check_key.split('base_model.')[1]
        else:
            check_key_changed = check_key

        if "fulltune" in config.MODEL.WEIGHTS_INIT.PARAMS_FILE:
            check_key_changed = f"trunk.{check_key_changed}"
        elif "finetune" in config.MODEL.WEIGHTS_INIT.PARAMS_FILE:
            check_key_changed = f"trunk.base_model.{check_key_changed}"

        if 'xcit' in config.MODEL.TRUNK.NAME or 'transformer' in config.MODEL.TRUNK.NAME:
            weights_0 = weights['classy_state_dict']['base_model']['model']['trunk'][check_key]
            model_weights_0 = model.state_dict()[f'{check_key_changed}']
        elif 'alex' in config.MODEL.TRUNK.NAME:
            weights_0 = weights['classy_state_dict']['base_model']['model']['trunk'][check_key]
            model_weights_0 = model.state_dict()[f'{check_key_changed}']
        elif 'resnet' in config.MODEL.TRUNK.NAME or 'rn50' in config.MODEL.TRUNK.NAME:
            print(weights['classy_state_dict']['base_model']['model']['trunk'].keys())
            print(model.state_dict().keys())
            weights_0 = weights['classy_state_dict']['base_model']['model']['trunk'][check_key]
            model_weights_0 = model.state_dict()[f'{check_key_changed}']


        if weights_0.sum() != model_weights_0.sum():
            print("Weights are not equal")
            print(config.MODEL.WEIGHTS_INIT.PARAMS_FILE)

    return ProxyModel(model, device)

    

if __name__ == "__main__":
    timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
    set_seeds(1997)  # Ensuring reproducability.
    parser = argparse.ArgumentParser(description="Sanity check experiment",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--model_proxy", required=True, action="store",
                        help="The relative path to the proxy model weights from ats/tuned_models/. For \
                        example: 'ats/tuned_models/sanity_check_model_proxy/'")
    parser.add_argument("--model_target", required=True, action="store",
                        help="The relative path to the second model weights from ./vissl_tools/checkpoints/. For \
                        example: 'ats/tuned_models/sanity_check_model_target/'")
    parser.add_argument("--data_name", required=True, action="store",
                        help="Dataset name. Supported datasets are ('mnist', ).")
    parser.add_argument("--attack_name", required=True, action="store",
                        help="Attack name. Supported attacks are ('fgsm', ).")
    parser.add_argument("--experiment_group", required=True, action="store",
                        help="Experiment group (directory name to be created if needed in ./vissl_tools/experiments/).")
    parser.add_argument("--experiment_name", required=True, action="store",
                        help="Experiment name in the experiment group.")
    parser.add_argument("--mode_eval_proxy", required=False, action="store_true",
                        help="Evaluate the performance of the proxy model on the test dataset")
    parser.add_argument("--mode_eval_target", required=False, action="store_true",
                        help="Evaluate the performance of the target model on the test dataset")

    args = parser.parse_args()

    # Creating checkpoint directory.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    root = f"ats/experiments/{args.experiment_group}/"
    models_root = "ats/tuned_models/"
    checkpoint_dir = os.path.join(root, f"{args.experiment_name}_{timestamp}")
    os.makedirs(f"{checkpoint_dir}/results")

    # Setting up a logger.
    setup_logging(__name__, output_dir=checkpoint_dir)
    logging.info(f"Model Proxy: {args.model_proxy}")
    logging.info(f"Model Target: {args.model_target}")


    c_proxy = name_parser(args.model_proxy)
    c_target = name_parser(args.model_target)

    cfg_proxy = get_config_from_name(c_proxy)
    cfg_target = get_config_from_name(c_target)

    # Initializing models.
    model_proxy = init_model(cfg_proxy)
    model_proxy.eval()
    model_target = init_model(cfg_target)
    model_target.eval()

    # Initializing attacking algorithm, dataloaders and transformations.
    atk, transform, inverse_transform, dataloader = get_data_and_attack(model_proxy, data_name=args.data_name,
                                                                        attack_name=args.attack_name)

    if args.mode_eval_proxy:
        logging.info("Evaluating the proxy model on the test dataset.")
        correctly_classified = evaluate(model_proxy, dataloader, device)
        logging.info(f"Correctly classified: {correctly_classified}")
    
    if args.mode_eval_target:
        logging.info("Evaluating the target model on the test dataset.")
        correctly_classified = evaluate(model_target, dataloader, device)
        logging.info(f"Correctly classified: {correctly_classified}")

    attack_transferability_score = attack_transferability(model_proxy, model_target, dataloader, atk, device, \
        checkpoint_dir, inverse_transform, IDX2LABEL)

    shutdown_logging()
