import math
from pathlib import Path
from uuid import uuid4

import click
import torch
from PIL import Image
from torch.nn import SmoothL1Loss, Linear, Sequential, Module
from torchvision.transforms import ToTensor, Resize

import applications
from algorithms.convergence_algorithms.egl_scheduler import EGLScheduler
from algorithms.convergence_algorithms.hegl import HEGL
from algorithms.convergence_algorithms.utils import GOAL_IS_REACHED_STOPPING_CONDITION
from algorithms.mapping.trust_region import TanhTrustRegion
from algorithms.mapping.value_normalizers import AdaptedOutputUnconstrainedMapping
from algorithms.nn.datasets import PairsInEpsRangeDataset
from algorithms.nn.distributions import QuantileWeights
from algorithms.nn.modules import BaseSequentialModel, BigLinearNetwork
from applications.code_space import DiscriminatorSpace
from applications.common import IMG_MODELS
from applications.losses import (
    SplitLogLoss,
    LossManipulation,
    SigmoidScaleLogExpLoss,
    PolynomScaleSplitLoss,
    ZeroLoss,
    MultiplierLoss,
)
from applications.manipulate import NormalizeWindowsSoftmax
from applications.options import (
    MODEL_NAME_OPTION,
    HEIGHT_OPTION,
    WIDTH_OPTION,
    IMAGE_NAME_OPTION,
    NUM_OF_WINDOWS_OPTION,
)
from applications.processor import ModelProcessorWrapper
from applications.saver import (
    GenerateImage,
    SaverCallbackHandler,
    AccuracyPlotSaver,
    ImageFalseNegativeCalculator,
    MSEFidelityCalculator,
)
from applications.stop_conditions import ClassificationInTop
from handlers.drawer_handlers import LoggerDrawerHandler
from handlers.drawers.loss_drawer import StepSizeDrawer
from handlers.drawers.utils import convert_to_real_drawer
from run_options import DEVICE_OPTION
from utils.dynamically_load_class import find_class_by_name
from utils.logger import create_logger
from utils.python import timestamp_file_signature


