import math
from pathlib import Path

import torch
from beam.distributed import RayDispatcher
from torch.nn import Linear, Sequential, SmoothL1Loss
from torchvision.transforms.v2.functional import to_tensor

from algorithms.convergence_algorithms.hegl import HEGL
from algorithms.convergence_algorithms.utils import GOAL_IS_REACHED_STOPPING_CONDITION
from algorithms.mapping.trust_region import TanhTrustRegion
from algorithms.mapping.value_normalizers import AdaptedOutputUnconstrainedMapping
from algorithms.nn.datasets import PairsInEpsRangeDataset
from algorithms.nn.distributions import QuantileWeights
from algorithms.nn.modules import BigLinearNetwork, BaseSequentialModel
from applications.classifier_server import ClassifierModelSampler
from applications.code_space import BlackBoxClassifierSpace
from applications.common import IMG_MODELS
from applications.datasets_transform import dataset_from_datasets_package, images_from_folder
from applications.processor import RandomAdditionProcessor, ModelProcessorWrapper
from applications.stop_conditions import ClassificationInTop
from handlers.drawer_handlers import LoggerDrawerHandler
from handlers.drawers.loss_drawer import StepSizeDrawer
from handlers.drawers.utils import convert_to_real_drawer
from utils.logger import create_logger
from utils.python import timestamp_file_signature

DATASETS = {
    "imagenet": lambda size, **kwargs: dataset_from_datasets_package("ILSVRC/imagenet-1k", int(size)),
    "mine": lambda path, **kwargs: images_from_folder(path),
}


class ClassifierBatchSampler:
    def __init__(self, classifier, max_batch_size: int = 122400000):
        self.classifier = classifier
        self.max_batch_size = max_batch_size

    def __call__(self, images):
        if images.numel() < self.max_batch_size or len(images.shape) < 4:
            return self.classifier(images)
        batch_size, channels, height, width = images.shape
        image_size = channels * height * width
        max_samples_in_batch = self.max_batch_size // image_size
        return torch.cat(
            [
                self.classifier(image_batch.cpu().clone())
                for image_batch in torch.split(images, max_samples_in_batch)
            ]
        )


def run_attack(
    classifier_bbo,
    image,
    window_height,
    window_width,
    channels,
    num_of_windows,
    top_to_check,
    desired_classification=None,
    true_classification=None,
):
    algorithm_name = "evasion_attack_experiment"
    run_name = "normal"
    base_dir = Path(__file__).parent.parent
    normal_logs_path = (
        base_dir
        / "app_logs"
        / algorithm_name
        / rf"logs_for_parallel-{algorithm_name}-{run_name}-{timestamp_file_signature()}"
    )
    logger = create_logger(normal_logs_path, None, run_name, algorithm_name, None)

    device = 0
    dtype = torch.float32
    image = image.to(device=device)
    add_noise_processor = RandomAdditionProcessor(
        image, window_height, window_width, channels, num_of_windows
    )
    dims = add_noise_processor.dims
    space = BlackBoxClassifierSpace(
        # ClassifierBatchSampler(classifier_bbo),
        classifier_bbo,
        image,
        add_noise_processor,
        upper_bound=torch.ones(dims) * 10,
        lower_bound=-torch.ones(dims) * 10,
        logger=logger,
        budget=150_000,
        desired_classification=desired_classification,
        correct_classifications=true_classification,
        stop_condition=ClassificationInTop(
            top_to_check,
            desired_classification or true_classification,
            True if desired_classification else False,
        ),
    )
    model_to_train = BaseSequentialModel(
        Sequential(Linear(dims, 1, bias=False, dtype=dtype)).to(device=device)
    )
    curr_point = model_to_train.model_parameter_tensor()
    curr_point[: num_of_windows * 2] = torch.rand(num_of_windows * 2) - 0.5
    model_to_train.from_parameter_tensor(curr_point)

    grad_network = BigLinearNetwork(dims, [dims // 2, dims // 2, dims], device).to(
        dtype=dtype
    )
    grad_opt = torch.optim.Adam(
        grad_network.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-04
    )
    model_opt = torch.optim.Adam(
        model_to_train.parameters(), lr=0.01, betas=(0.9, 0.999), eps=1e-04
    )

    egl = HEGL(
        space,
        helper_network=grad_network,
        model_to_train=model_to_train,
        value_optimizer=grad_opt,
        model_to_train_optimizer=model_opt,
        epsilon=0.6 * math.sqrt(dims),
        epsilon_factor=0.97,
        min_epsilon=1e-4,
        perturb=0,
        grad_loss=SmoothL1Loss(),
        database_type=PairsInEpsRangeDataset,
        database_size=100_000,
        input_mapping=TanhTrustRegion(
            space.upper_bound,
            space.lower_bound,
            min_trust_region_size=0,
            dtype=dtype,
        ),
        output_mapping=AdaptedOutputUnconstrainedMapping(output_epsilon=5e-4),
        weights_creator=QuantileWeights(),
        train_quantile=70,
        dtype=dtype,
        device=device,
        logger=logger,
    )

    try:
        egl.train(
            epochs=10_000,
            exploration_size=128,
            num_loop_without_improvement=20,
            min_iteration_before_shrink=60,
            helper_model_training_epochs=1,
            callback_handlers=[
                LoggerDrawerHandler(
                    convert_to_real_drawer(StepSizeDrawer()),
                    logger=logger,
                    name=f"step size {space}",
                ),
            ],
            stopping_conditions=[GOAL_IS_REACHED_STOPPING_CONDITION],
        )
    except Exception as e:
        logger.error(f"Finish attack: {e}")


def attack(
    server_classifier: str,
    concurrency: int,
    height: int,
    width: int,
    dataset_name: str,
    num_of_windows: int,
    top_check,
    classification,
    d_param,
):
    # classifier_sim = resource(f"beam-http://{server_classifier}:{ATTACK_PORT}")
    model_card = IMG_MODELS[server_classifier]
    classifier_sim = ClassifierModelSampler(
        model_card.model_class.from_pretrained(model_card.model_name).to(0),
        ModelProcessorWrapper(
            model_card.processor_class.from_pretrained(model_card.model_name)
        ),
        model_card.post_processor,
    )
    dataset = DATASETS[dataset_name](**dict(d_param))
    workers = [
        RayDispatcher(run_attack, remote_kwargs={"num_gpus": 1 / concurrency})
        for _ in range(len(dataset))
    ]
    results = [
        worker(
            classifier_sim,
            to_tensor(dataset[i][0]),
            height,
            width,
            3,
            num_of_windows,
            top_check,
            classification,
            dataset[i][1] if not classification else None,
        )
        for i, worker in enumerate(workers)
    ]
    print([res.value for res in results])
    return results
