"""
Snapshot saver modules for FeatureInversionPipeline.
"""

import os
from typing import List, Union
import numpy as np
from PIL import Image
from abc import ABC, abstractmethod
import torch

from .image_domain import finalize

class SnapshotSaver(ABC):
    @abstractmethod
    def __init__(self, *args, **kwargs):
        """Initialize the snapshot saver."""
        pass

    @abstractmethod
    def __call__(self, step: int, stimulus_names: list[str], images: torch.Tensor, *args, **kwargs):
        """Save snapshot at the given step if conditions are met."""
        pass

    def reset_states(self):
        """Reset the state of the snapshot saver."""
        pass


def crop_image(image: np.ndarray, size: int = 224) -> np.ndarray:
    """
    Crop the image to the specified size.

    Args:
        image (np.ndarray): Image to crop.
        size (int): Size of the cropped image.

    Returns:
        np.ndarray: Cropped image.
    """
    h, w, _ = image.shape
    h_start = (h - size) // 2
    w_start = (w - size) // 2
    return image[h_start : h_start + size, w_start : w_start + size]


class IntervalSnapshotSaver(SnapshotSaver):
    """Save snapshots at regular intervals.

    TODO: Re-implement this class not to use the stimulus_names argument.

    Args:
        path_templates (list[str]): List of snapshot save path templates for each stimulus.
            Each path can contain placeholders `{step}`.
        save_steps (int): Interval of steps to save snapshots.
    """

    def __init__(self, path_templates: str, save_steps: int, crop_size: int = 224):
        self.path_templates = path_templates
        self.save_steps = save_steps
        self.crop_size = crop_size

    def __call__(self, step: int, images: torch.Tensor, *args, **kwargs):
        """
        Save the snapshot if the current step is a multiple of the save interval.

        Args:
            step (int): Current optimization step.
            stimulus_names (list[str]): Names of the stimuli.
            images (torch.Tensor): Images to save.
        """
        if step % self.save_steps != 0:
            return
        
        # save images
        images = finalize(images)
        for i, path_template in enumerate(self.path_templates):
            image = images[i].detach().cpu().numpy().astype(np.uint8)
            image = crop_image(image, self.crop_size)
            image = Image.fromarray(image)

            path = path_template.format(step=step)
            os.makedirs(os.path.dirname(path), exist_ok=True)
            image.save(path)


class ThresholdSnapshotSaver(SnapshotSaver):
    """Save snapshots when the metric exceeds the threshold for the first time.
    The metric is evaluated for each sample separately.

    Args:
        path_templates (list[str]): List of snapshot save paths for each stimulus.
            Each path can contain placeholders `{step}`, `{threshold}`, and `{value}`.
        thresholds (List[float]): Thresholds to save the snapshot.
        metric_name (str): Name of the metric to evaluate
    """
    def __init__(self, path_templates: list[str], thresholds: List[float], metric_name: str, crop_size: int = 224):
        self.path_templates = path_templates
        self.thresholds = thresholds
        self.initialized = False
        self.metric_name = metric_name
        self.crop_size = crop_size

    def __call__(self, step: int, images: torch.Tensor, step_metrics, *args, **kwargs):
        """Save the snapshot if the metric exceeds the threshold for the first time.

        Args:
            step (int): Current optimization step.
            images (torch.Tensor): Images to save.
            step_metrics (dict[str, list[float]]): Metrics evaluated at the current step.
                metric_name -> list[float], element i corresponds to the i-th stimulus.
        """
        batch_size = images.shape[0]
        if not self.initialized:
            # do not suppose unique stimulus names
            self.thresholds_remain = [self.thresholds.copy() for _ in range(batch_size)]
            self.initialized = True

        images = finalize(images)

        # check if the metric exceeds the threshold for each sample
        for i in range(batch_size):
            if not self.thresholds_remain[i]:
                continue
            
            # check if the metric exceeds the threshold
            metric_value = step_metrics[self.metric_name][i]
            remaining_thresholds = self.thresholds_remain[i]
            for threshold in remaining_thresholds[:]:  # Iterate over a copy to modify in-place
                if metric_value >= threshold:
                    # Exceeded the threshold, save the snapshot
                    image = images[i].detach().cpu().numpy().astype(np.uint8)
                    image = crop_image(image, self.crop_size)
                    image = Image.fromarray(image)

                    template = self.path_templates[i]
                    path = template.format(step=step, threshold=threshold, value=metric_value)
                    os.makedirs(os.path.dirname(path), exist_ok=True)
                    image.save(path)
                    print(f'Saved snapshot for index {i} at step {step}, threshold {threshold} to {path}')

                    # Remove the threshold from the remaining list
                    remaining_thresholds.remove(threshold)

    def reset_states(self):
        """Reset the state of the snapshot saver."""
        self.initialized = False
        self.thresholds_remain = None