def evasion_attack(
    model: Module,
    processor,
    post_processor,
    image: Image,
    num_of_windows: int,
    window_height: int,
    window_width: int,
    class_number: int,
    processor_class,
    algorithm_class,
    device: int,
):
    dtype = torch.float32
    resize_image = Resize((150, 150))
    image = resize_image(ToTensor()(image)).unsqueeze(0)
    image = image.to(device=device)
    image_example = processor(images=image, return_tensors="pt").pixel_values
    model.eval()
    c = image_example.shape[-3]
    add_noise_processor = processor_class(
        image,
        window_height,
        window_width,
        c,
        num_of_windows=num_of_windows,
        # window_manipulator=NormalizeWindowsSoftmax(),
    )
    model_processor = ModelProcessorWrapper(processor)
    dims = add_noise_processor.dims

    move_toward_classification = False
    model_to_train = BaseSequentialModel(
        Sequential(Linear(dims, 1, bias=False, dtype=dtype)).to(device=device)
    )
    curr_point = model_to_train.model_parameter_tensor()
    curr_point[: num_of_windows * 2] = torch.rand(num_of_windows * 2) - 1.5
    curr_point[num_of_windows * 2 :] = torch.rand_like(curr_point[num_of_windows * 2 :])
    model_to_train.from_parameter_tensor(curr_point)

    grad_network = BigLinearNetwork(dims, [dims // 2, dims // 2, dims], device).to(
        dtype=dtype
    )
    grad_opt = torch.optim.Adam(
        grad_network.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-04
    )
    model_opt = torch.optim.Adam(
        model_to_train.parameters(), lr=0.01, betas=(0.9, 0.999), eps=1e-04
    )

    base_dir = Path(__file__).parent.parent
    algorithm_name = "evasion_attack"
    run_name = "normal"
    results_save_path = base_dir / "saves" / f"{timestamp_file_signature()}-{uuid4()}"
    normal_logs_path = (
        base_dir
        / "app_logs"
        / rf"logs_for_parallel-{algorithm_name}-{run_name}-{timestamp_file_signature()}"
    )

    logger = create_logger(normal_logs_path, None, run_name, algorithm_name, None)
    space = DiscriminatorSpace(
        model,
        original_image=image,
        distance_calculator=lambda orig, new: torch.nn.functional.mse_loss(
            orig, new, reduction="none"
        ).mean(dim=(1, 2, 3)),
        processor=add_noise_processor,
        model_processor=model_processor,
        post_processor=post_processor,
        upper_bound=torch.ones(dims),  # * 10,
        lower_bound=-torch.ones(dims),  # * 10,
        logger=logger,
        budget=150_000,
        loss=LossManipulation(
            # PolynomScaleSplitLoss(n2=-1, b=8),
            MultiplierLoss(30),
            MultiplierLoss(2),
            # SigmoidScaleLogExpLoss(log_factor=0.4, negative_loss=True),
            class_number,
        ),
        stop_condition=ClassificationInTop(
            2, class_number, move_toward_classification, image, 0.017
        ),
    )

    egl = algorithm_class(
        space,
        helper_network=grad_network,
        model_to_train=model_to_train,
        value_optimizer=grad_opt,
        model_to_train_optimizer=model_opt,
        epsilon=0.1 * math.sqrt(dims),
        epsilon_factor=0.97,
        min_epsilon=1e-4,
        perturb=0,
        grad_loss=SmoothL1Loss(),
        database_type=PairsInEpsRangeDataset,
        database_size=100_000,
        input_mapping=TanhTrustRegion(
            space.upper_bound,
            space.lower_bound,
            min_trust_region_size=0,
            dtype=dtype,
        ),
        output_mapping=AdaptedOutputUnconstrainedMapping(output_epsilon=5e-4),
        weights_creator=QuantileWeights(),
        train_quantile=70,
        dtype=dtype,
        device=device,
        logger=logger,
    )

    egl.train(
        epochs=1000,
        exploration_size=128,
        num_loop_without_improvement=8,
        min_iteration_before_shrink=15,
        helper_model_training_epochs=1,
        stopping_conditions=[GOAL_IS_REACHED_STOPPING_CONDITION],
        callback_handlers=[
            LoggerDrawerHandler(
                convert_to_real_drawer(StepSizeDrawer()),
                logger=logger,
                name=f"step size {space}",
            ),
            SaverCallbackHandler(
                results_save_path, GenerateImage(add_noise_processor), ending="png"
            ),
            AccuracyPlotSaver(
                30,
                ImageFalseNegativeCalculator(
                    model,
                    add_noise_processor,
                    model_processor,
                    post_processor,
                    class_number,
                ),
                MSEFidelityCalculator(image, add_noise_processor),
                results_save_path,
            ),
        ],
    )
    logger.info(f"starting with {egl}")


@click.command
@MODEL_NAME_OPTION
@IMAGE_NAME_OPTION
@click.option("--class_number", type=int, default=None)
@NUM_OF_WINDOWS_OPTION
@HEIGHT_OPTION
@WIDTH_OPTION
@click.option("--processor_name", type=str, default="RandomAdditionProcessor")
@click.option("--use_hegl", is_flag=True, type=bool, default=False)
# @SET_OPTION
@DEVICE_OPTION
def main(
    model_name: str,
    image_name: Image,
    class_number: int,
    height: int,
    width: int,
    num_of_windows: int,
    processor_name: str,
    # setn: List[Tuple[str, str]],
    use_hegl: bool,
    device: int,
):
    from tqdm import tqdm
    import functools

    tqdm.__init__ = functools.partialmethod(tqdm.__init__, disable=True)

    device = device if device != -1 else None
    model_card = IMG_MODELS[model_name]
    processor = model_card.processor_class.from_pretrained(model_card.model_name)
    model = model_card.model_class.from_pretrained(model_card.model_name).to(
        device=device
    )
    image_path = Path().parent / "images" / image_name
    if class_number is None:
        class_number = model.config.label2id[image_path.stem.split("_")[0]]
    image = Image.open(image_path)
    processor_class = find_class_by_name(applications, processor_name)
    evasion_attack(
        model,
        processor,
        model_card.post_processor,
        image,
        num_of_windows,
        height,
        width,
        class_number,
        processor_class,
        HEGL if use_hegl else EGLScheduler,
        device,
    )


if __name__ == "__main__":
    main()
