import os
import torch
import logging
from tqdm import tqdm
from torch import Tensor
from ats.utils import show
from torchattacks import FGSM
from torch.utils.data import DataLoader
from data.get_mnist import ROOT, IDX2LABEL
from torchvision.datasets import ImageFolder
from ats.attacks_meta  import get_attack_instance
import torchvision.transforms as vision_transforms
from vissl.data.ssl_transforms.mnist_img_pil_to_rgb_mode import MNISTImgPil2RGB
from torchvision.datasets import MNIST, CIFAR10, CIFAR100, OxfordIIITPet, Flowers102, INaturalist


#set up logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler())


def get_data_and_attack(model, data_name: str, attack_name: str = "fgsm", attack_meta_path: str = None):
    """Helper function to load data, transformations and initialize adversarial attack method.
    Supported datasets: ('mnist', ).
    Supported models: ('resnet50', ).
    Supported attacks: All in torchattacks.

    Args:
        model: The first model to check adversarial transferability (model_proxy -> model_target).
        data_name: Dataset name.
        model_name: The first model name.
        attack_name: Adversarial attack algorithm.

    Returns:
        atk: Attacking algorithm.
        transform: Data transformations.
        inverse_transform: Inverse data transformations used for visualization.
        train_loader: Trainset dataloader.
    """
    attack, data_transform, inverse_data_transform, dataloader = None, None, None, None
    try:
        # Pick the same transformations as used for training procedure, i.e. from config.DATA.TRANSFORMS.
        if data_name == "mnist":
            mean, std = torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])
            data_transform = vision_transforms.Compose([MNISTImgPil2RGB(size=32, box=[2, 2]), vision_transforms.Resize(224),
                                                        vision_transforms.ToTensor(),
                                                        vision_transforms.Normalize(mean=mean, std=std)])
            inverse_data_transform = vision_transforms.Compose([vision_transforms.Normalize(mean=[0., 0., 0.], std=1.0/std),
                                                                vision_transforms.Normalize(mean=-mean, std=[1., 1., 1.])])
            dataset = MNIST(root=ROOT, train=False, download=True, transform=data_transform)
            dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=16, pin_memory=True)
        elif data_name == "cifar10":
            mean, std = torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])
            data_transform = vision_transforms.Compose([vision_transforms.Resize(224), vision_transforms.ToTensor(),
                                                        vision_transforms.Normalize(mean=mean, std=std)])
            inverse_data_transform = vision_transforms.Compose([vision_transforms.Normalize(mean=[0., 0., 0.], std=1.0/std),
                                                                vision_transforms.Normalize(mean=-mean, std=[1., 1., 1.])])
            dataset = CIFAR10(root=ROOT, train=False, download=True, transform=data_transform)
            dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=16, pin_memory=True)
        elif data_name == "cifar100":
            mean, std = torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])
            data_transform = vision_transforms.Compose([vision_transforms.Resize(224), vision_transforms.ToTensor(),
                                                        vision_transforms.Normalize(mean=mean, std=std)])
            inverse_data_transform = vision_transforms.Compose([vision_transforms.Normalize(mean=[0., 0., 0.], std=1.0/std),
                                                                vision_transforms.Normalize(mean=-mean, std=[1., 1., 1.])])
            dataset = CIFAR100(root=ROOT, train=False, download=True, transform=data_transform)
            dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=16, pin_memory=True)
        elif data_name == "oxford_flowers":
            mean, std = torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])
            data_transform = vision_transforms.Compose([vision_transforms.Resize((224,224)), vision_transforms.ToTensor(),
                                                        vision_transforms.Normalize(mean=mean, std=std)])
            inverse_data_transform = vision_transforms.Compose([vision_transforms.Normalize(mean=[0., 0., 0.], std=1.0/std),
                                                                vision_transforms.Normalize(mean=-mean, std=[1., 1., 1.])])
            
            new_root = os.path.join(ROOT, 'oxford_flowers')
            new_root = os.path.join(new_root, 'test')
            dataset = ImageFolder(root=new_root, transform=data_transform)
            dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=32, pin_memory=True)

        elif data_name == "oxford_pets":  # TODO: Maybe a duplicate of `"pet" in data_name` part?
            mean, std = torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])
            data_transform = vision_transforms.Compose([vision_transforms.Resize((224,224)), vision_transforms.ToTensor(),
                                                        vision_transforms.Normalize(mean=mean, std=std)])
            inverse_data_transform = vision_transforms.Compose([vision_transforms.Normalize(mean=[0., 0., 0.], std=1.0/std),
                                                                vision_transforms.Normalize(mean=-mean, std=[1., 1., 1.])])
            new_root = os.path.join(ROOT, "oxford_pets")
            new_root = os.path.join(new_root, "test")

            dataset = ImageFolder(root=new_root, transform=data_transform)
            dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=32, pin_memory=True)

        elif data_name == "inaturalist2018":
            mean, std = torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])
            data_transform = vision_transforms.Compose([vision_transforms.Resize(224), vision_transforms.ToTensor(),
                                                        vision_transforms.Normalize(mean=mean, std=std)])
            inverse_data_transform = vision_transforms.Compose([vision_transforms.Normalize(mean=[0., 0., 0.], std=1.0/std),
                                                                vision_transforms.Normalize(mean=-mean, std=[1., 1., 1.])])
            dataset = INaturalist(root=ROOT, version="2018", download=True, transform=data_transform)
            dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=16, pin_memory=True)
        elif data_name == "imagenet1k":
            pass
        else:
            raise ValueError(f"Dataset '{data_name}' is not supported.")

        # Define the attack method and let the attack algorithm know when normalization is used as a data transformation.
        attack = get_attack_instance(attack_name, model, attack_meta_path)

        # if attack_name in ("fgsm", ):
        # attack = FGSM(model, eps=40/255)
        mean_list = mean.tolist() if mean is not None else None
        std_list = std.tolist() if mean is not None else None
        attack.set_normalization_used(mean=mean_list, std=std_list)
        attack.set_return_type(type="float")
        
        # else:
        #     raise ValueError(f"Adversarial attack method '{attack_name}' is not supported.")
    except RuntimeError as e:
        logging.error(f"Error in getting data and attack: {e}")
        return attack, data_transform, inverse_data_transform, dataloader
    return attack, data_transform, inverse_data_transform, dataloader


