from typing import Union, Callable, Any

import numpy as np
import torch.utils.data
import torch_fidelity
from torch_fidelity.utils import create_feature_extractor, get_featuresdict_from_dataset
from tqdm import tqdm

from src.datasets.noise import NoiseDataset
from src.datasets.noise_label import NoiseLabelDataset
from src.fidelity.utils import ImagesScale, ImagesFormat, ImageDataset, convert_scale, convert_format, \
    load_features
from torch_utils.stats import get_torch_stats
from utils.logger.logger import Logger
from utils.numpy.stats import get_numpy_stats
from utils.utils import get_class_name


def calc_cdist_full(features_1, features_2, batch_size=10000):
    Logger.debug(
        f'{get_class_name(calc_cdist_full)} - '
        f'features_1: {get_torch_stats(features_1)}, '
        f'features_2: {get_torch_stats(features_2)}, '
        f'batch_size: {batch_size}'
    )
    dists = []
    i = 0
    for feat1_batch in features_1.split(batch_size):
        print(f'Batch {i}')
        dists_batch = []
        j = 0
        for feat2_batch in features_2.split(batch_size):
            print(f'Batch {i}, {j}')
            dists_batch.append(torch.cdist(feat1_batch, feat2_batch).cpu())
            j += 1
        dists.append(torch.cat(dists_batch, dim=1))
        i += 1
    return torch.cat(dists, dim=0)


def calc_cdist_part(features_1, features_2, batch_size=10000):
    Logger.debug(
        f'{get_class_name(calc_cdist_part)} - '
        f'features_1: {get_torch_stats(features_1)}, '
        f'features_2: {get_torch_stats(features_2)}, '
        f'batch_size: {batch_size}'
    )
    dists = []
    i = 0
    for feat2_batch in features_2.split(batch_size):
        print(f'Batch {i}')
        dists.append(torch.cdist(features_1, feat2_batch).cpu())
        i += 1
    return torch.cat(dists, dim=1)


def calculate_precision_recall_full(features_1, features_2, neighborhood=3, batch_size=10000):
    Logger.debug(
        f'{get_class_name(calculate_precision_recall_full)} - '
        f'features_1: {get_torch_stats(features_1)}, '
        f'features_2: {get_torch_stats(features_2)}, '
        f'neighborhood: {neighborhood}, '
        f'batch_size: {batch_size}'
    )
    dist_nn_1 = calc_cdist_full(features_1, features_1, batch_size).kthvalue(neighborhood + 1).values
    dist_nn_2 = calc_cdist_full(features_2, features_2, batch_size).kthvalue(neighborhood + 1).values
    dist_2_1 = calc_cdist_full(features_2, features_1, batch_size)
    dist_1_2 = dist_2_1.T
    # Precision
    precision = (dist_2_1 <= dist_nn_1).any(dim=1).float().mean().item()
    # Recall
    recall = (dist_1_2 <= dist_nn_2).any(dim=1).float().mean().item()
    Logger.debug(f'{get_class_name(calculate_precision_recall_full)} - precision: {precision}, recall: {recall}')
    return precision, recall


def calculate_precision_recall_part(features_1, features_2, neighborhood=3, batch_size=10000):
    Logger.debug(
        f'{get_class_name(calculate_precision_recall_part)} - '
        f'features_1: {get_torch_stats(features_1)}, '
        f'features_2: {get_torch_stats(features_2)}, '
        f'neighborhood: {neighborhood}, '
        f'batch_size: {batch_size}'
    )
    # Precision
    dist_nn_1 = []
    for feat_1_batch in features_1.split(batch_size):
        dist_nn_1.append(calc_cdist_part(feat_1_batch, features_1, batch_size).kthvalue(neighborhood + 1).values)
    dist_nn_1 = torch.cat(dist_nn_1)
    precision = []
    for feat_2_batch in features_2.split(batch_size):
        dist_2_1_batch = calc_cdist_part(feat_2_batch, features_1, batch_size)
        precision.append((dist_2_1_batch <= dist_nn_1).any(dim=1).float())
    precision = torch.cat(precision).mean().item()
    # Recall
    dist_nn_2 = []
    for feat_2_batch in features_2.split(batch_size):
        dist_nn_2.append(calc_cdist_part(feat_2_batch, features_2, batch_size).kthvalue(neighborhood + 1).values)
    dist_nn_2 = torch.cat(dist_nn_2)
    recall = []
    for feat_1_batch in features_1.split(batch_size):
        dist_1_2_batch = calc_cdist_part(feat_1_batch, features_2, batch_size)
        recall.append((dist_1_2_batch <= dist_nn_2).any(dim=1).float())
    recall = torch.cat(recall).mean().item()
    return precision, recall


def prc_features_to_metric(features_1, features_2, neighborhood=3, batch_size=10000):
    Logger.debug(
        f'{get_class_name(prc_features_to_metric)} - '
        f'features_1: {get_torch_stats(features_1)}, '
        f'features_2: {get_torch_stats(features_2)}, '
        f'neighborhood: {neighborhood}, '
        f'batch_size: {batch_size}'
    )
    # Convention: features_1 is REAL, features_2 is GENERATED. This important for the notion of precision/recall only.
    assert torch.is_tensor(features_1) and features_1.dim() == 2
    assert torch.is_tensor(features_2) and features_2.dim() == 2
    assert features_1.shape[1] == features_2.shape[1]

    calculate_precision_recall_fn = calculate_precision_recall_part
    precision, recall = calculate_precision_recall_fn(features_1, features_2, neighborhood, batch_size)
    f_score = 2 * precision * recall / max(1e-5, precision + recall)

    out = {
        torch_fidelity.KEY_METRIC_PRECISION: precision,
        torch_fidelity.KEY_METRIC_RECALL: recall,
        torch_fidelity.KEY_METRIC_F_SCORE: f_score,
    }

    print(f"Precision: {out[torch_fidelity.KEY_METRIC_PRECISION]:.7g}")
    print(f"Recall: {out[torch_fidelity.KEY_METRIC_RECALL]:.7g}")
    print(f"F-score: {out[torch_fidelity.KEY_METRIC_F_SCORE]:.7g}")

    return out


