import argparse
import os.path
from typing import Any

import numpy as np
import torch.utils.data
from torch_fidelity.utils import create_feature_extractor, get_featuresdict_from_dataset
from tqdm import tqdm

from datasets.numpy import FolderNumpyDataset
from src.fidelity.utils import ImageDataset
from utils.logger.logger import Logger
from utils.utils import get_class_name


def run(
        folder: str,
        image_height: int,
        image_width: int,
        image_channels: int,
        num_samples: int,
        batch_size: int,
        distribution_save_path: str,
        start_index: int = 0,
        classes_save_path: str = None,
        num_workers: int = 8
) -> None:
    Logger.debug(
        f'{get_class_name(run)} - '
        f'folder: {folder}, '
        f'image_height: {image_height}, '
        f'image_width: {image_width}, '
        f'image_channels: {image_channels}, '
        f'num_samples: {num_samples}, '
        f'batch_size: {batch_size}, '
        f'distribution_save_path: {distribution_save_path}, '
        f'start_index: {start_index}, '
        f'classes_save_path: {classes_save_path}, '
        f'num_workers: {num_workers}'
    )
    dataset: FolderNumpyDataset = FolderNumpyDataset(
        data_shape=(image_channels, image_height, image_width),
        data_type=np.float64,
        folder=folder,
        num_samples=num_samples,
        start_index=start_index
    )
    data_loader: torch.utils.data.DataLoader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        drop_last=False
    )
    feat_extractor = create_feature_extractor(
        'inception-v3-compat', ['logits_unbiased'], cuda=torch.cuda.is_available())

    classes: np.ndarray = np.zeros(num_samples, dtype=np.int64)
    distribution: np.ndarray = np.zeros(1008, dtype=np.int64)

    for i, data_batch in enumerate(tqdm(data_loader)):
        start_index: int = i * batch_size
        end_index: int = min((i + 1) * batch_size, num_samples)

        data_batch: torch.Tensor = torch.round(torch.clip(data_batch * 0.5 + 0.5, 0, 1) * 255).to(torch.uint8)
        features_dict = get_featuresdict_from_dataset(
            input=ImageDataset(data_batch.detach().cpu().numpy()),
            feat_extractor=feat_extractor,
            batch_size=batch_size,
            cuda=torch.cuda.is_available(),
            save_cpu_ram=False,
            verbose=True
        )
        curr_classes_torch = torch.argmax(features_dict['logits_unbiased'], dim=1)
        curr_classes: np.ndarray = curr_classes_torch.detach().cpu().numpy()
        curr_distribution: np.ndarray = np.bincount(curr_classes, minlength=features_dict['logits_unbiased'].shape[1])

        classes[start_index:end_index] = curr_classes
        distribution += curr_distribution

    if classes_save_path is not None:
        if os.path.dirname(classes_save_path) != '':
            os.makedirs(os.path.dirname(classes_save_path), exist_ok=True)
        np.save(classes_save_path, classes)

    if os.path.dirname(distribution_save_path) != '':
        os.makedirs(os.path.dirname(distribution_save_path), exist_ok=True)
    np.save(distribution_save_path, distribution)


def run_from_config(config: dict[str, Any]) -> None:
    run(**config)


def parse_args() -> argparse.Namespace:
    parser: argparse.ArgumentParser = argparse.ArgumentParser()
    parser.add_argument('--folder', type=str, required=True)
    parser.add_argument('--image_height', type=int, required=True)
    parser.add_argument('--image_width', type=int, required=True)
    parser.add_argument('--image_channels', type=int, required=True)
    parser.add_argument('--num_samples', type=int, required=True)
    parser.add_argument('--batch_size', type=int, required=True)
    parser.add_argument('--distribution_save_path', type=str, required=True)
    parser.add_argument('--start_index', type=int, default=0)
    parser.add_argument('--classes_save_path', type=str, default=None)
    parser.add_argument('--num_workers', type=int, default=8)
    return parser.parse_args()


def get_config_from_args(args: argparse.Namespace) -> dict[str, Any]:
    return vars(args)


def main() -> None:
    args: argparse.Namespace = parse_args()
    config: dict[str, Any] = get_config_from_args(args)
    run_from_config(config)


if __name__ == '__main__':
    main()