def attack_transferability(model_proxy, model_target, dataloader, atk, device, checkpoint_dir,\
     inverse_transform, IDX2LABEL, same_dataset=False):
    # Iterating over the data and constructing adversarial examples.
    n_adv_examples, n_examples, total_attacks = 0, 0, 0
    logging.info(f"Running adversarial transferability experiment...\n")
        
    try:

        for batch_idx, (images, labels) in enumerate(tqdm(dataloader)):
            images, labels = images.to(device), labels.to(device)

            # Selecting those examples on which model_proxy has succeeded in classification.
            
            if same_dataset:
                pred_1 = model_proxy(images).argmax(dim=1)
                correctly_classified: Tensor = (pred_1 == labels)
                images, labels = images[correctly_classified], labels[correctly_classified]

            # Creating adversarial examples.
            adv_images = atk(images, labels)
            total_attacks += adv_images.shape[0]
            adv_predict_proxy, adv_predict_target = model_proxy(adv_images).argmax(dim=1), model_target(adv_images).argmax(dim=1)

            # Picking successful adversarial examples on model_proxy to estimate transferability on model_target.
            # is_adversarial is a binary mask representing whether adv_images[i] is adversarial or not.
            is_adversarial: Tensor = torch.ones(adv_predict_proxy.shape[0], dtype=torch.bool)
            if same_dataset:
                is_adversarial: Tensor = (adv_predict_proxy != labels)
                
            n_adv_examples += (adv_predict_target[is_adversarial] != labels[is_adversarial]).sum().item()
            n_examples += adv_predict_target[is_adversarial].shape[0]

            # Save some predictions on both original and adversarial samples per 100 iterations.
            if batch_idx % 100 == 99:
                pred = model_target(images[:5]).argmax(dim=1)
                show(inverse_transform(images[:5]), inverse_transform(adv_images[:5]), pred, adv_predict_target[:5], IDX2LABEL,
                        figsize=(20, 7), save_figure=True, filename=f"{checkpoint_dir}/results/attack_example_{batch_idx}")

        # Logging the final results.
        transferability_acc = (n_adv_examples / n_examples) * 100
        model_proxy_adv_acc = (n_examples / total_attacks) * 100
        logging.info(f"Total attacks on model_proxy: {total_attacks}")
        logging.info(f"The number of successful attacks on model_proxy: {n_examples} ({round(model_proxy_adv_acc, 4)} %)")
        logging.info(f"The number of attacks that were also successful on model_target (transferability):\
        {n_adv_examples} ({round(transferability_acc, 4)} %)")

    except RuntimeError as e:
        logging.error("Error in attack: ", e)
        return None
    return transferability_acc

def attack_blackbox(model, dataloader, atk, device, checkpoint_dir,\
     inverse_transform, IDX2LABEL, same_dataset:str =False, attack_budget: int = 1000):

    # Iterating over the data and constructing adversarial examples.
    n_adv_examples, n_examples, total_attacks = 0, 0, 0
    logging.info(f"Running adversarial blackbox experiment...\n")

    try:
        for batch_idx, (images, labels) in enumerate(tqdm(dataloader)):
            images, labels = images.to(device), labels.to(device)

            predictions = model(images).argmax(dim=1)
            adv_images = atk(images, labels)
            total_attacks += adv_images.shape[0]
            adv_predict = model(adv_images).argmax(dim=1)


            n_adv_examples += (adv_predict != labels).sum().item()
            n_examples += adv_predict.shape[0]
        
            if batch_idx % 100 == 99:
                pred = model(images[:5]).argmax(dim=1)
                show(inverse_transform(images[:5]), inverse_transform(adv_images[:5]), pred, adv_predict[:5], IDX2LABEL,
                        figsize=(20, 7), save_figure=True, filename=f"{checkpoint_dir}/results/attack_example_{batch_idx}")

        fooling_rate = (n_adv_examples / n_examples) * 100


        logging.info(f"Total attacks on model_proxy: {total_attacks}")
        logging.info(f"The number of successful blackbox attacks on model: {n_examples} ({round(fooling_rate, 4)} %)")

    except RuntimeError as e:
        logging.error("Error in attack: ", e)
        return None
    return fooling_rate
    