import os
from typing import Union, Callable

import numpy as np
import scipy
import torch.utils.data
import torch.distributed
from sympy.matrices.expressions.factorizations import LofLU
from tqdm import tqdm

from external.consistency_models.evaluations.inception_v3 import InceptionV3, SoftmaxModel
import blobfile as bf

from src.datasets.noise import NoiseDataset
from src.datasets.noise_label import NoiseLabelDataset
from torch_utils.distributed.distributed_manager import DistributedManager
from torch_utils.utils import get_default_device
from utils.logger.logger import Logger
from utils.numpy.stats import get_numpy_stats
from utils.utils import get_class_name


def create_inception_model(device: str = None) -> InceptionV3:
    if device is None:
        device: str = get_default_device()

    inception: InceptionV3 = InceptionV3()
    with bf.BlobFile('PATH/inception-2015-12-05.pt', 'rb') as f:
        inception.load_state_dict(torch.load(f))
    inception.eval()
    inception.to(device)

    return inception


def get_inception_features_batch(
        inception: InceptionV3,
        images: torch.Tensor,
        device: str = None
) -> tuple[torch.Tensor, torch.Tensor]:
    if device is None:
        device: str = get_default_device()

    with torch.no_grad():
        batch = (torch.clip(images * 0.5 + 0.5, 0, 1) * 255).round()

        pred, spatial_pred = inception(batch.to(device))
        pred, spatial_pred = pred.reshape([pred.shape[0], -1]), spatial_pred.reshape([spatial_pred.shape[0], -1])

    return pred, spatial_pred


def get_inception_features(
        dataset: torch.utils.data.Dataset,
        batch_size: int = 50,
        num_workers: int = 8,
        device: str = None
) -> np.ndarray:
    Logger.debug(
        f'{get_class_name(get_inception_features)} start - '
        f'dataset: {type(dataset)}, '
        f'batch_size: {batch_size}, '
        f'num_workers: {num_workers}, '
        f'device: {device}'
    )
    if device is None:
        device: str = get_default_device()

    inception: InceptionV3 = create_inception_model(device)

    data_loader: torch.utils.data.DataLoader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=num_workers
    )

    pred_list: list[np.ndarray] = []
    for batch in tqdm(data_loader):
        curr_pred, _ = get_inception_features_batch(inception, batch)
        pred_list.append(curr_pred.detach().cpu().numpy())
    pred: np.ndarray = np.concatenate(pred_list, axis=0)

    Logger.debug(f'{get_class_name(get_inception_features)} end')

    return pred


def get_inception_features_distributed(
        dataset: torch.utils.data.Dataset,
        batch_size: int = 50,
        num_workers: int = 8,
        device: str = None
) -> np.ndarray:
    Logger.debug(
        f'{get_class_name(get_inception_features_distributed)} start - '
        f'dataset: {type(dataset)}, '
        f'batch_size: {batch_size}, '
        f'num_workers: {num_workers}, '
        f'device: {device}'
    )
    if device is None:
        device: str = get_default_device()

    inception: InceptionV3 = create_inception_model(device)

    sampler: torch.utils.data.DistributedSampler = torch.utils.data.DistributedSampler(
        dataset=dataset,
        num_replicas=DistributedManager.world_size,
        rank=DistributedManager.rank,
        shuffle=False,
        drop_last=False
    ) if DistributedManager.initialized else None
    data_loader: torch.utils.data.DataLoader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        sampler=sampler,
        shuffle=False if sampler is None else None,
        drop_last=False,
        num_workers=num_workers
    )

    pred_list: list[torch.Tensor] = []

    if DistributedManager.initialized:
        sampler.set_epoch(0)

    for batch in tqdm(data_loader):
        curr_pred, _ = get_inception_features_batch(inception, batch)
        pred_list.append(curr_pred.detach().cpu())
    pred: torch.Tensor = torch.concatenate(pred_list, dim=0)

    if DistributedManager.initialized:
        pred.to(device)
        if DistributedManager.is_main():
            dist_result: list[torch.Tensor] = \
                [torch.zeros_like(pred).to(device) for _ in range(DistributedManager.world_size)]
            torch.distributed.gather(pred.to(device), dist_result, dst=DistributedManager.main_rank())
            pred: torch.Tensor = torch.concatenate(dist_result)
        else:
            torch.distributed.gather(pred.to(device), dst=DistributedManager.main_rank())

    Logger.debug(f'{get_class_name(get_inception_features_distributed)} end')

    return pred.detach().cpu().numpy() if DistributedManager.is_main() else None


