from typing import Any, Union, Callable

import numpy as np
import torch.utils.data
import torch_fidelity
from torch_fidelity.metric_fid import fid_featuresdict_to_statistics, fid_statistics_to_metric
from torch_fidelity.metric_isc import isc_featuresdict_to_metric
from torch_fidelity.utils import create_feature_extractor, get_featuresdict_from_dataset
from tqdm import tqdm

from src.datasets.noise import NoiseDataset
from src.datasets.noise_label import NoiseLabelDataset
from src.fidelity.utils import ImagesScale, ImagesFormat, ImageDataset, convert_scale, convert_format, \
    load_fid_statistics
from utils.logger.logger import Logger
from utils.numpy.stats import get_numpy_stats
from utils.utils import get_class_name


@torch.inference_mode()
def calculate_metrics(
        images: np.ndarray,
        batch_size: int,
        calculate_fid: bool = False,
        calculate_isc: bool = False,
        fid_reference_path: str = None,
        images_scale: ImagesScale = ImagesScale.MINUS_ONE_TO_ONE,
        images_format: ImagesFormat = ImagesFormat.NCHW,
) -> dict[str, Any]:
    Logger.debug(
        f'{get_class_name(calculate_metrics)} - '
        f'images: {get_numpy_stats(images)}, '
        f'batch_size: {batch_size}, '
        f'calculate_fid: {calculate_fid}, '
        f'calculate_isc: {calculate_isc}, '
        f'fid_reference_path: {fid_reference_path}, '
        f'images_scale: {images_scale}, '
        f'images_format: {images_format}'
    )
    assert calculate_fid or calculate_isc, 'at least one metric must be calculated'
    assert not calculate_fid or fid_reference_path is not None, 'fid_reference_path must be provided for fid calculation'
    features_set: set[str] = set()
    if calculate_fid:
        features_set.add('2048')
    if calculate_isc:
        features_set.add('logits_unbiased')
    dataset: ImageDataset = ImageDataset(convert_scale(convert_format(images, images_format), images_scale))
    feat_extractor = create_feature_extractor('inception-v3-compat', list(features_set), cuda=torch.cuda.is_available())
    features_dict = get_featuresdict_from_dataset(
        input=dataset,
        feat_extractor=feat_extractor,
        batch_size=batch_size,
        cuda=torch.cuda.is_available(),
        save_cpu_ram=False,
        verbose=True
    )
    result: dict[str, Any] = {}
    if calculate_fid:
        fid_statistics: dict[str, np.ndarray] = fid_featuresdict_to_statistics(features_dict, '2048')
        fid_reference_statistics: dict[str, np.ndarray] = load_fid_statistics(fid_reference_path)
        fid_result = fid_statistics_to_metric(fid_statistics, fid_reference_statistics, verbose=True)
        fid: float = fid_result[torch_fidelity.KEY_METRIC_FID]
        Logger.debug(f'{get_class_name(calculate_metrics)} - fid: {fid}')
        result['fid'] = fid
    if calculate_isc:
        isc_result: dict[str, float] = isc_featuresdict_to_metric(features_dict, 'logits_unbiased')
        isc_mean: float = isc_result[torch_fidelity.KEY_METRIC_ISC_MEAN]
        isc_std: float = isc_result[torch_fidelity.KEY_METRIC_ISC_STD]
        Logger.debug(f'{get_class_name(calculate_metrics)} - isc mean: {isc_mean}, isc std: {isc_std}')
        result['isc'] = {'mean': isc_mean, 'std': isc_std}
    Logger.debug(f'{get_class_name(calculate_metrics)} - result: {result}')
    return result


@torch.inference_mode()
def calculate_metrics_for_dataset_and_model(
        model: torch.nn.Module,
        dataset: Union[NoiseLabelDataset, NoiseDataset],
        inference_batch_func: Callable[[torch.nn.Module, torch.Tensor, torch.Tensor], torch.Tensor],
        calculate_fid: bool = False,
        calculate_isc: bool = False,
        fid_reference_path: str = None,
        batch_size: int = 50,
        data_loader_workers: int = 8
) -> dict[str, Any]:
    Logger.debug(
        f'{get_class_name(calculate_metrics_for_dataset_and_model)} start - '
        f'model: {type(model)}, '
        f'dataset: {type(dataset)}, '
        f'inference_batch_func: {get_class_name(inference_batch_func)}, '
        f'calculate_fid: {calculate_fid}, '
        f'calculate_isc: {calculate_isc}, '
        f'fid_reference_path: {fid_reference_path}, '
        f'batch_size: {batch_size}, '
        f'data_loader_workers: {data_loader_workers}'
    )
    assert isinstance(dataset, NoiseLabelDataset) or isinstance(dataset, NoiseDataset), \
        f'dataset must be NoiseLabelDataset or NoiseDataset, got {type(dataset)}'
    conditional: bool = isinstance(dataset, NoiseLabelDataset)
    Logger.debug(f'conditional: {conditional}')

    noise_shape: tuple[int, ...] = dataset.data_shape['noise']

    images: np.ndarray = np.zeros((len(dataset), *noise_shape))

    data_loader: torch.utils.data.DataLoader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=data_loader_workers,
        drop_last=False
    )

    model.eval()

    index: int = 0
    for data_batch in tqdm(data_loader):
        output_batch: torch.Tensor = inference_batch_func(
            model,
            data_batch['noise'],
            data_batch['label'] if conditional else None
        )
        images[index:index + len(output_batch)] = output_batch.detach().cpu().numpy()
        index += len(output_batch)

    model.train()

    result: dict[str, Any] = calculate_metrics(
        images=images,
        calculate_fid=calculate_fid,
        calculate_isc=calculate_isc,
        fid_reference_path=fid_reference_path,
        batch_size=batch_size,
        images_scale=ImagesScale.MINUS_ONE_TO_ONE,
        images_format=ImagesFormat.NCHW
    )
    Logger.debug(f'{get_class_name(calculate_metrics_for_dataset_and_model)} end - result: {result}')
    return result
