from io import BytesIO
from pathlib import Path
from typing import Callable

import numpy as np
import torch
from PIL import Image
from matplotlib import pyplot as plt

from handlers.base_handler import EpochEndCallbackHandler


class SaverCallbackHandler(EpochEndCallbackHandler):
    def __init__(self, path: Path, generate_data: Callable, ending: str = "py"):
        self.path = path
        self.generate_data = generate_data
        self.curr_best_value = float("inf")
        self.curr = 0
        self.ending = ending

    def on_epoch_end(self, alg, *args, best_model_value=None, **kwargs):
        if best_model_value is None:
            return
        if best_model_value < self.curr_best_value:
            self.curr_best_value = best_model_value
            real_data = alg.real_data(alg.curr_point_to_draw)
            data = self.generate_data(real_data)
            self.save_model(data)

    def save_model(self, curr_point_to_draw):
        self.path.mkdir(parents=True, exist_ok=True)
        new_save_path = self.path / f"model_{self.curr}.{self.ending}"
        self.curr += 1
        new_save_path.write_bytes(curr_point_to_draw)


class GenerateCode:
    def __init__(self, llm):
        self.llm = llm

    def __call__(self, data):
        with torch.no_grad():
            c = self.llm(data)[0]
            return c[0].encode()


class GenerateImage:
    def __init__(self, processor):
        self.processor = processor

    def __call__(self, data):
        with torch.no_grad():
            processed_image = self.processor(data)
            if len(processed_image.shape) == 4:
                processed_image = processed_image[0]
            image_tensor = processed_image.permute(1, 2, 0).cpu().numpy()
            pil_image = Image.fromarray((image_tensor * 255).astype(np.uint8))
            buffer = BytesIO()
            pil_image.save(buffer, format="PNG")
            image_bytes = buffer.getvalue()
        return image_bytes


class AccuracyPlotSaver(EpochEndCallbackHandler):
    def __init__(
        self,
        epoch_interval: int,
        accuracy_calculator: Callable,
        fidelity_calculator: Callable,
        path,
    ):
        self.epoch_interval = epoch_interval
        self.accuracy_calculator = accuracy_calculator
        self.fidelity_calculator = fidelity_calculator
        self.epoch_counter = 0
        self.accuracy = []
        self.fidelity = []
        self.path = path

    def on_epoch_end(self, alg, *args, **kwargs):
        acc = self.accuracy_calculator(alg.best_point_until_now)
        fidelity = self.fidelity_calculator(alg.best_point_until_now)
        self.fidelity.append(fidelity.cpu().detach().item())
        self.accuracy.append(acc.cpu().detach().item())
        if self.epoch_counter % self.epoch_interval == 0:
            fig, ax = plt.subplots()
            ax.plot(self.fidelity, label="Fidelity")
            fig.legend()

            ax.set_title("Image Classification")
            ax.set_xlabel("Epochs")
            ax.set_ylabel("Value")

            image_path = self.path / f"fidelity_{self.epoch_counter}.png"
            fig.savefig(image_path)

            fig, ax = plt.subplots()
            ax.plot(self.accuracy, label="Accuracy")
            fig.legend()
            ax.set_title("Image Classification")
            ax.set_xlabel("Epochs")
            ax.set_ylabel("Value")
            image_path = self.path / f"accuracy_{self.epoch_counter}.png"
            fig.savefig(image_path)
        self.epoch_counter += 1


class ImageFalseNegativeCalculator:
    def __init__(
        self,
        classifier,
        processor,
        model_processor,
        post_processor,
        correct_classification,
    ):
        self.classifier = classifier
        self.processor = processor
        self.model_processor = model_processor
        self.post_processor = post_processor
        self.correct_classification = correct_classification

    def __call__(self, data):
        single_or_batch = len(data.shape) == 4
        data = data.unsqueeze(0)
        data = self.processor(data)
        data = self.model_processor(data)
        classification = self.classifier(data)
        loss = torch.nn.functional.cross_entropy(
            self.post_processor(classification),
            torch.ones(data.shape[0], device=data.device, dtype=torch.long)
            * self.correct_classification,
            reduction="none",
        )
        return loss[0] if single_or_batch else loss


class MSEFidelityCalculator:
    def __init__(self, original_image, processor):
        self.original_image = original_image
        self.processor = processor

    def __call__(self, data):
        images = self.processor(data)
        single = len(images.shape) == 3
        return torch.nn.functional.mse_loss(
            self.original_image[0] if single else self.original_image, images
        )
