import os
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
import numpy as np
from sklearn.decomposition import PCA

from device import DEVICE
from datasets import normalize_images, denormalize_images, get_dimensions
from architectures import get_architecture
from attacks import select_attack
from tqdm import tqdm
from datasets import get_dataset

import multiprocessing
multiprocessing.set_start_method('fork', force=True)

class GaussianImageSampler:

    def __init__(self, dataset, model, attack_type, sigma, num_samples=1000, var_retain=0.95):
        """
        Initialize the GaussianImageSampler.

        Args:
        dataset (str): Name of the dataset.
        model (str): Name of the model architecture.
        attack_type (str): Type of adversarial attack.
        sigma (float): Noise standard deviation.
        num_samples (int): Number of samples to use for computing covariance matrix.
        var_retain (float): Proportion of variance to retain in PCA.
        """

        # Initialize the dataset name, model, and attack type
        self.dataset_name = dataset
        self.dataset = get_dataset(self.dataset_name,"test")
        self.model = model
        self.attack = attack_type

        # Get the architecture and adversarial attack method
        self.arch = get_architecture(self.model, self.dataset_name)
        self.adv_attack = select_attack(self.arch, self.attack)
            
        # Initialize the number of samples and PCA with the desired variance retention
        self.num_samples = num_samples
        self.pca = PCA(n_components=var_retain)

        # Compute the covariance matrix, sigma, and Cholesky decomposition
        self.covariance_matrix = self.compute_covariance_matrix()
        self.sigma = self.compute_sigma(sigma)
        self.L = self.compute_cholesky()

        self.norm = "Mahalanobis"

    def load_perturbations(self):
        """
        Load or generate adversarial perturbations.

        Returns:
        torch.Tensor: Loaded or generated perturbations.
        """

        # File path to save or load perturbations
        file_path = 'datasets/' + self.dataset_name + '_' + self.model + '_' + self.attack + '_perturbations.pt'

        if os.path.exists(file_path):
            # Load perturbations if they already exist
            print(f"Perturbations already defined.")
            perturbations = torch.load(file_path)
        else:
            # Create the datasets directory if it does not exist
            if not os.path.exists('datasets'): os.mkdir('datasets')

            # Select a random subset of indices from the dataset
            indices = np.random.choice(len(self.dataset), min(self.num_samples, len(self.dataset)), replace=False)
            sampler = SubsetRandomSampler(indices)
            dataloader = DataLoader(self.dataset, num_workers=0, batch_size=min(self.num_samples, 32), sampler=sampler)

            # Generate adversarial perturbations for the selected samples
            perturbations = []
            for images, labels in tqdm(dataloader, desc="Computing perturbations for Sigma"):
                # Denormalize CIFAR images if necessary
                if self.dataset_name == "cifar10": images = denormalize_images(self.dataset_name, images)

                # Send images to devices
                images = images.to(DEVICE) 
                labels = labels.to(DEVICE) 

                # Apply the adversarial attack to generate perturbed images
                adversarial_images = self.adv_attack(images.to(DEVICE), labels.to(DEVICE))
                prts = adversarial_images - images

                # Save the perturbations
                prts = prts.squeeze(0)
                perturbations.append(prts.cpu())

            # Concatenate all perturbations and save them to a file
            perturbations = torch.cat(perturbations)
            torch.save(perturbations, file_path)

        return perturbations

    def compute_covariance_matrix(self):
        """
        Compute the covariance matrix from the perturbations using PCA.

        Returns:
        torch.Tensor: The covariance matrix.
        """

        # Load perturbations and flatten them for PCA
        dataset = self.load_perturbations()
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=len(dataset), shuffle=False)
        images = next(iter(dataloader))
        flattened_images = images.view(images.size(0), -1).numpy()

        # Apply PCA to the flattened images
        transformed_images = self.pca.fit_transform(flattened_images)

        # Compute the covariance matrix of the PCA-transformed data
        mean_vector = np.mean(transformed_images, axis=0)
        centered_data = transformed_images - mean_vector
        covariance_matrix_pca = np.cov(centered_data, rowvar=False)
        covariance_matrix_pca = torch.from_numpy(covariance_matrix_pca)

        # Ensure numerical stability by adjusting small eigenvalues
        eigenvalues = torch.linalg.eigvals(covariance_matrix_pca)
        min_eigenvalue = torch.min(eigenvalues.real)
        covariance_matrix_pca += torch.eye(covariance_matrix_pca.shape[0], 
                                           device=covariance_matrix_pca.device) * (1e-6 - min_eigenvalue)

        return covariance_matrix_pca

    def compute_cholesky(self):
        """
        Compute the Cholesky decomposition of the covariance matrix.

        Returns:
        torch.Tensor: The Cholesky factor.
        """

        # Compute the Cholesky decomposition
        chol = torch.linalg.cholesky(self.covariance_matrix)
        if chol.dtype == torch.float64: chol = chol.to(torch.float32)

        return chol
    
    def compute_sigma(self, sigma):
        """
        Adjust the sigma value based on the determinant of the covariance matrix.

        Args:
        sigma (float): Initial sigma value.

        Returns:
        float: Adjusted sigma value.
        """

        # Compute the determinant of the covariance matrix
        determinant = torch.det(self.covariance_matrix).item()
        if determinant == 0 : determinant += 1e-6  # Handle numerical stability issues

        # Get the dimensionality of the dataset
        dimensions = get_dimensions(self.dataset_name)

        # Adjust sigma based on the determinant and dimensionality
        sigma = sigma * (determinant ** (1 / (2 * dimensions)))

        return sigma

    def sample_multivariate_normal_gpu(self, n_samples):
        """
        Sample from a multivariate normal distribution using GPU.

        Args:
        n_samples (int): Number of samples to generate.

        Returns:
        torch.Tensor: Generated samples.
        """

        # Generate mean vector and Cholesky factor on the GPU
        mean = torch.zeros(self.covariance_matrix.shape[0]).to(DEVICE)
        L = self.L.to(DEVICE)

        # Generate standard normal samples and scale by the Cholesky factor
        Z = torch.randn(n_samples, mean.shape[0], device=DEVICE)
        samples = mean + Z @ L.T

        # Expand the standard normal samples using the PCA components
        if not hasattr(self.pca, 'components_'):
            perturbations = self.load_perturbations()
            self.pca.fit_transform(perturbations)
        sample = np.dot(samples.cpu().numpy(), self.pca.components_)
        sample = torch.from_numpy(sample).to(DEVICE)

        return sample

    def sample_smoothed_images(self, x, num_samples):
        """
        Generate smoothed images by adding adversarial noise.

        Args:
        x (torch.Tensor): Original image.
        num_samples (int): Number of samples to generate.

        Returns:
        torch.Tensor: Smoothed images.
        """

        # Denormalize CIFAR images if necessary
        if self.dataset_name == "cifar10": x = denormalize_images(self.dataset_name, x)

        # Generate a batch of repeated original images
        batch = x.repeat((num_samples, 1, 1, 1))

        # Generate adversarial perturbations and scale by sigma
        samples = self.sample_multivariate_normal_gpu(num_samples) * self.sigma
        sampled_images = samples.view(num_samples, *x.shape[0:])

        # Generate smoothed images by adding noise to the original batch
        smoothed_images = batch.to(DEVICE) + sampled_images.to(DEVICE)

        # Normalize and clamp CIFAR images if necessary
        if self.dataset_name == "cifar10":
            smoothed_images = normalize_images(self.dataset_name, smoothed_images)
        smoothed_images = smoothed_images.clamp(0, 1)

        return smoothed_images


