import random
from typing import List, Any

import matplotlib
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from numpy import ndarray, dtype, bool_
from torch import Tensor
from torch.utils.data import DataLoader, Dataset, TensorDataset, Subset
from tqdm import tqdm
from sklearn.decomposition import PCA

from nn_model_manager import NNModelManager
from utils import log_prints

matplotlib.use('Qt5Agg')


@log_prints()
def full_model_mimic_attack(
        nn_mm_mimic: NNModelManager, nn_mm: NNModelManager,
        train_subset: Subset, test_subset: Subset,
        test_subset_dataloader: DataLoader, target_image: Tensor, target_label: int,
        mean_dataset: List[float], std_dataset: List[float],
        batch_size: int, cycles_of_mimic_attack: int = 10,
        n_attack_samples: int = 100,
        epsilon_min_bound: float = 0.01, epsilon_max_bound: float = 0.1,
        additional_dataset_flag: bool = True, attack_type="pgd",
        show_image_flag: bool = False, n_components: int = None,
        span_project_flag: bool = True,
        **kwargs,
):
    model_mimic_attack_init_train(
        nn_mm_mimic=nn_mm_mimic, nn_mm=nn_mm,
        train_subset=train_subset, test_subset=test_subset,
        target_image=target_image, test_subset_dataloader=test_subset_dataloader,
    )

    adv_dataset_for_train = create_adv_dataset_and_eval(
        nn_mm_mimic=nn_mm_mimic, target_image=target_image, target_label=target_label,
        mean_dataset=mean_dataset, std_dataset=std_dataset,
        batch_size=batch_size, nn_mm=nn_mm, train_subset=train_subset, test_subset=test_subset,
        show_image_flag=show_image_flag, n_attack_samples=n_attack_samples,
        epsilon_min_bound=epsilon_min_bound, epsilon_max_bound=epsilon_max_bound,
        additional_dataset_flag=additional_dataset_flag, attack_type=attack_type,
        n_components=n_components, span_project_flag=span_project_flag,
        kwargs=kwargs,
    )

    for i in range(cycles_of_mimic_attack):
        print(f"Start {i + 1} mimic attack cycle on adv_sample. Complete {i}/{cycles_of_mimic_attack} cycles.")
        model_mimic_attack_1_cycle_adv(
            nn_mm_mimic=nn_mm_mimic, nn_mm=nn_mm, adv_dataset_for_train=adv_dataset_for_train,
            target_image=target_image,
        )

        adv_dataset_for_train = create_adv_dataset_and_eval(
            nn_mm_mimic=nn_mm_mimic, target_image=target_image, target_label=target_label,
            mean_dataset=mean_dataset, std_dataset=std_dataset,
            batch_size=batch_size, nn_mm=nn_mm, train_subset=train_subset, test_subset=test_subset,
            show_image_flag=True if show_image_flag or i + 1 == cycles_of_mimic_attack else False,
            # show_image_flag=show_image_flag,
            n_attack_samples=n_attack_samples,
            epsilon_min_bound=epsilon_min_bound, epsilon_max_bound=epsilon_max_bound,
            additional_dataset_flag=additional_dataset_flag, attack_type=attack_type,
            n_components=n_components, span_project_flag=span_project_flag,
            kwargs=kwargs,
        )


