import torch
from torch import Tensor

from algorithms.space.basic_callable_space import BasicCallableSpace


class LLMPythonGeneratorSpace(BasicCallableSpace):
    def __init__(self, llm, loss_func, upper_bound, lower_bound, logger, **kwargs):
        super().__init__(**kwargs)
        self.llm = llm
        self.loss_func = loss_func
        self.my_upper_bound = upper_bound
        self.my_lower_bound = lower_bound
        self.logger = logger

    @property
    def callable_env(self):
        def func(data: Tensor) -> Tensor:
            with torch.no_grad():
                code = self.llm(data)
            loss = self.loss_func(code).to(device=data.device, dtype=data.dtype)
            return loss

        return func

    @property
    def upper_bound(self) -> Tensor:
        return self.my_upper_bound

    @property
    def lower_bound(self) -> Tensor:
        return self.my_lower_bound


class DiscriminatorSpace(BasicCallableSpace):
    def __init__(
        self,
        discriminator,
        original_image,
        distance_calculator,
        processor,
        model_processor,
        post_processor,
        loss,
        upper_bound,
        lower_bound,
        logger,
        stop_condition=None,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.discriminator = discriminator
        self.original_image = original_image
        self.distance_calculator = distance_calculator
        self.model_processor = model_processor
        self.post_processor = post_processor
        self.loss = loss
        self.my_upper_bound = upper_bound
        self.my_lower_bound = lower_bound
        self.logger = logger
        self.processor = processor
        self.stop_condition = stop_condition
        self.finished = False

    @property
    def callable_env(self):
        def func(data: Tensor) -> Tensor:
            with torch.no_grad():
                new_images = self.processor(data)
                images_for_model = self.model_processor(new_images)
                classification = self.discriminator(images_for_model)
                classification = self.post_processor(classification)
                if self.stop_condition:
                    self.finished = self.stop_condition(new_images, classification)

            distance_loss = self.distance_calculator(
                self.original_image.expand_as(new_images), new_images
            )
            return self.loss(classification, distance_loss)

        return func

    @property
    def upper_bound(self) -> Tensor:
        return self.my_upper_bound

    @property
    def lower_bound(self) -> Tensor:
        return self.my_lower_bound

    def is_goal_reached(self):
        return self.finished


class BlackBoxClassifierSpace(BasicCallableSpace):
    def __init__(
        self,
        classifier,
        original_image,
        processor,
        upper_bound,
        lower_bound,
        logger,
        desired_classification=None,
        correct_classifications=None,
        stop_condition=None,
        **kwargs
    ):
        assert correct_classifications is None or desired_classification is None
        super().__init__(**kwargs)
        self.classifier = classifier
        self.original_image = original_image
        self.desired_classification = desired_classification
        self.correct_classifications = correct_classifications
        self.my_upper_bound = upper_bound
        self.my_lower_bound = lower_bound
        self.logger = logger
        self.processor = processor
        self.stop_condition = stop_condition
        self.finished = False

    @property
    def callable_env(self):
        def func(data: Tensor) -> Tensor:
            with torch.no_grad():
                new_images = self.processor(data)
                classification = self.classifier(new_images)
                if self.stop_condition:
                    self.finished = self.stop_condition(new_images, classification)
            classification_loss = (
                torch.nn.functional.cross_entropy(
                    classification,
                    torch.ones(data.shape[0], device=data.device, dtype=torch.long)
                    * self.desired_classification,
                    reduction="none",
                )
                if self.desired_classification
                else -torch.nn.functional.cross_entropy(
                    classification,
                    torch.ones(data.shape[0], device=data.device, dtype=torch.long)
                    * self.correct_classifications,
                    reduction="none",
                )
            )
            return classification_loss

        return func

    def is_goal_reached(self):
        return self.finished

    @property
    def upper_bound(self) -> Tensor:
        return self.my_upper_bound

    @property
    def lower_bound(self) -> Tensor:
        return self.my_lower_bound