class GaussianSampler:

    def __init__(self, sigma):
        """
        Initialize the GaussianSampler.

        Args:
        sigma (float): Standard deviation of the noise to be added.
        """

        self.sigma = sigma
        self.norm = "L2"

    def sample_smoothed_images(self, x, num_samples):
        """
        Sample new images from the Gaussian distribution.

        Args:
        x (torch.Tensor): A tensor that specifies the shape of the output.
        num_samples (int): Number of images to sample.

        Returns:
        torch.Tensor: The sampled images, reshaped to original image dimensions.
        """

        # Generate a batch of repeated original images
        batch = x.repeat((num_samples, 1, 1, 1)).to(DEVICE)

        # Generate Gaussian noise and add it to the batch
        noise = torch.randn_like(batch, device=DEVICE) * self.sigma 
        smoothed_images = batch + noise
        smoothed_images = smoothed_images.clamp(0, 1)

        return smoothed_images
    
class LOneSampler:

    def __init__(self, sigma):
        """
        Initialize the GaussianSampler.

        Args:
        sigma (float): Standard deviation of the noise to be added.
        """

        self.sigma = sigma
        self.norm = "L1"

    def sample_smoothed_images(self, x, num_samples):
        """
        Sample new images from the Gaussian distribution.

        Args:
        x (torch.Tensor): A tensor that specifies the shape of the output.
        num_samples (int): Number of images to sample.

        Returns:
        torch.Tensor: The sampled images, reshaped to original image dimensions.
        """

        # Generate a batch of repeated original images
        batch = x.repeat((num_samples, 1, 1, 1)).to(DEVICE)

        # Generate Gaussian noise and add it to the batch
        laplace_dist = torch.distributions.Laplace(loc=0.0, scale=self.sigma)
        noise = laplace_dist.sample(batch.shape).to(DEVICE)
        smoothed_images = batch + noise
        smoothed_images = smoothed_images.clamp(0, 1)

        return smoothed_images
  
    
class UnifSampler:
    def __init__(self, sigma):
        """
        Initialize the GaussianImageSampler.
        """

        self.sigma = sigma
        self.norm = "Unif"

    def sample_smoothed_images(self, x, num_samples):
        """
        Sample new images from the Gaussian distribution.

        Args:
        x (torc.Tensor): A tensor that specifies the shape of the output
        num_samples (int): Number of images to sample.

        Returns:
        torch.Tensor: The sampled images, reshaped to original image dimensions.
        """

        # Generate batch of an original image
        batch = x.repeat((num_samples, 1, 1, 1)).to(DEVICE)
        noise =  -1 + self.sigma * torch.rand(batch.shape)
        smoothed_images = batch + noise.to(DEVICE)
        smoothed_images = smoothed_images.clamp(0, 1)

        return smoothed_images