@log_prints()
def create_adv_dataset_and_eval(
        nn_mm_mimic: NNModelManager, target_image: Tensor, target_label: int,
        mean_dataset: List[float], std_dataset: List[float], batch_size: int, nn_mm: NNModelManager,
        train_subset: Subset, test_subset: Subset, n_attack_samples: int = 100,
        epsilon_min_bound: float = 0.01, epsilon_max_bound: float = 0.1,
        additional_dataset_flag: bool = True, attack_type="pgd", n_components: int = None,
        span_project_flag=True,
        show_image_flag: bool = False, **kwargs,
) -> TensorDataset:
    adv_dataset, adversarial_images_for_train, successful_adv_samples_distance = attack_vary(
        model=nn_mm_mimic.model,
        target_image=target_image,
        target_label=target_label,
        train_dataset=train_subset,
        epsilon_min_bound=epsilon_min_bound,
        epsilon_max_bound=epsilon_max_bound,
        n_attack_samples=n_attack_samples,
        mean_dataset=mean_dataset,
        std_dataset=std_dataset,
        attack_type=attack_type,
        batch_size=batch_size,
        n_components=n_components,
        span_project_flag=span_project_flag,
        kwargs=kwargs,
    )

    # show_image(target_image[0][0])
    # show_image(adv_dataset.tensors[0][0:1])
    # show_image(adv_dataset.tensors[0][-1:])

    print(f"\nMax distance between successful adv sample and target image: "
          f"{successful_adv_samples_distance.max():.4f}.\n"
          f"Min distance between successful adv sample and target image: "
          f"{successful_adv_samples_distance.min():.4f}.\n"
          f"Mean distance between successful adv sample and target image: "
          f"{successful_adv_samples_distance.mean():.4f}.\n"
          f"Var distance between successful adv sample and target image: "
          f"{successful_adv_samples_distance.var():.4f}.\n"
          )
    adv_dataloader = DataLoader(adv_dataset, batch_size=batch_size, shuffle=True)
    acc_full = nn_mm.evaluate_model(dataloader=adv_dataloader)
    print(f"Origin model accuracy on adv test data: {acc_full:.4f}")
    if acc_full <= 0.5 and len(adv_dataloader) > 0.25 * n_attack_samples:
        min_index = torch.argmin(successful_adv_samples_distance)
        max_index = torch.argmax(successful_adv_samples_distance)
        show_images_diff(target_image[0][0], adv_dataset.tensors[0][max_index:max_index + 1], mean_dataset, std_dataset)
        show_images_diff(target_image[0][0], adv_dataset.tensors[0][min_index:min_index + 1], mean_dataset, std_dataset)
        # raise Exception("Successful attack")
    # else:
    predicted_origin_model_labels = nn_mm.create_labels(adversarial_images_for_train)
    if additional_dataset_flag:
        dataloader = DataLoader(train_subset, batch_size=len(train_subset), shuffle=False)
        inputs_train, labels_train = next(iter(dataloader))
        dataloader = DataLoader(test_subset, batch_size=len(test_subset), shuffle=False)
        inputs_test, labels_test = next(iter(dataloader))
        adversarial_images_for_train = torch.cat(
            tensors=[adversarial_images_for_train,
                     inputs_test,
                     inputs_train],
            dim=0
        )
        predicted_origin_model_labels = torch.cat(
            tensors=[predicted_origin_model_labels,
                     labels_test,
                     labels_train],
            dim=0
        )
    adv_dataset_for_train = TensorDataset(adversarial_images_for_train, predicted_origin_model_labels)

    if show_image_flag:
        min_index = torch.argmin(successful_adv_samples_distance)
        max_index = torch.argmax(successful_adv_samples_distance)
        show_images_diff(target_image[0][0], adv_dataset.tensors[0][max_index:max_index + 1], mean_dataset, std_dataset)
        show_images_diff(target_image[0][0], adv_dataset.tensors[0][min_index:min_index + 1], mean_dataset, std_dataset)
    return adv_dataset_for_train


