import argparse
import os.path
import pickle
from typing import Any, Optional

import numpy as np
import torch.nn

from datasets.numpy import NumpyDataset, FolderNumpyDataset
from src.datasets.noise import NumpyNoiseDataset
from src.datasets.noise_label import NoiseLabelDataset, NumpyNoiseLabelDataset
from src.fid.utils import RandomNoiseDataset, RandomLabelDataset
from src.fidelity.fid.fid import calculate_fid_for_dataset_and_model
from src.fidelity.metrics.metrics import calculate_metrics_for_dataset_and_model
from src.fidelity.samplers.multi_step import inference_multi_step
from src.models.models.edm import create_edm_model
from src.utils.load import load_from_state_dict
from torch_utils.utils import get_default_device
from utils.logger.logger import Logger
from utils.utils import get_class_name


def save_result(result: Any, save_path: str, name: str = None) -> None:
    if save_path is not None:
        Logger.debug(f'saving {name} to {save_path}')
        if os.path.dirname(save_path) != '':
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
        with open(save_path, 'wb') as file:
            pickle.dump(result, file)


def run(
        num_steps: int,
        time_steps: list[int],
        image_height: int,
        image_width: int,
        image_channels: int,
        model_name: str,
        model_path: str,
        calculate_fid: bool = False,
        calculate_isc: bool = False,
        fid_reference_path: str = None,
        noise_folder: str = None,
        label_folder: str = None,
        noise_start_index: int = 0,
        label_start_index: int = 0,
        num_classes: Optional[int] = None,
        model_load_keys: list[str] = None,
        num_samples: int = 50_000,
        batch_size: int = 50,
        fid_save_path: str = None,
        isc_save_path: str = None,
        sigma_min: float = 0.002,
        sigma_max: float = 80,
        rho: float = 7,
        s_churn: float = 0,
        s_min: float = 0,
        s_max: float = float('inf'),
        s_noise: float = 1,
        multiply_noises: bool = True,
        same_noise: bool = True,
        data_loader_workers: int = 8
) -> None:
    Logger.debug(
        f'{get_class_name(run)} - '
        f'num_steps: {num_steps}, '
        f'time_steps: {time_steps}, '
        f'image_height: {image_height}, '
        f'image_width: {image_width}, '
        f'image_channels: {image_channels}, '
        f'model_name: {model_name}, '
        f'model_path: {model_path}, '
        f'calculate_fid: {calculate_fid}, '
        f'calculate_isc: {calculate_isc}, '
        f'fid_reference_path: {fid_reference_path}, '
        f'noise_folder: {noise_folder}, '
        f'label_folder: {label_folder}, '
        f'noise_start_index: {noise_start_index}, '
        f'label_start_index: {label_start_index}, '
        f'num_classes: {num_classes}, '
        f'model_load_keys: {model_load_keys}, '
        f'num_samples: {num_samples}, '
        f'batch_size: {batch_size}, '
        f'fid_save_path: {fid_save_path}, '
        f'isc_save_path: {isc_save_path}, '
        f'sigma_min: {sigma_min}, '
        f'sigma_max: {sigma_max}, '
        f'rho: {rho}, '
        f's_churn: {s_churn}, '
        f's_min: {s_min}, '
        f's_max: {s_max}, '
        f's_noise: {s_noise}, '
        f'multiply_noises: {multiply_noises}, '
        f'same_noise: {same_noise}, '
        f'data_loader_workers: {data_loader_workers}'
    )
    conditional: bool = num_classes is not None
    Logger.debug(f'conditional: {conditional}')

    Logger.debug('creating model')
    model: torch.nn.Module = create_edm_model(model_name)
    Logger.debug('loading model')
    model: torch.nn.Module = load_from_state_dict(model, model_path, model_load_keys)
    model.to(get_default_device())
    model.eval()

    noise_dataset: NumpyDataset = FolderNumpyDataset(
        data_shape=(image_channels, image_height, image_width),
        data_type=np.float64,
        folder=noise_folder,
        num_samples=num_samples,
        start_index=noise_start_index
    ) if noise_folder is not None else RandomNoiseDataset(
        num_samples=num_samples,
        noise_shape=(image_channels, image_height, image_width)
    )
    label_dataset: NumpyDataset = FolderNumpyDataset(
        data_shape=(),
        data_type=np.int64,
        folder=label_folder,
        num_samples=num_samples,
        start_index=label_start_index
    ) if label_folder is not None else (
        RandomLabelDataset(
            num_samples=num_samples,
            num_classes=num_classes
        ) if conditional else None
    )

    dataset: NoiseLabelDataset = NumpyNoiseLabelDataset(
        noise_dataset=noise_dataset,
        label_dataset=label_dataset
    ) if conditional else NumpyNoiseDataset(
        noise_dataset=noise_dataset
    )

    Logger.debug('calculating fid')
    result: dict[str, Any] = calculate_metrics_for_dataset_and_model(
        model=model,
        dataset=dataset,
        calculate_fid=calculate_fid,
        calculate_isc=calculate_isc,
        fid_reference_path=fid_reference_path,
        inference_batch_func=lambda m, n, l: inference_multi_step(
            model=m,
            num_steps=num_steps,
            time_steps=time_steps,
            noises=n,
            labels=l if conditional else None,
            num_classes=num_classes,
            sigma_min=sigma_min,
            sigma_max=sigma_max,
            rho=rho,
            s_churn=s_churn,
            s_min=s_min,
            s_max=s_max,
            s_noise=s_noise,
            multiply_noises=multiply_noises,
            same_noise=same_noise
        ),
        batch_size=batch_size,
        data_loader_workers=data_loader_workers
    )
    if calculate_fid:
        Logger.debug(f'fid: {result["fid"]}')
        save_result(result['fid'], fid_save_path, 'fid')
    if calculate_isc:
        Logger.debug(f'isc: {result["isc"]}')
        save_result(result['isc'], isc_save_path, 'isc')


