import pickle
from enum import Enum, auto
from typing import Any

import numpy as np
import torch.utils.data

from utils.logger.logger import Logger
from utils.numpy.stats import get_numpy_stats
from utils.utils import get_class_name


class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, images: np.ndarray):
        self.images: np.ndarray = images

    def __getitem__(self, index) -> torch.Tensor:
        return torch.from_numpy(self.images[index])

    def __len__(self) -> int:
        return len(self.images)


class ImagesScale(Enum):
    ZERO_TO_ONE = auto()
    ZERO_TO_255 = auto()
    MINUS_ONE_TO_ONE = auto()


class ImagesFormat(Enum):
    NCHW = auto()
    NHWC = auto()


def convert_scale(images: np.ndarray, images_scale: ImagesScale) -> np.ndarray:
    Logger.debug(f'{get_class_name(convert_scale)} - images: {get_numpy_stats(images)}, images_scale: {images_scale}')
    if images_scale == ImagesScale.ZERO_TO_255:
        return np.clip(images, 0, 255).round().astype(np.uint8)
    elif images_scale == ImagesScale.ZERO_TO_ONE:
        return (np.clip(images, 0, 1) * 255).round().astype(np.uint8)
    elif images_scale == ImagesScale.MINUS_ONE_TO_ONE:
        return (np.clip(images * 0.5 + 0.5, 0, 1) * 255).round().astype(np.uint8)
    else:
        raise ValueError(f'unknown images scale: {images_scale}')


def convert_format(images: np.ndarray, images_format: ImagesFormat) -> np.ndarray:
    Logger.debug(
        f'{get_class_name(convert_format)} - images: {get_numpy_stats(images)}, images_format: {images_format}')
    if images_format == ImagesFormat.NCHW:
        return images
    elif images_format == ImagesFormat.NHWC:
        return np.transpose(images, (0, 2, 3, 1))
    else:
        raise ValueError(f'unknown images format: {images_format}')


def load_features(filepath: str) -> dict[str, torch.Tensor]:
    with open(filepath, 'rb') as file:
        features: dict[str, np.ndarray] = pickle.load(file)
    return {k: torch.from_numpy(v) for k, v in features.items()}


def load_fid_statistics(filepath: str) -> dict[str, Any]:
    with np.load(filepath) as file:
        mu: np.ndarray = file['mu']
        sigma: np.ndarray = file['sigma']
    return {'mu': mu, 'sigma': sigma}
