from dataclasses import dataclass
from typing import Optional, List, Tuple

from omegaconf import SI

from conf._util import return_factory
from conf.dataset import ValueRange


@dataclass
class InceptionDistanceParams:
    use            : bool = False
    fid_dims       : int = 2048

    kid_feature    : int = 2048
    kid_subsets    : int = 100
    kid_subset_size: int = 1_000
    kid_degree     : int = 3
    kid_gamma      : Optional[float] = None
    kid_coef       : float = 1.0

    init                    : bool = True  # If True and load_initialization_path does not exist, init with real stats
    fid_load_initialization_path: Optional[str] = './_fids/fid_init.ckpt'
    kid_load_initialization_path: Optional[str] = './_kids/kid_init.ckpt'
    number_to_generate      : int = 2_000
    check_frequency         : int = 1  # frequency to compute in step for validation. Test always compute
    compute_first           : bool = False
    stages                  : List[str] = return_factory(['valid', 'test'])

    # Dataloader params
    batch_size     : int = 100
    num_workers    : int = 10
    pin_memory     : bool = True
    prefetch_factor: int = 2

    # Data normalization
    value_range: ValueRange = SI('${dataset_params.data_params.value_range}')
    norm_func  : str = 'celeba3'  # [ celeba3 ]

    # Used to makes fakes batch data, when sampling from unconditional model
    dimension_per_domain: List[int] = SI('${dataset_params.data_params.dimension_per_domain}')

    compute_running: bool = True  # If compute during running validation and testing step
    hack_mode      : Tuple[int] = (0, 0, 0)
    compute_on_ema : bool = True

    # If we .compute() multiple time during the process
    running_compute     : bool = False
    running_compute_freq: int = 1_000