def model_mimic_attack_init_train(
        nn_mm_mimic: NNModelManager, nn_mm: NNModelManager, train_subset: Subset,
        test_subset: Subset, target_image: Tensor, test_subset_dataloader: DataLoader,
) -> None:
    try:
        # raise FileNotFoundError()
        nn_mm_mimic.save_name = 0
        nn_mm_mimic.load(train_mode="mimic")
    except FileNotFoundError:
        nn_mm_mimic.save_name = None
        nn_mm_mimic.mimic_train(origin_model=nn_mm.model,
                                dataset=(train_subset, test_subset),
                                target_image=target_image, )
    acc_full = nn_mm_mimic.evaluate_model(dataloader=test_subset_dataloader)
    print(f"Mimic model accuracy on test subset data: {acc_full:.4f}")
    acc = nn_mm_mimic.mimic_evaluate(dataloader=DataLoader(target_image), origin_model=nn_mm.model)
    print(f"Mimic model accuracy on target_image: {acc:.4f}")
    # assert acc == 1, "Accuracy on target_image != 1 after train small model, so code must be stopped"


@log_prints()
def model_mimic_attack_1_cycle_adv(nn_mm_mimic: NNModelManager, nn_mm: NNModelManager,
                                   adv_dataset_for_train: TensorDataset, target_image: Tensor) -> None:
    train_subset, test_subset = random_dataset_subset(
        dataset=adv_dataset_for_train, subset_size=-1, test_part=0.1, target_image_flag=False
    )
    try:
        raise FileNotFoundError()
        # nn_mm_mimic.save_name = 1
        # nn_mm_mimic.load(train_mode="mimic")
    except FileNotFoundError:
        nn_mm_mimic.save_name = None
        nn_mm_mimic.mimic_train(origin_model=nn_mm.model,
                                dataset=(train_subset, test_subset),
                                target_image=target_image,
                                additional_train_flag=False
                                )
    acc = nn_mm_mimic.mimic_evaluate(dataloader=DataLoader(target_image), origin_model=nn_mm.model)
    print(f"Mimic model accuracy on target_image: {acc:.4f}")
    # assert acc == 1, "Accuracy on target_image != 1 after train small model, so code must be stopped"


def random_dataset_subset(dataset, subset_size=600, test_part=0.2, target_image_flag=True) -> [Subset, Subset]:
    if subset_size == -1:
        subset_size = len(dataset)
    target_shift = 0 if target_image_flag else 1
    random_indices = random.sample(range(len(dataset) - target_shift), subset_size - target_shift)
    if target_image_flag:
        train_subset = Subset(dataset, random_indices[int(subset_size * test_part):])
        if test_part == 0:
            test_subset = Subset(dataset, random_indices[:int(subset_size * 0.1)])
        else:
            test_subset = Subset(dataset, random_indices[:int(subset_size * test_part)])
        target_image = Subset(dataset, random_indices[-1:])
        if isinstance(dataset.targets, list):
            target_label = dataset.targets[random_indices[-1:][0]]
        elif isinstance(dataset.targets, Tensor):
            target_label = dataset.targets[random_indices[-1:]].item()
        else:
            raise Exception
        return train_subset, test_subset, target_image, target_label
    else:
        train_subset = Subset(dataset, random_indices[int(subset_size * test_part):] + [len(dataset) - target_shift])
        if test_part == 0:
            test_subset = Subset(dataset, random_indices[:int(subset_size * 0.1)])
        else:
            test_subset = Subset(dataset, random_indices[:int(subset_size * test_part)])
        return train_subset, test_subset


def show_image(image_tensor: Tensor) -> None:
    image_tensor = image_tensor.squeeze(0)
    image_np = image_tensor.permute(1, 2, 0).numpy()
    plt.imshow(image_np)
    plt.axis('off')
    plt.show()