def run_from_config(config: dict[str, Any]) -> None:
    run(**config)


def parse_args() -> argparse.Namespace:
    parser: argparse.ArgumentParser = argparse.ArgumentParser()
    parser.add_argument('--num_steps', type=int, required=True)
    parser.add_argument('--time_steps', type=int, nargs='+', required=True)
    parser.add_argument('--image_height', type=int, required=True)
    parser.add_argument('--image_width', type=int, required=True)
    parser.add_argument('--image_channels', type=int, required=True)
    parser.add_argument('--model_name', type=str, required=True)
    parser.add_argument('--model_path', type=str, required=True)
    parser.add_argument('--calculate_fid', type=bool, default=False)
    parser.add_argument('--calculate_isc', type=bool, default=False)
    parser.add_argument('--fid_reference_path', type=str, default=None)
    parser.add_argument('--noise_folder', type=str, default=None)
    parser.add_argument('--label_folder', type=str, default=None)
    parser.add_argument('--noise_start_index', type=int, default=0)
    parser.add_argument('--label_start_index', type=int, default=0)
    parser.add_argument('--num_classes', type=int, default=None)
    parser.add_argument('--model_load_keys', type=str, nargs='+', default=None)
    parser.add_argument('--num_samples', type=int, default=50_000)
    parser.add_argument('--batch_size', type=int, default=50)
    parser.add_argument('--fid_save_path', type=str, default=None)
    parser.add_argument('--isc_save_path', type=str, default=None)
    parser.add_argument('--sigma_min', type=float, default=0.002)
    parser.add_argument('--sigma_max', type=float, default=80)
    parser.add_argument('--rho', type=float, default=7)
    parser.add_argument('--s_churn', type=float, default=0)
    parser.add_argument('--s_min', type=float, default=0)
    parser.add_argument('--s_max', type=float, default=float('inf'))
    parser.add_argument('--s_noise', type=float, default=1)
    parser.add_argument('--multiply_noises', type=bool, default=True)
    parser.add_argument('--same_noise', type=bool, default=True)
    parser.add_argument('--data_loader_workers', type=int, default=8)
    return parser.parse_args()


def get_config_from_args(args: argparse.Namespace) -> dict[str, Any]:
    return vars(args)


def main() -> None:
    args: argparse.Namespace = parse_args()
    config: dict[str, Any] = get_config_from_args(args)
    run_from_config(config)


if __name__ == '__main__':
    main()
