from typing import Union, Callable

import numpy as np
import torch.utils.data
import torch_fidelity
from torch_fidelity.metric_fid import fid_featuresdict_to_statistics, fid_statistics_to_metric
from torch_fidelity.utils import create_feature_extractor, get_featuresdict_from_dataset
from tqdm import tqdm

from src.datasets.noise import NoiseDataset
from src.datasets.noise_label import NoiseLabelDataset
from src.fidelity.utils import ImagesScale, ImagesFormat, ImageDataset, convert_scale, convert_format, \
    load_fid_statistics
from utils.logger.logger import Logger
from utils.numpy.stats import get_numpy_stats
from utils.utils import get_class_name


@torch.inference_mode()
def calculate_fid(
        images: np.ndarray,
        reference_path: str,
        batch_size: int,
        images_scale: ImagesScale = ImagesScale.MINUS_ONE_TO_ONE,
        images_format: ImagesFormat = ImagesFormat.NCHW,
) -> float:
    Logger.debug(
        f'{get_class_name(calculate_fid)} - '
        f'images: {get_numpy_stats(images)}, '
        f'reference_path: {reference_path}, '
        f'batch_size: {batch_size}, '
        f'images_scale: {images_scale}, '
        f'images_format: {images_format}'
    )
    dataset: ImageDataset = ImageDataset(convert_scale(convert_format(images, images_format), images_scale))
    feat_extractor = create_feature_extractor('inception-v3-compat', ['2048'], cuda=torch.cuda.is_available())
    features_dict = get_featuresdict_from_dataset(
        input=dataset,
        feat_extractor=feat_extractor,
        batch_size=batch_size,
        cuda=torch.cuda.is_available(),
        save_cpu_ram=False,
        verbose=True
    )
    stats = fid_featuresdict_to_statistics(features_dict, '2048')
    ref_stats = load_fid_statistics(reference_path)
    result = fid_statistics_to_metric(stats, ref_stats, verbose=True)
    Logger.debug(f'{get_class_name(calculate_fid)} - result: {result}')
    fid: float = result[torch_fidelity.KEY_METRIC_FID]
    Logger.debug(f'{get_class_name(calculate_fid)} - fid: {fid}')
    return fid


@torch.inference_mode()
def calculate_fid_for_dataset_and_model(
        model: torch.nn.Module,
        dataset: Union[NoiseLabelDataset, NoiseDataset],
        reference_path: str,
        inference_batch_func: Callable[[torch.nn.Module, torch.Tensor, torch.Tensor], torch.Tensor],
        batch_size: int = 50,
        data_loader_workers: int = 8
) -> float:
    Logger.debug(
        f'{get_class_name(calculate_fid_for_dataset_and_model)} start - '
        f'model: {type(model)}, '
        f'dataset: {type(dataset)}, '
        f'reference_path: {reference_path}, '
        f'inference_batch_func: {get_class_name(inference_batch_func)}, '
        f'batch_size: {batch_size}, '
        f'data_loader_workers: {data_loader_workers}'
    )
    assert isinstance(dataset, NoiseLabelDataset) or isinstance(dataset, NoiseDataset), \
        f'dataset must be NoiseLabelDataset or NoiseDataset, got {type(dataset)}'
    conditional: bool = isinstance(dataset, NoiseLabelDataset)
    Logger.debug(f'conditional: {conditional}')

    noise_shape: tuple[int, ...] = dataset.data_shape['noise']

    images: np.ndarray = np.zeros((len(dataset), *noise_shape))

    data_loader: torch.utils.data.DataLoader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=data_loader_workers,
        drop_last=False
    )

    model.eval()

    index: int = 0
    for data_batch in tqdm(data_loader):
        output_batch: torch.Tensor = inference_batch_func(
            model,
            data_batch['noise'],
            data_batch['label'] if conditional else None
        )
        images[index:index + len(output_batch)] = output_batch.detach().cpu().numpy()
        index += len(output_batch)

    model.train()

    fid: float = calculate_fid(
        images=images,
        reference_path=reference_path,
        batch_size=batch_size,
        images_scale=ImagesScale.MINUS_ONE_TO_ONE,
        images_format=ImagesFormat.NCHW
    )
    Logger.debug(f'{get_class_name(calculate_fid_for_dataset_and_model)} end - fid: {fid}')
    return fid
