import argparse
import os
import pickle
from typing import Any

import numpy as np

from src.fidelity.isc.isc import calculate_isc
from src.fidelity.utils import ImagesFormat, ImagesScale
from torch_utils.utils import get_default_device
from utils.logger.logger import Logger
from utils.numpy.load import load
from utils.utils import get_class_name


def run(
        folder: str,
        save_path: str = None,
        num_samples=50_000,
        batch_size: int = 50,
        start_index: int = 0,
        num_processes: int = 8,
        device: str = None
) -> None:
    Logger.debug(
        f'{get_class_name(run)} - '
        f'folder: {folder}, '
        f'save_path: {save_path}, '
        f'num_samples: {num_samples}, '
        f'batch_size: {batch_size}, '
        f'start_index: {start_index}, '
        f'num_processes: {num_processes}, '
        f'device: {device}'
    )
    if device is None:
        device: str = get_default_device()

    images: np.ndarray = load(
        folder=folder,
        n_samples=num_samples,
        start_index=start_index,
        n_processes=num_processes
    )
    mean, std = calculate_isc(
        images=images,
        batch_size=batch_size,
        images_scale=ImagesScale.MINUS_ONE_TO_ONE,
        images_format=ImagesFormat.NCHW
    )

    Logger.info(f'isc mean: {mean}, isc std: {std}')

    if save_path is not None:
        if os.path.dirname(save_path) != '':
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
        with open(save_path, 'wb') as file:
            pickle.dump({'mean': mean, 'std': std}, file)

    Logger.debug(f'{get_class_name(run)} - done')


def run_from_config(config: dict[str, Any]) -> None:
    run(**config)


def parse_args() -> argparse.Namespace:
    parser: argparse.ArgumentParser = argparse.ArgumentParser()
    parser.add_argument('--folder', type=str, required=True)
    parser.add_argument('--save_path', type=str, default=None)
    parser.add_argument('--num_samples', type=int, default=50_000)
    parser.add_argument('--batch_size', type=int, default=50)
    parser.add_argument('--start_index', type=int, default=0)
    parser.add_argument('--num_processes', type=int, default=8)
    return parser.parse_args()


def get_config_from_args(args: argparse.Namespace) -> dict[str, Any]:
    return vars(args)


def main() -> None:
    args: argparse.Namespace = parse_args()
    config: dict[str, Any] = get_config_from_args(args)
    run_from_config(config)


if __name__ == '__main__':
    main()
