import os
import torch
import logging
from tqdm import tqdm
from torch import Tensor
from torch.nn import functional as F
from torchattacks import FGSM
from torch.utils.data import DataLoader
from ats.data.get_mnist import ROOT, IDX2LABEL
from torchvision.datasets import ImageFolder
from ats.attacks_meta  import get_attack_instance
import torchvision.transforms as vision_transforms
from ats.utils import show, entropy, cross_entropy
from vissl.data.ssl_transforms.mnist_img_pil_to_rgb_mode import MNISTImgPil2RGB
from vissl.data.ssl_transforms.img_pil_to_lab_tensor import ImgPil2LabTensor
from torchvision.datasets import MNIST, CIFAR10, CIFAR100, OxfordIIITPet, Flowers102, INaturalist
from vissl.trainer.train_task import SelfSupervisionTask
from vissl.data import (
    build_dataloader,
    build_dataset,
    print_sampler_config,
)
from vissl.data.ssl_transforms import get_transform


import pickle
import gzip
import numpy as np
from collections import defaultdict

#set up logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler())


def check_normalization(cfg):
    for i in cfg["DATA"]["TEST"]["TRANSFORMS"]:
        if 'normal' in i['name'].lower():
            return True
    return False 


def get_data_and_attack(model, data_name: str, attack_name: str = "fgsm", attack_meta_path: str = None, cfg: dict = None):
    """Helper function to load data, transformations and initialize adversarial attack method.
    Supported datasets: ('mnist', ).
    Supported models: ('resnet50', ).
    Supported attacks: All in torchattacks.

    Args:
        model: The first model to check adversarial transferability (model_proxy -> model_target).
        data_name: Dataset name.
        model_name: The first model name.
        attack_name: Adversarial attack algorithm.

    Returns:
        atk: Attacking algorithm.
        transform: Data transformations.
        inverse_transform: Inverse data transformations used for visualization.
        train_loader: Trainset dataloader.
    """
    attack, data_transform, inverse_data_transform, dataloader = None, None, None, None
    resize_size = cfg["DATA"]["TEST"]["TRANSFORMS"][0]['size']
    crop_size = cfg["DATA"]["TEST"]["TRANSFORMS"][1]['size']

    # task = SelfSupervisionTask.from_config(cfg)
    split = "TEST"
    # sampler_seed = task.config["SEED_VALUE"]

    dataset = build_dataset(cfg=cfg,
                split=split,
                current_train_phase_idx=0,
            )
    
    # data_transform = dataset.transform
    # print("data_transform: ", data_transform)
    # print(f"individual transforms: {data_transform.transforms}")
    
    # dataloader = build_dataloader(
    #         dataset=dataset,
    #         dataset_config=task.config["DATA"][split],
    #         num_dataloader_workers=16,
    #         pin_memory=True,
    #         multi_processing_method=task.config.MULTI_PROCESSING_METHOD,
    #         device=task.device,
    #         sampler_seed=sampler_seed,
    #         split=split.lower(),
    #     )

    data_transform = dataset.transform
    data_transform = vision_transforms.Compose(data_transform.transforms)

    try:
        # Pick the same transformations as used for training procedure, i.e. from config.DATA.TRANSFORMS.
        if data_name == "mnist":
            mean, std = torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])
            # data_transform = vision_transforms.Compose([MNISTImgPil2RGB(size=32, box=[2, 2]), vision_transforms.Resize(resize_size),
            #                                             vision_transforms.CenterCrop(crop_size),vision_transforms.ToTensor(),
            #                                             vision_transforms.Normalize(mean=mean, std=std)])
            
            inverse_data_transform = vision_transforms.Compose([vision_transforms.Normalize(mean=[0., 0., 0.], std=1.0/std),
                                                                vision_transforms.Normalize(mean=-mean, std=[1., 1., 1.])])
            # dataset = MNIST(root=ROOT, train=False, download=True, transform=data_transform)
            # dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=16, pin_memory=True)
        elif data_name == "cifar10":
            mean, std = torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])

            # data_transform = vision_transforms.Compose(data_transform.transforms)
            
            inverse_data_transform = vision_transforms.Compose([vision_transforms.Normalize(mean=[0., 0., 0.], std=1.0/std),
                                                                vision_transforms.Normalize(mean=-mean, std=[1., 1., 1.])])
            # dataset = CIFAR10(root=ROOT, train=False, download=True, transform=data_transform)
            dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=16, pin_memory=True)
        elif data_name == "cifar100":
            mean, std = torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])
            # data_transform = vision_transforms.Compose([vision_transforms.Resize(resize_size), vision_transforms.CenterCrop(crop_size),
            #                                             vision_transforms.ToTensor(),
            #                                             vision_transforms.Normalize(mean=mean, std=std)])
            inverse_data_transform = vision_transforms.Compose([vision_transforms.Normalize(mean=[0., 0., 0.], std=1.0/std),
                                                                vision_transforms.Normalize(mean=-mean, std=[1., 1., 1.])])
            # dataset = CIFAR100(root=ROOT, train=False, download=True, transform=data_transform)
            dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=16, pin_memory=True)            
        elif data_name == "oxford_flowers":
            mean, std = torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])
            # data_transform = vision_transforms.Compose([vision_transforms.Resize(resize_size), vision_transforms.CenterCrop(crop_size),
            #                                              vision_transforms.ToTensor(),
            #                                             vision_transforms.Normalize(mean=mean, std=std)])
            inverse_data_transform = vision_transforms.Compose([vision_transforms.Normalize(mean=[0., 0., 0.], std=1.0/std),
                                                                vision_transforms.Normalize(mean=-mean, std=[1., 1., 1.])])
            
            new_root = os.path.join(ROOT, 'oxford_flowers')
            new_root = os.path.join(new_root, 'test')
            # dataset = ImageFolder(root=new_root, transform=data_transform)
            dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=16, pin_memory=True)

        elif data_name == "oxford_pets":  # TODO: Maybe a duplicate of `"pet" in data_name` part?
            mean, std = torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])


            inverse_data_transform = vision_transforms.Compose([vision_transforms.Normalize(mean=[0., 0., 0.], std=1.0/std),
                                                                vision_transforms.Normalize(mean=-mean, std=[1., 1., 1.])])
            new_root = os.path.join(ROOT, "oxford_pets")
            new_root = os.path.join(new_root, "test")

            # dataset = ImageFolder(root=new_root, transform=data_transform)
            dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=16, pin_memory=True)

        elif data_name == "inaturalist2018":
            mean, std = torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])

            inverse_data_transform = vision_transforms.Compose([vision_transforms.Normalize(mean=[0., 0., 0.], std=1.0/std),
                                                                vision_transforms.Normalize(mean=-mean, std=[1., 1., 1.])])
            # dataset = INaturalist(root=ROOT, version="2018", download=True, transform=data_transform)
            dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=16, pin_memory=True)
        elif data_name == "imagenet1k":
            pass
        else:
            raise ValueError(f"Dataset '{data_name}' is not supported.")

        # Define the attack method and let the attack algorithm know when normalization is used as a data transformation.
        attack = get_attack_instance(attack_name, model, attack_meta_path)

        # if attack_name in ("fgsm", ):
        # attack = FGSM(model, eps=40/255)
        mean_list = mean.tolist() if mean is not None else None
        std_list = std.tolist() if mean is not None else None

        if check_normalization(cfg=cfg):
            attack.set_normalization_used(mean=mean_list, std=std_list)
        attack.set_return_type(type="float")
        
        # else:
        #     raise ValueError(f"Adversarial attack method '{attack_name}' is not supported.")
    except RuntimeError as e:
        logging.error(f"Error in getting data and attack: {e}")
        return attack, data_transform, inverse_data_transform, dataloader
    return attack, data_transform, inverse_data_transform, dataloader


