from typing import Callable, Union

import numpy as np
import torch.distributed
import torch.utils.data
from tqdm import tqdm

from datasets.numpy import NumpyDataset
from fid_pytorch.fid_pytorch import create_inception_model, get_activations_batch_torch, \
    calculate_statistics_from_activations, calculate_frechet_distance, load_fid_stats
from fid_pytorch.inception import InceptionV3
from src.datasets.noise import NoiseDataset
from src.datasets.noise_label import NoiseLabelDataset
from torch_utils.distributed.distributed_manager import DistributedManager
from torch_utils.utils import get_default_device
from utils.logger.logger import Logger
from utils.utils import get_class_name


class RandomNoiseDataset(NumpyDataset):
    def __init__(self, num_samples: int, noise_shape: tuple[int, ...]) -> None:
        super().__init__(noise_shape, np.float64)
        self.num_samples: int = num_samples
        self.noise_shape: tuple[int, ...] = noise_shape

    def __len__(self) -> int:
        return self.num_samples

    def __getitem__(self, item: int) -> np.ndarray:
        return np.random.randn(*self.noise_shape)


class RandomLabelDataset(NumpyDataset):
    def __init__(self, num_samples: int, num_classes: int) -> None:
        super().__init__((), np.int64)
        self.num_samples: int = num_samples
        self.num_classes: int = num_classes

    def __len__(self) -> int:
        return self.num_samples

    def __getitem__(self, item: int) -> np.ndarray:
        return np.random.randint(0, self.num_classes)


@torch.inference_mode()
def calculate_fid_for_dataset_and_model(
        model: torch.nn.Module,
        dataset: Union[NoiseLabelDataset, NoiseDataset],
        reference_path: str,
        inference_batch_func: Callable[[torch.nn.Module, torch.Tensor, torch.Tensor], torch.Tensor],
        batch_size: int = 50,
        dims: int = 2048,
        data_loader_workers: int = 8
) -> float:
    Logger.debug(
        f'{get_class_name(calculate_fid_for_dataset_and_model)} start - '
        f'model: {type(model)}, '
        f'dataset: {type(dataset)}, '
        f'reference_path: {reference_path}, '
        f'inference_batch_func: {get_class_name(inference_batch_func)}, '
        f'batch_size: {batch_size}, '
        f'dims: {dims}, '
        f'data_loader_workers: {data_loader_workers}'
    )
    assert isinstance(dataset, NoiseLabelDataset) or isinstance(dataset, NoiseDataset), \
        f'dataset must be NoiseLabelDataset or NoiseDataset, got {type(dataset)}'
    conditional: bool = isinstance(dataset, NoiseLabelDataset)
    Logger.debug(f'conditional: {conditional}')

    sampler: torch.utils.data.DistributedSampler = torch.utils.data.DistributedSampler(
        dataset,
        num_replicas=DistributedManager.world_size,
        rank=DistributedManager.rank,
        shuffle=False,
        drop_last=False
    ) if DistributedManager.initialized else None
    data_loader: torch.utils.data.DataLoader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        sampler=sampler,
        shuffle=False if sampler is None else None,
        num_workers=data_loader_workers,
        drop_last=False
    )

    inception_model: InceptionV3 = create_inception_model(dims).to(get_default_device())
    inception_model.eval()

    model.eval()

    num_samples: int = len(sampler) if sampler is not None else len(dataset)
    activations: torch.Tensor = torch.zeros((num_samples, dims), dtype=torch.float64)

    index: int = 0
    if sampler is not None:
        sampler.set_epoch(0)
    for data_batch in tqdm(data_loader):
        output_batch: torch.Tensor = inference_batch_func(
            model,
            data_batch['noise'],
            data_batch['label'] if conditional else None
        )
        image_uint8_batch: torch.Tensor = \
            torch.round(torch.clip(output_batch * 0.5 + 0.5, 0.0, 1.0) * 255).to(torch.uint8)
        image_batch: torch.Tensor = (image_uint8_batch.to(torch.float32) / 255.0).to(get_default_device())
        activations[index:index + len(image_batch)] = get_activations_batch_torch(inception_model, image_batch)
        index += len(image_batch)

    model.train()

    if DistributedManager.initialized:
        Logger.debug('calculating fid for distributed')
        dist_result: list[torch.Tensor] = [
            torch.zeros((num_samples, dims)).to(torch.float64).to(get_default_device())
            for _ in range(DistributedManager.world_size)
        ]
        torch.distributed.gather(
            activations.to(get_default_device()),
            dist_result if DistributedManager.is_main() else None,
            dst=DistributedManager.main_rank()
        )

        if DistributedManager.is_main():
            activations: torch.Tensor = torch.concatenate(dist_result)

            generated_mu, generated_sigma = \
                calculate_statistics_from_activations(activations.detach().cpu().numpy())
            reference_mu, reference_sigma = load_fid_stats(reference_path)

            fid: float = \
                calculate_frechet_distance(generated_mu, generated_sigma, reference_mu, reference_sigma)
            res: torch.Tensor = torch.as_tensor(fid, dtype=torch.float64).to(get_default_device())
        else:
            res: torch.Tensor = torch.as_tensor(0.0, dtype=torch.float64).to(get_default_device())

        torch.distributed.broadcast(res, src=DistributedManager.main_rank())
        fid: float = res.detach().cpu().item()
    else:
        Logger.debug('calculating fid for non-distributed')
        generated_mu, generated_sigma = \
            calculate_statistics_from_activations(activations.detach().cpu().numpy())
        reference_mu, reference_sigma = load_fid_stats(reference_path)

        fid: float = calculate_frechet_distance(generated_mu, generated_sigma, reference_mu, reference_sigma)

    Logger.debug(
        f'{get_class_name(calculate_fid_for_dataset_and_model)} end - fid: {fid}')
    return fid