def prc_featuresdict_to_metric(featuresdict_1, featuresdict_2, feat_layer_name, **kwargs):
    features_1 = featuresdict_1[feat_layer_name]
    features_2 = featuresdict_2[feat_layer_name]
    metric = prc_features_to_metric(features_1, features_2, **kwargs)
    return metric


@torch.inference_mode()
def calculate_prc(
        images: np.ndarray,
        features_path: str,
        batch_size: int,
        images_scale: ImagesScale = ImagesScale.MINUS_ONE_TO_ONE,
        images_format: ImagesFormat = ImagesFormat.NCHW,
        save_cpu_ram: bool = False
) -> tuple[float, float, float]:
    Logger.debug(
        f'{get_class_name(calculate_prc)} - '
        f'images: {get_numpy_stats(images)}, '
        f'features_path: {features_path}, '
        f'batch_size: {batch_size}, '
        f'images_scale: {images_scale}, '
        f'images_format: {images_format}, '
        f'save_cpu_ram: {save_cpu_ram}'
    )
    dataset: ImageDataset = ImageDataset(convert_scale(convert_format(images, images_format), images_scale))
    feat_extractor = create_feature_extractor('inception-v3-compat', ['2048'], cuda=torch.cuda.is_available())
    features_dict = get_featuresdict_from_dataset(
        input=dataset,
        feat_extractor=feat_extractor,
        batch_size=batch_size,
        cuda=torch.cuda.is_available(),
        save_cpu_ram=save_cpu_ram,
        verbose=True
    )
    reference_feature_dict = load_features(features_path)
    print(f'features_dict: {features_dict["2048"].shape}')
    print(f'features_dict: {features_dict["2048"].dtype}')
    print(f'features_dict: {features_dict["2048"].device}')
    print(f'reference_feature_dict: {reference_feature_dict["2048"].shape}')
    print(f'reference_feature_dict: {reference_feature_dict["2048"].dtype}')
    print(f'reference_feature_dict: {reference_feature_dict["2048"].device}')
    result: dict[str, Any] = prc_featuresdict_to_metric(reference_feature_dict, features_dict, '2048')
    Logger.debug(f'{get_class_name(calculate_prc)} - result: {result}')
    precision: float = result[torch_fidelity.KEY_METRIC_PRECISION]
    recall: float = result[torch_fidelity.KEY_METRIC_RECALL]
    f_score: float = result[torch_fidelity.KEY_METRIC_F_SCORE]
    Logger.debug(f'{get_class_name(calculate_prc)} - precision: {precision}, recall: {recall}, f_score: {f_score}')
    return precision, recall, f_score


@torch.inference_mode()
def calculate_prc_for_dataset_and_model(
        model: torch.nn.Module,
        dataset: Union[NoiseLabelDataset, NoiseDataset],
        features_path: str,
        inference_batch_func: Callable[[torch.nn.Module, torch.Tensor, torch.Tensor], torch.Tensor],
        batch_size: int = 50,
        data_loader_workers: int = 8,
        save_cpu_ram: bool = False
) -> tuple[float, float, float]:
    Logger.debug(
        f'{get_class_name(calculate_prc_for_dataset_and_model)} start - '
        f'model: {type(model)}, '
        f'dataset: {type(dataset)}, '
        f'features_path: {features_path}, '
        f'inference_batch_func: {get_class_name(inference_batch_func)}, '
        f'batch_size: {batch_size}, '
        f'data_loader_workers: {data_loader_workers}, '
        f'save_cpu_ram: {save_cpu_ram}'
    )
    assert isinstance(dataset, NoiseLabelDataset) or isinstance(dataset, NoiseDataset), \
        f'dataset must be NoiseLabelDataset or NoiseDataset, got {type(dataset)}'
    conditional: bool = isinstance(dataset, NoiseLabelDataset)
    Logger.debug(f'conditional: {conditional}')

    noise_shape: tuple[int, ...] = dataset.data_shape['noise']

    images: np.ndarray = np.zeros((len(dataset), *noise_shape))

    data_loader: torch.utils.data.DataLoader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=data_loader_workers,
        drop_last=False
    )

    model.eval()

    index: int = 0
    for data_batch in tqdm(data_loader):
        output_batch: torch.Tensor = inference_batch_func(
            model,
            data_batch['noise'],
            data_batch['label'] if conditional else None
        )
        images[index:index + len(output_batch)] = output_batch.detach().cpu().numpy()
        index += len(output_batch)

    model.train()

    precision, recall, f_score = calculate_prc(
        images=images,
        features_path=features_path,
        batch_size=batch_size,
        images_scale=ImagesScale.MINUS_ONE_TO_ONE,
        images_format=ImagesFormat.NCHW
    )
    Logger.debug(
        f'{get_class_name(calculate_prc_for_dataset_and_model)} end - precision: {precision}, recall: {recall}, f_score: {f_score}')
    return precision, recall, f_score
