import argparse
import logging
import os
from datetime import datetime

import torch
import torchvision.transforms as vision_transforms
from classy_vision.generic.util import load_checkpoint
from torch import Tensor
from vissl.models import build_model
from vissl.models.base_ssl_model import BaseSSLMultiInputOutputModel
from vissl.utils.checkpoint import init_model_from_consolidated_weights
from vissl.utils.hydra_config import compose_hydra_configuration, convert_to_attrdict
from vissl.utils.logger import setup_logging, shutdown_logging
from vissl.utils.checkpoint import CheckpointLoader

from ats.data.get_mnist import ROOT, IDX2LABEL
from ats.utils import ProxyModel, set_seeds, show, get_config_from_name, find_final_model, get_random_key, name_parser
from ats.eval import evaluate
from ats.attacks import attack_blackbox, get_data_and_attack

from collections import defaultdict

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def init_model(config: list):
    """Initializes model from configs.
    Args:
        config: Configuration arguments.
        is_fulltuned: True if a model is fulltuned (including backbone) otherwse False.
    Returns:
        Initialized model.
    """
    config = compose_hydra_configuration(config)
    _, config = convert_to_attrdict(config)
    
    # model = build_model(config.MODEL, config.OPTIMIZER)
    model = BaseSSLMultiInputOutputModel(config.MODEL, config.OPTIMIZER)
    weights = load_checkpoint(checkpoint_path=config.MODEL.WEIGHTS_INIT.PARAMS_FILE)

    # init_model_from_consolidated_weights(config=config, model=model, state_dict=weights,
    #                                      state_dict_key_name="classy_state_dict", skip_layers=[])

    vissl_state_dict = weights.get("classy_state_dict")
    model.set_classy_state(vissl_state_dict["base_model"])

    # model = build_model(config["MODEL"], config["OPTIMIZER"]).cuda()
    # checkpoint = CheckpointLoader.load_and_broadcast_init_weights(
    #     checkpoint_path=config.MODEL.WEIGHTS_INIT.PARAMS_FILE,
    #     device=device,
    # )
    # model.init_model_from_weights_params_file(config, checkpoint)
     
    check_validity = False
    if check_validity:
        check_key = get_random_key(weights['classy_state_dict']['base_model']['model']['trunk'])
        check_key = 'base_model._feature_blocks.2.1.running_mean'
        if 'base_model' in check_key:
            check_key_changed = check_key.split('base_model.')[1]
        else:
            check_key_changed = check_key

        if "fulltune" in config.MODEL.WEIGHTS_INIT.PARAMS_FILE:
            check_key_changed = f"trunk.{check_key_changed}"
        elif "finetune" in config.MODEL.WEIGHTS_INIT.PARAMS_FILE:
            check_key_changed = f"trunk.base_model.{check_key_changed}"

        if 'xcit' in config.MODEL.TRUNK.NAME or 'transformer' in config.MODEL.TRUNK.NAME:
            weights_0 = weights['classy_state_dict']['base_model']['model']['trunk'][check_key]
            model_weights_0 = model.state_dict()[f'{check_key_changed}']
        elif 'alex' in config.MODEL.TRUNK.NAME:
            weights_0 = weights['classy_state_dict']['base_model']['model']['trunk'][check_key]
            model_weights_0 = model.state_dict()[f'{check_key_changed}']
        elif 'resnet' in config.MODEL.TRUNK.NAME or 'rn50' in config.MODEL.TRUNK.NAME:
            print(weights['classy_state_dict']['base_model']['model']['trunk'].keys())
            print(model.state_dict().keys())
            weights_0 = weights['classy_state_dict']['base_model']['model']['trunk'][check_key]
            model_weights_0 = model.state_dict()[f'{check_key_changed}']


        if weights_0.sum() != model_weights_0.sum():
            print("Weights are not equal")
            print(config.MODEL.WEIGHTS_INIT.PARAMS_FILE)

    return ProxyModel(model, device)


if __name__ == "__main__":
    timestamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
    set_seeds(1997)  # Ensuring reproducability.
    parser = argparse.ArgumentParser(description="Sanity check experiment",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--model_target", required=True, action="store",
                        help="The relative path to the second model weights from ./vissl_tools/checkpoints/. For \
                        example: 'ats/tuned_models/sanity_check_model_target/'")
    parser.add_argument("--data_name", required=True, action="store",
                        help="Dataset name. Supported datasets are ('mnist', ).")
    parser.add_argument("--attack_name", required=True, action="store",
                        help="Attack name. Supported attacks are ('fgsm', ).")
    parser.add_argument("--experiment_group", required=True, action="store",
                        help="Experiment group (directory name to be created if needed in ./vissl_tools/experiments/).")
    parser.add_argument("--experiment_name", required=True, action="store",
                        help="Experiment name in the experiment group.")
    parser.add_argument("--mode_eval_target", required=False, action="store_true",
                        help="Evaluate the performance of the target model on the test dataset")
    parser.add_argument("--targeted", required=False, action="store_true",
                        help="Whether to perform a targeted attack or not.")
    parser.add_argument("--white", required=False, action="store_true",
                        help="Whether to perform a white-box attack or not.")

    args = parser.parse_args()

    # Creating checkpoint directory.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    root = f"ats/experiments/{args.experiment_group}/"
    models_root = "ats/tuned_models/"
    checkpoint_dir = os.path.join(root, f"{args.experiment_name}")

    if not os.path.exists(f"{checkpoint_dir}/results"):
        os.makedirs(f"{checkpoint_dir}/results")

    # Setting up a logger.
    setup_logging(__name__, output_dir=checkpoint_dir)
    logging.info(f"Model Target: {args.model_target}")


    c_target = name_parser(args.model_target)

    cfg_target = get_config_from_name(c_target)

    # Initializing models.
    model_target = init_model(cfg_target)

    # for name, buff in model_target.named_buffers():
    #     print(name)

    model_target.eval()
    # model_target.train(False)
    # model_target.model.eval()
    # model_target.eval()
    # model_target.model.eval()
    # model_target.model.trunk.eval()
    # model_target.model.heads.eval()



    cfg_target = compose_hydra_configuration(cfg_target)
    _, cfg_target = convert_to_attrdict(cfg_target)

    # Initializing attacking algorithm, dataloaders and transformations.
    args.data_name = args.data_name.lower().strip()
    atk, transform, inverse_transform, dataloader = get_data_and_attack(model_target, data_name=args.data_name,
                                                                        attack_name=args.attack_name, cfg = cfg_target)

    if args.mode_eval_target:
        logging.info("Evaluating the target model on the test dataset.")
        correctly_classified = evaluate(model_target, dataloader, device)
        logging.info(f"Correctly classified: {correctly_classified}")

    # attack_transferability_score = attack_transferability(model_proxy, model_target, dataloader, atk, device, \
    #     checkpoint_dir, inverse_transform, IDX2LABEL, targeted = args.targeted)

    # print(atk.n_queries, atk.eps)
    IDX2LABEL = defaultdict(lambda: "class_name", IDX2LABEL)
    attack_transferability_score = attack_blackbox(model = model_target, \
                                                          dataloader=dataloader, atk=atk, device=device, \
                                                        checkpoint_dir = checkpoint_dir, inverse_transform=inverse_transform,\
                                                        IDX2LABEL=IDX2LABEL, same_dataset=False, targeted=args.targeted, white=args.white)


    shutdown_logging()

# !python ats/attack_transferability_score.py --model_target "ats/tuned_models/sanity_check_model_target/" --data_name "mnist" --attack_name "fgsm" --experiment_group "sanity_check" --experiment_name "sanity_check"