import os
import torch
from torch.utils.data import Dataset, DataLoader
import torchattacks
import numpy as np
from torch.distributions.multivariate_normal import MultivariateNormal

from torchvision import datasets, transforms
from torchvision.transforms import ToPILImage

from architectures import get_architecture
from datasets import get_dataset, normalize_images, denormalize_images
import device
from device import DEVICE

import ssl
ssl._create_default_https_context = ssl._create_stdlib_context

ATTACKS = [
    'fgsm',
    'bim',
    'basiciterative',
    'pgd',
    'deepfool',
    'cw',
    'ead',
    'square',
    'autoattack',
    'onepixel',
    'spatialtransformation',
    'jsma',
    'saltandpepper',
    'df',
    'apgd',
    'multiattack',
    'tifgsm'
]

def select_attack(model, attack_type, **kwargs):
    """
    Automatically select and apply an attack from torchattacks based on input parameters.

    Args:
        model: PyTorch model to attack.
        attack_type: A string indicating the type of attack or context (e.g., 'fgsm', 'pgd').
        **kwargs: Additional parameters that influence the choice of attack.

    Returns:
        attack: The selected attack from torchattacks.
    """
    
    # Lowercase the attack_type for standardization
    attack_type = attack_type.lower()
    
    # Initialize the attack variable
    attack = None

    # Select the attack based on attack_type
    if attack_type == 'fgsm':
        epsilon = kwargs.get('epsilon', 1.0)
        attack = torchattacks.FGSM(model, eps=epsilon)
    
    elif attack_type == 'bim' or attack_type == 'basiciterative':
        epsilon = kwargs.get('epsilon', 0.3)
        alpha = kwargs.get('alpha', 0.01)
        steps = kwargs.get('steps', 40)
        attack = torchattacks.BIM(model, eps=epsilon, alpha=alpha, steps=steps)
    
    elif attack_type == 'pgd':
        epsilon = kwargs.get('epsilon', 0.3)
        alpha = kwargs.get('alpha', 2/255)
        steps = kwargs.get('steps', 40)
        attack = torchattacks.PGD(model, eps=epsilon, alpha=alpha, steps=steps)
    
    elif attack_type == 'deepfool':
        steps = kwargs.get('steps', 50)
        overshoot = kwargs.get('overshoot', 0.02)
        attack = torchattacks.DeepFool(model, steps=steps, overshoot=overshoot)
    
    elif attack_type == 'cw':
        c = kwargs.get('c', 1e-4)
        steps = kwargs.get('steps', 1000)
        lr = kwargs.get('lr', 0.01)
        attack = torchattacks.CW(model, c=c, steps=steps, lr=lr)
    
    elif attack_type == 'ead':
        eps = kwargs.get('eps', 0.3)
        beta = kwargs.get('beta', 0.01)
        steps = kwargs.get('steps', 100)
        attack = torchattacks.EAD(model, eps=eps, beta=beta, steps=steps)
    
    elif attack_type == 'square':
        eps = kwargs.get('eps', 1.0)
        n_queries = kwargs.get('n_queries', 5000)
        attack = torchattacks.Square(model, eps=eps, n_queries=n_queries)
    
    elif attack_type == 'autoattack':
        eps = kwargs.get('eps', 0.3)
        n_classes = kwargs.get('n_classes', 10)
        version = kwargs.get('version', 'standard')
        attack = torchattacks.AutoAttack(model, eps=eps, n_classes=n_classes, version=version)

    elif attack_type == 'onepixel':
        pixels = kwargs.get('pixels', 5)
        steps = kwargs.get('steps', 75)
        popsize = kwargs.get('popsize', 400)
        inf_batch = kwargs.get('inf_batch', 128)
        attack = torchattacks.OnePixel(model, pixels=pixels, steps=steps, popsize=popsize, inf_batch=inf_batch)

    elif attack_type == 'spatialtransformation':
        max_translation = kwargs.get('max_translation', 3)
        num_translations = kwargs.get('num_translations', 20)
        max_rotation = kwargs.get('max_rotation', 30)
        num_rotations = kwargs.get('num_rotations', 30)
        attack = torchattacks.SpatialTransformation(model, max_translation=max_translation, num_translations=num_translations, max_rotation=max_rotation, num_rotations=num_rotations)
    
    elif attack_type == 'jsma':
        theta = kwargs.get('theta', 1.0)
        gamma = kwargs.get('gamma', 0.1)
        attack = torchattacks.JSMA(model, theta=theta, gamma=gamma)
    
    elif attack_type == 'saltandpepper':
        amount = kwargs.get('amount', 0.1)
        attack = torchattacks.SaltAndPepper(model, amount=amount)

    elif attack_type == 'df':
        steps = kwargs.get('steps', 50)
        overshoot = kwargs.get('overshoot', 0.02)
        attack = torchattacks.DeepFool(model, steps=steps, overshoot=overshoot)

    elif attack_type == 'apgd':
        eps = kwargs.get('eps', 0.3)
        n_restarts = kwargs.get('n_restarts', 5)
        n_iter = kwargs.get('n_iter', 100)
        loss = kwargs.get('loss', 'ce')
        attack = torchattacks.APGD(model, eps=eps, n_restarts=n_restarts, n_iter=n_iter, loss=loss)

    elif attack_type == 'multiattack':
        attacks = kwargs.get('attacks', [torchattacks.Square(model), torchattacks.FGSM(model)])
        attack = torchattacks.MultiAttack(attacks)

    elif attack_type == 'tifgsm':
        epsilon = kwargs.get('epsilon', 0.3)
        alpha = kwargs.get('alpha', 0.01)
        steps = kwargs.get('steps', 40)
        decay = kwargs.get('decay', 1.0)
        attack = torchattacks.TIFGSM(model, eps=epsilon, alpha=alpha, steps=steps, decay=decay)
    
    else:

        print(f"Unrecognized attack type '{attack_type}'.")
    
    return attack



class AdversarialAttackApplier:

    def __init__(self, model, dataset, adversarial_attack):
        """
        Initialize the class with the model, dataset, and adversarial attack type.
        """
        self.dataset = dataset
        self.arch = model.to(DEVICE)
        self.attack = select_attack(self.arch, adversarial_attack)

    def apply_adversarial_with_probability(self, batch, targets, probability=0.5):
        """
        Apply the adversarial attack with a certain probability.
        """

        # Denormalize input for CIFAR10
        if self.dataset == "cifar10":
            batch = denormalize_images(self.dataset, batch)

        # Create a random mask
        batch = batch.to(DEVICE)
        mask = torch.rand(batch.size(0), 1, 1, 1, device=batch.device) < probability

        # Add adversarial exampels to batch
        if mask.any():
            with torch.no_grad(): 
                try:
                    adv_batch = self.attack(batch[mask.view(-1)], targets[mask.view(-1)])
                except:
                    adv_batch = batch[mask.view(-1)]
                batch[mask.view(-1)] = adv_batch  

        # Normalize input for CIFAR10
        if self.dataset == "cifar10":
            batch = normalize_images(self.dataset, batch)

        return batch