def show_images_diff(target_image: Tensor, perturbed_image: Tensor,
                     mean_dataset: List[float], std_dataset: List[float]) -> None:
    target_image_np = target_image.squeeze().detach().numpy().transpose(1, 2, 0)
    perturbed_image_np = perturbed_image.squeeze().detach().numpy().transpose(1, 2, 0)

    def denormalize(img: ndarray[Any, dtype[bool_]], mean: List[float], std: List[float]) -> ndarray[Any, dtype[bool_]]:
        mean = np.array(mean)
        std = np.array(std)
        denormalized_img = img * std + mean
        return denormalized_img

    target_image_np = denormalize(target_image_np, mean_dataset, std_dataset)
    perturbed_image_np = denormalize(perturbed_image_np, mean_dataset, std_dataset)

    diff_image = perturbed_image_np - target_image_np
    highlighted_image = target_image_np.copy()
    diff_threshold = diff_image.mean()
    mask = np.max(diff_image, axis=2) > diff_threshold
    mask_full = np.max(diff_image, axis=2) > 0
    highlighted_image[mask_full] = [1, 0, 1]
    highlighted_image[mask] = [1, 0, 0]

    fig, axs = plt.subplots(2, 2, figsize=(10, 10))

    axs[0, 0].imshow(target_image_np, cmap='gray', vmin=0, vmax=1)
    axs[0, 0].set_title('Target Image')
    axs[0, 0].axis('off')

    axs[0, 1].imshow(perturbed_image_np, cmap='gray', vmin=0, vmax=1)
    axs[0, 1].set_title('Perturbed Image')
    axs[0, 1].axis('off')

    axs[1, 0].imshow(np.clip(diff_image, 0, 1), cmap='gray', vmin=0, vmax=1)
    axs[1, 0].set_title('Difference Image')
    axs[1, 0].axis('off')

    axs[1, 1].imshow(highlighted_image, cmap='gray', vmin=0, vmax=1)
    axs[1, 1].set_title('Highlighted Target Image')
    axs[1, 1].axis('off')

    for i in range(0, diff_image.shape[0], 3):
        for j in range(0, diff_image.shape[1], 3):
            if mask[i, j]:
                axs[1, 0].text(j, i, f'{np.max(diff_image[i, j]):.2e}', color='red', fontsize=6,
                               ha='center', va='center')

    plt.tight_layout()
    plt.show()


def span_project(train_dataset: Subset, adversarial_examples: Tensor,
                 n_components: int = None, batch_size: int = 256):
    train_features = extract_features(train_dataset, batch_size=batch_size)
    pca = perform_pca(train_features, n_components=n_components)
    projected_adversarial_examples = project_onto_pca(pca, adversarial_examples)
    return projected_adversarial_examples


def extract_features(dataset: Subset, batch_size: int = 256):
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    features = []
    for images, _ in dataloader:
        features.append(images.view(images.size(0), -1).numpy())
    return np.vstack(features)


def perform_pca(features: np.ndarray, n_components: int = None):
    pca = PCA(n_components=n_components)
    pca.fit(features)
    return pca


def project_onto_pca(pca, adversarial_examples: Tensor):
    adversarial_features = adversarial_examples.view(adversarial_examples.size(0), -1).numpy()
    projected_features = pca.transform(adversarial_features)
    reconstructed_features = pca.inverse_transform(projected_features)
    reconstructed_images = torch.tensor(reconstructed_features).view(adversarial_examples.size())
    return reconstructed_images


def calculate_l2_distance(image1: Tensor, image2: Tensor) -> Tensor:
    difference = image1 - image2
    squared_difference = difference ** 2
    sum_squared_difference = torch.sum(squared_difference)
    l2_distance = torch.sqrt(sum_squared_difference)
    return l2_distance


def image_clamp_min_max_val(mean_dataset: List[float],
                            std_dataset: List[float]) -> [List[float], List[float]]:
    mean = torch.tensor(mean_dataset).view(3, 1, 1)
    std = torch.tensor(std_dataset).view(3, 1, 1)
    min_val = (0 - mean) / std
    max_val = (1 - mean) / std
    return min_val, max_val


def fgsm_attack(image: Tensor, epsilon: float, data_grad: Tensor,
                min_pixel_val: Tensor, max_pixel_val: Tensor) -> Tensor:
    sign_data_grad = data_grad.sign()
    perturbed_image = image + epsilon * sign_data_grad
    perturbed_image = torch.clamp(perturbed_image, min_pixel_val, max_pixel_val)
    return perturbed_image