def attack_transferability(model_proxy, model_target, dataloader, atk, device, checkpoint_dir,\
     inverse_transform, IDX2LABEL, same_dataset=False, targeted = False):
    # Iterating over the data and constructing adversarial examples.
    n_adv_examples, n_examples, total_attacks = 0, 0, 0
    logging.info(f"Running adversarial transferability experiment...\n")
    
    etropies ={}
    try:
        
        IDX2LABEL = defaultdict(lambda: "class_name", IDX2LABEL)
        for batch_idx, (batch) in enumerate(tqdm(dataloader)):

            images, labels = batch['data'], batch['label']
            images, labels = images[0].to(device), labels[0].to(device)

            images, labels = images.to(device), labels.to(device)
            # IDX2LABEL = {v: k for k, v in dataloader.dataset.class_to_idx.items()}

            # Selecting those examples on which model_proxy has succeeded in classification.
            original_predict_target, proxy_predict_target = model_target(images), model_proxy(images)

            # print(labels)
            # print("__"*20)
            if same_dataset:
                pred_1 = model_proxy(images).argmax(dim=1)
                correctly_classified: Tensor = (pred_1 == labels)
                images, labels = images[correctly_classified], labels[correctly_classified]

            # if not targeted:
            #     output_shape_proxy = proxy_predict_target.shape
            #     untargeted_labels = torch.randint(0, output_shape_proxy[1], (output_shape_proxy[0],)).to(device)
            #     labels_transferred = untargeted_labels
            # else:
            #     output_shape_proxy = proxy_predict_target.shape
            #     targeted_labels = proxy_predict_target.argmax(dim=1).to(device)
            #     labels_transferred = targeted_labels

            #     # if atk.targeted == False:
            #     target_map_function = lambda inputs, labels: targeted_labels
            #     atk.set_mode_targeted_by_function(target_map_function)

            if targeted:
                output_shape_proxy = proxy_predict_target.shape
                targeted_labels = torch.randint(0, output_shape_proxy[1], (output_shape_proxy[0],)).to(device)
                labels_transferred = targeted_labels

                # This is dark magic, i.e. the torch_attckas targeted random actually exlucdes GT labels
                # So we fake the GT labels with random, so its actually random. Can move towards the correct class
                # which is preferred.
                if atk.targeted == False:
                    atk.set_mode_targeted_random()
            else:
                untargeted_labels = proxy_predict_target.argmax(dim=1).to(device)
                labels_transferred = untargeted_labels


            adv_images = atk(images, labels_transferred)
            total_attacks += adv_images.shape[0]

            adv_predict_proxy, adv_predict_target = model_proxy(adv_images), model_target(adv_images)


            original_predict_target_softmax = F.softmax(original_predict_target, dim=1)
            adv_predict_target_softmax = F.softmax(adv_predict_target, dim=1)
            original_entropy = entropy(original_predict_target_softmax)
            adv_entropy = entropy(adv_predict_target_softmax)
            cross_entropy_orig_to_adv = cross_entropy(original_predict_target_softmax, adv_predict_target_softmax)
            cross_entropy_adv_to_orig = cross_entropy(adv_predict_target_softmax, original_predict_target_softmax)

            if "original_entropy" not in etropies:
                etropies["original_entropy"] = []
                etropies["adv_entropy"] = []
                etropies["cross_entropy_orig_to_adv"] = []
                etropies["cross_entropy_adv_to_orig"] = []

            etropies["original_entropy"] += [i.item() for i in original_entropy]
            etropies["adv_entropy"] += [i.item() for i in adv_entropy]
            etropies["cross_entropy_orig_to_adv"] += [i.item() for i in cross_entropy_orig_to_adv]
            etropies["cross_entropy_adv_to_orig"] += [i.item() for i in cross_entropy_adv_to_orig]

            adv_predict_proxy, adv_predict_target = adv_predict_proxy.argmax(dim=1), adv_predict_target.argmax(dim=1)
            original_predict_target, proxy_predict_target = original_predict_target.argmax(dim=1), proxy_predict_target.argmax(dim=1)

            is_adversarial: Tensor = torch.ones(adv_predict_proxy.shape[0], dtype=torch.bool)
            if same_dataset:
                is_adversarial: Tensor = (adv_predict_proxy != labels)
                
            n_adv_examples += (adv_predict_target[is_adversarial] != labels[is_adversarial]).sum().item()
            n_examples += adv_predict_target[is_adversarial].shape[0]

            # Save some predictions on both original and adversarial samples per 100 iterations.
            if batch_idx % 100 == 99:
                condition_of_change = adv_predict_target[is_adversarial] != original_predict_target[is_adversarial]
                # print(f"Condition of change: {condition_of_change}")
                

                kwargs=dict(original_inputs=inverse_transform(images[condition_of_change][:5]),
                    fakes=inverse_transform(adv_images[condition_of_change][:5]),
                    original_labels=original_predict_target[condition_of_change][:5],
                    fake_labels=adv_predict_target[condition_of_change][:5],
                    idx2label=IDX2LABEL, figsize=(10, 6), save_figure=True,
                    filename=f"{checkpoint_dir}/results/attack_example_{batch_idx}"
                )

                try:
                    show(**kwargs)
                except:
                    print("Could not save figure")
                # with open(f"{kwargs['filename']}.pkl", 'wb') as f:
                #     pickle.dump(kwargs, f, pickle.HIGHEST_PROTOCOL)

                
        # Save the entropy values
        with open(f"{checkpoint_dir}/results/entropy.pkl", "wb") as f:
            pickle.dump(etropies, f)

        # Logging the final results.
        transferability_acc = (n_adv_examples / n_examples) * 100
        model_proxy_adv_acc = (n_examples / total_attacks) * 100
        logging.info(f"Total attacks on model_proxy: {total_attacks}")
        logging.info(f"The number of successful attacks on model_proxy: {n_examples} ({round(model_proxy_adv_acc, 4)} %)")
        logging.info(f"The number of attacks that were also successful on model_target (transferability):\
        {n_adv_examples} ({round(transferability_acc, 4)} %)")
        logging.info(f"Averaged entropy of original predictions: {np.mean(etropies['original_entropy'])}")
        logging.info(f"Averaged entropy of adversarial predictions: {np.mean(etropies['adv_entropy'])}")
        logging.info(f"Averaged cross entropy of original to adversarial predictions: {np.mean(etropies['cross_entropy_orig_to_adv'])}")
        logging.info(f"Averaged cross entropy of adversarial to original predictions: {np.mean(etropies['cross_entropy_adv_to_orig'])}")


        # Save the results
        with open(f"{checkpoint_dir}/results/transferability.pkl", "wb") as f:
            pickle.dump({"total_attacks": total_attacks, "n_examples": n_examples, "model_proxy_adv_acc": model_proxy_adv_acc,\
                "n_adv_examples": n_adv_examples, "transferability_acc": transferability_acc}, f)
            
        logging.info(f"Adversarial blackbox experiment finished.\n")

    except RuntimeError as e:
        logging.error("Error in attack: ", e)
        return None
    return transferability_acc