def get_fid_statistics(activations: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
    mu: np.ndarray = np.mean(activations, axis=0)
    sigma: np.ndarray = np.cov(activations, rowvar=False)
    return mu, sigma


def calculate_frechet_distance(
        mu1: np.ndarray, sigma1: np.ndarray, mu2: np.ndarray, sigma2: np.ndarray, eps: float = 1e-6) -> float:
    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert (
            mu1.shape == mu2.shape
    ), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
    assert (
            sigma1.shape == sigma2.shape
    ), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"

    diff = mu1 - mu2

    # product might be almost singular
    covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = (
                "fid calculation produces singular product; adding %s to diagonal of cov estimates"
                % eps
        )
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    # numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError("Imaginary component {}".format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean


def calculate_inception_score(
        activations: np.ndarray,
        batch_size: int = 50,
        split_size: int = 5000,
        device: str = None
) -> float:
    Logger.debug(
        f'{get_class_name(calculate_inception_score)} start - '
        f'activations: {get_numpy_stats(activations)}, '
        f'batch_size: {batch_size}, '
        f'split_size: {split_size}, '
        f'device: {device}'
    )
    if device is None:
        device: str = get_default_device()

    inception: InceptionV3 = create_inception_model(device)
    inception_softmax: SoftmaxModel = inception.create_softmax_model()

    softmax_out = []
    for i in tqdm(range(0, len(activations), batch_size)):
        acts = activations[i: i + batch_size]
        with torch.no_grad():
            softmax_out.append(inception_softmax(torch.from_numpy(acts).to(device)).cpu().numpy())
    preds = np.concatenate(softmax_out, axis=0)
    # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
    scores = []
    for i in range(0, len(preds), split_size):
        part = preds[i: i + split_size]
        kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
        kl = np.mean(np.sum(kl, 1))
        scores.append(np.exp(kl))
    result: float = float(np.mean(scores))

    Logger.debug(f'{get_class_name(calculate_inception_score)} end - result: {result}')
    return result


def save_fid_statistics(mu: np.ndarray, sigma: np.ndarray, save_path: str) -> None:
    if os.path.dirname(save_path) != '':
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
    np.savez(save_path, mu=mu, sigma=sigma)


def load_fid_statistics(load_path: str) -> tuple[np.ndarray, np.ndarray]:
    fid_statistics = np.load(load_path)
    return fid_statistics['mu'], fid_statistics['sigma']


def get_features_for_dataset_and_model(
        model: torch.nn.Module,
        dataset: Union[NoiseLabelDataset, NoiseDataset],
        inference_batch_func: Callable[[torch.nn.Module, torch.Tensor, torch.Tensor], torch.Tensor],
        batch_size: int = 50,
        data_loader_workers: int = 8,
        device: str = None
) -> np.ndarray:
    Logger.debug(
        f'{get_class_name(get_features_for_dataset_and_model)} start - '
        f'model: {type(model)}, '
        f'dataset: {type(dataset)}, '
        f'inference_batch_func: {get_class_name(inference_batch_func)}, '
        f'batch_size: {batch_size}, '
        f'data_loader_workers: {data_loader_workers}, '
        f'device: {device}'
    )
    if device is None:
        device: str = get_default_device()

    assert isinstance(dataset, NoiseLabelDataset) or isinstance(dataset, NoiseDataset), \
        f'dataset must be NoiseLabelDataset or NoiseDataset, got {type(dataset)}'
    conditional: bool = isinstance(dataset, NoiseLabelDataset)
    Logger.debug(f'conditional: {conditional}')

    sampler: torch.utils.data.DistributedSampler = torch.utils.data.DistributedSampler(
        dataset,
        num_replicas=DistributedManager.world_size,
        rank=DistributedManager.rank,
        shuffle=False,
        drop_last=False
    ) if DistributedManager.initialized else None
    data_loader: torch.utils.data.DataLoader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        sampler=sampler,
        shuffle=False if sampler is None else None,
        num_workers=data_loader_workers,
        drop_last=False
    )

    inception: InceptionV3 = create_inception_model()
    model.eval()

    pred_list: list[torch.Tensor] = []

    if DistributedManager.initialized:
        sampler.set_epoch(0)
    for data_batch in tqdm(data_loader):
        output_batch: torch.Tensor = inference_batch_func(
            model,
            data_batch['noise'],
            data_batch['label'] if conditional else None
        )
        curr_pred, _ = get_inception_features_batch(inception, output_batch)
        pred_list.append(curr_pred.detach().cpu())
    pred: torch.Tensor = torch.concatenate(pred_list, dim=0)

    model.train()

    if DistributedManager.initialized:
        pred.to(device)
        if DistributedManager.is_main():
            dist_result: list[torch.Tensor] = \
                [torch.zeros_like(pred).to(device) for _ in range(DistributedManager.world_size)]
            torch.distributed.gather(pred.to(device), dist_result, dst=DistributedManager.main_rank())
            pred: torch.Tensor = torch.concatenate(dist_result)
        else:
            torch.distributed.gather(pred.to(device), dst=DistributedManager.main_rank())

    Logger.debug(f'{get_class_name(get_features_for_dataset_and_model)} end')
    return pred.detach().cpu().numpy() if DistributedManager.is_main() else None