def pgd_attack(model, image: Tensor, label, epsilon: float, alpha: float,
               num_iter: int, min_pixel_val: Tensor, max_pixel_val: Tensor,
               ) -> Tensor:
    """

    :param model: The neural network model.
    :param image: The input image.
    :param label: The true label of the image.
    :param epsilon: Maximum perturbation.
    :param alpha: Step size.
    :param num_iter: The number of iterations.
    :param min_pixel_val:
    :param max_pixel_val:
    :return: The adversarial example.
    """
    perturbed_image = image.clone().detach()
    perturbed_image.requires_grad = True
    for _ in range(num_iter):
        model.zero_grad()
        output = model(perturbed_image)
        loss = nn.CrossEntropyLoss()(output, label.view(-1))
        loss.backward()
        grad = perturbed_image.grad.data
        perturbation = alpha * grad.sign()
        perturbed_image = perturbed_image + perturbation
        perturbed_image = torch.max(torch.min(perturbed_image, image + epsilon), image - epsilon)
        perturbed_image = torch.clamp(perturbed_image, min_pixel_val, max_pixel_val)
        perturbed_image = perturbed_image.detach()
        perturbed_image.requires_grad = True
    return perturbed_image


@log_prints()
def attack_vary(
        model, target_image: torch.Tensor, target_label: int, train_dataset: Subset,
        mean_dataset: List[float], std_dataset: List[float],
        epsilon_min_bound: float = 0.01, epsilon_max_bound: float = 0.1,
        n_attack_samples: int = None, attack_type: str = 'pgd', n_components: int = 100,
        batch_size: int = 256,
        span_project_flag: bool = True, **kwargs
) -> [TensorDataset, torch.Tensor, torch.Tensor]:
    model.eval()
    if isinstance(target_image, Subset):
        target_image = target_image[0][0]
    if target_image.ndim == 3:
        target_image = target_image.unsqueeze(0)
    device = next(model.parameters()).device
    target_label_tensor = torch.tensor([target_label], dtype=torch.long, device=device)
    target_image = target_image.to(device)

    target_image.requires_grad = True
    output = model(target_image)
    init_pred = output.max(1, keepdim=True)[1]
    loss = nn.CrossEntropyLoss()(output, init_pred.view(-1))
    model.zero_grad()
    loss.backward()
    data_grad = target_image.grad.data

    epsilon_values = torch.linspace(epsilon_min_bound, epsilon_max_bound, n_attack_samples).to(device)
    alpha_values = torch.linspace(kwargs.get('alpha', 0.01), kwargs.get('alpha', 0.01) * 1.01, n_attack_samples).to(
        device)
    successful_adv_samples = []
    successful_labels = []
    successful_adv_samples_distance = []
    adv_samples = []
    min_pixel_val, max_pixel_val = image_clamp_min_max_val(mean_dataset, std_dataset)
    for i in tqdm(range(n_attack_samples)):
        noisy_image = target_image + (torch.rand_like(target_image) - 0.5) * 0.01
        noisy_image = torch.clamp(noisy_image, min_pixel_val, max_pixel_val)

        if attack_type == 'fgsm':
            perturbed_image = fgsm_attack(noisy_image, epsilon_values[i].item(), data_grad, min_pixel_val,
                                          max_pixel_val)
        elif attack_type == 'pgd':
            perturbed_image = pgd_attack(model, noisy_image, target_label_tensor,
                                         epsilon_values[i].item(),
                                         alpha_values[i].item(),
                                         kwargs.get('num_iter', 10),
                                         min_pixel_val, max_pixel_val)
        else:
            raise ValueError(f"Unknown attack type: {attack_type}")

        output = model(perturbed_image)
        final_pred = output.max(1, keepdim=True)[1]

        adv_samples.append(perturbed_image.squeeze().detach())

        if final_pred.item() != target_label_tensor.item():
            # show_images_diff(target_image, perturbed_image.squeeze().detach())

            successful_adv_samples.append(perturbed_image.squeeze().detach())
            successful_labels.append(final_pred.item())
            l2_distance = calculate_l2_distance(perturbed_image.squeeze().detach(), target_image)
            successful_adv_samples_distance.append(l2_distance)
    print(f"Successful adv samples: {len(successful_adv_samples)}")

    adversarial_images_for_train = torch.stack(adv_samples)
    successful_project_adv_samples_flag = False
    if span_project_flag:
        print(f"Start project adv images on span")
        projected_adversarial_images = span_project(
            train_dataset, adversarial_examples=adversarial_images_for_train,
            n_components=n_components, batch_size=batch_size,
        )
        print(f"Successful project adv images on span")
        projected_adversarial_images_for_train = torch.cat(
            tensors=[projected_adversarial_images,
                     adversarial_images_for_train,
                     ],
            dim=0
        )
        successful_adv_samples_loc = []
        successful_adv_samples_distance_loc = []
        for image in projected_adversarial_images:
            image = image.unsqueeze(0)
            output = model(image)
            final_pred = output.max(1, keepdim=True)[1]
            if final_pred.item() != target_label_tensor.item():
                successful_adv_samples_loc.append(image.squeeze().detach())
                l2_distance = calculate_l2_distance(image.squeeze().detach(), target_image)
                successful_adv_samples_distance_loc.append(l2_distance)
        print(f"Successful project adv samples: {len(successful_adv_samples_loc)}")
        if successful_adv_samples_loc:
            successful_project_adv_samples_flag = True
            successful_adv_samples_loc.append(target_image.squeeze().detach())
            adversarial_images_project = torch.stack(successful_adv_samples_loc[:-1])
            labels = torch.full((adversarial_images_project.size(0),), target_label, dtype=torch.long)
            adv_dataset_project = TensorDataset(adversarial_images_project, labels)
            successful_adv_samples_distance_loc = torch.tensor(successful_adv_samples_distance_loc)

    if successful_adv_samples:
        successful_adv_samples.append(target_image.squeeze().detach())
        successful_labels.append(target_label)
        adversarial_images = torch.stack(successful_adv_samples[:-1])
        labels = torch.full((adversarial_images.size(0),), target_label, dtype=torch.long)  # Placeholder labels
        # adversarial_images_for_train = torch.stack(successful_adv_samples)
        # successful_labels_for_train = torch.tensor(successful_labels)
        # successful_labels_for_train = torch.full((adversarial_images.size(0) + 1,), target_label, dtype=torch.long)
        successful_adv_samples_distance = torch.tensor(successful_adv_samples_distance)
        adv_dataset = TensorDataset(adversarial_images, labels)
        # adv_dataset_for_train = TensorDataset(adversarial_images_for_train, successful_labels_for_train)
    else:
        adv_dataset, adversarial_images_for_train, successful_adv_samples_distance = attack_vary(
            model=model, target_image=target_image, target_label=target_label,
            mean_dataset=mean_dataset, std_dataset=std_dataset, train_dataset=train_dataset,
            epsilon_min_bound=epsilon_min_bound * 2, epsilon_max_bound=epsilon_max_bound * 2,
            n_attack_samples=n_attack_samples, attack_type=attack_type, span_project_flag=span_project_flag,
            batch_size=batch_size,
            n_components=n_components,
            kwargs=kwargs,
        )

    if span_project_flag:
        if successful_project_adv_samples_flag:
            return adv_dataset_project, projected_adversarial_images_for_train, successful_adv_samples_distance_loc
        else:
            return adv_dataset, projected_adversarial_images_for_train, successful_adv_samples_distance
    return adv_dataset, adversarial_images_for_train, successful_adv_samples_distance