def attack_blackbox(model, dataloader, atk, device, checkpoint_dir,\
     inverse_transform, IDX2LABEL, same_dataset:str =False, targeted:bool =False, whitebox:bool =False, **kwargs):

    # Iterating over the data and constructing adversarial examples.
    n_adv_examples, n_examples, total_attacks = 0, 0, 0
    logging.info(f"Running adversarial blackbox experiment...\n")

    try:
        etropies ={}
        IDX2LABEL = defaultdict(lambda: "class_name", IDX2LABEL)

        for batch_idx, (batch) in enumerate(tqdm(dataloader)):

            images, labels = batch['data'], batch['label']
            images, labels = images[0].to(device), labels[0].to(device)
        
            # images, labels = images.to(device), labels.to(device)

            predictions = model(images)

            if not whitebox:
                if targeted:
                    output_shape_proxy = predictions.shape
                    targeted_labels = torch.randint(0, output_shape_proxy[1], (output_shape_proxy[0],)).to(device)
                    labels_transferred = targeted_labels

                    # This is dark magic, i.e. the torch_attckas targeted random actually exlucdes GT labels
                    # So we fake the GT labels with random, so its actually random. Can move towards the correct class
                    # which is preferred.
                    if atk.targeted == False:
                        atk.set_mode_targeted_random()
                else:
                    untargeted_labels = predictions.argmax(dim=1).to(device)
                    labels_transferred = untargeted_labels
            else:
                labels_transferred = labels


            adv_images = atk(images, labels_transferred)
            adv_predict = model(adv_images)

            if "original_entropy" not in etropies:
                etropies["original_entropy"] = []
                etropies["adv_entropy"] = []
                etropies["cross_entropy_orig_to_adv"] = []
                etropies["cross_entropy_adv_to_orig"] = []

            original_entropy = entropy(predictions)
            adv_entropy = entropy(adv_predict)
            cross_entropy_orig_to_adv = cross_entropy(predictions, adv_predict)
            cross_entropy_adv_to_orig = cross_entropy(adv_predict, predictions)

            etropies["original_entropy"] += [i.item() for i in original_entropy]
            etropies["adv_entropy"] += [i.item() for i in adv_entropy]
            etropies["cross_entropy_orig_to_adv"] += [i.item() for i in cross_entropy_orig_to_adv]
            etropies["cross_entropy_adv_to_orig"] += [i.item() for i in cross_entropy_adv_to_orig]

            predict = predictions.argmax(dim=1)
            adv_predict = adv_predict.argmax(dim=1)

            is_adversarial: Tensor = torch.ones(adv_predict.shape[0], dtype=torch.bool)
            if same_dataset:
                is_adversarial: Tensor = (predict == labels)

            n_adv_examples += (adv_predict[is_adversarial] != labels[is_adversarial]).sum().item()
            n_examples += adv_predict[is_adversarial].shape[0]
            total_attacks += adv_images[is_adversarial].shape[0]

            if batch_idx % 100 == 99:
                pred = model(images[:5]).argmax(dim=1)
                try:
                    show(inverse_transform(images[:5]), inverse_transform(adv_images[:5]), pred, adv_predict[:5], IDX2LABEL,
                            figsize=(20, 7), save_figure=True, filename=f"{checkpoint_dir}/results/attack_example_{batch_idx}")
                except Exception as e:
                    logging.error(f"Error in plotting the results: {e}")

        fooling_rate = (n_adv_examples / n_examples) * 100

        with open(f"{checkpoint_dir}/results/entropy.pkl", "wb") as f:
            pickle.dump(etropies, f)

        logging.info(f"Total attacks on model_proxy: {total_attacks}")
        logging.info(f"The number of successful blackbox attacks on model: {n_adv_examples} ({round(fooling_rate, 4)} %)")

        with open(f"{checkpoint_dir}/results/transferability.pkl", "wb") as f:
            pickle.dump({"total_attacks": total_attacks, "n_examples": n_examples, "fooling_rate": fooling_rate}, f)
        

    except RuntimeError as e:
        logging.error("Error in attack: ", e)
        return None
    return fooling_rate