from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Tuple, Union

from omegaconf import MISSING, SI

from conf._util import return_factory


@dataclass
class SupervisionDatasetConfig:
    number_of_domains: int = SI('${dataset_params.data_params.n_dom}')
    # region supervision
    supervision_mode: Optional[List[str]] = None  # select which part of the dataset is returned
    """
    in each part, we still return the whole dataset line, each model has to care what mode it is using
    None: no filter is performed
   
    otherwise, should be a list of str, where each str should be made of integer in [0, num_domains)
    for example, if we have 3 domains x y z, and we want to return the samples where (x,z) are present, we should use
    [
        '0,2',  # x and z
    ]
    
    not specified domains are ignored
    they should always be in sorted order
    """

    proportions_mode: str = 'frac'
    """
    how to interpret the proportions
        - frac: [float] values are proportions, eg 0.5 is 50%
        - abso: [int] values are absolutes, specify the number of samples for each parts
    """
    # sum should be equal or less than one
    proportions: Optional[List[str]] = return_factory([
        '0v1v2', '0.4',
        '0v1', '0.2',
        '0v2', '0.2',
        '1v2', '0.2',
    ])
    proportions_file: Optional[str] = None  # if not None, load the proportions from this file, incompatible with proportions not None
    """
    None: every data are present
    
    otherwise, key is a domain representation (like from supervision_mode) and value is the proportion representation
    """

    return_supervision : bool = True
    random_supervision : bool = True  # if the token list is shuffled
    random_from_dataset: bool = True  # if we randomly sample from the source dataset:
    random_file        : Optional[str] = None  # if not None, pickle the random list to this file
                                               # if not None, random_supervision and random_from_dataset should be False
    # endregion


class ValueRange(Enum):
    Zero        = '01'
    ZeroUnbound = '01'
    One         = '11'
    OneUnbound  = '11unbound'


# region dataset spec params
@dataclass
class BlenderParams:
    n_dom: int = 3
    name: str = 'blender'
    root: str = 'path/to/blender_inversion_rgb_pickle.lmdb'
    return_params: bool = False
    is_one: bool = False  # collapse all domains into one
    return_domain: Optional[List[int]] = None  # return only the specified domains if is_one is False
    height  : int = 64
    width   : int = 64
    channels: int = 3
    n_class : int = 0
    ignore_index: Optional[int] = None

    value_range: ValueRange = ValueRange.One
    dimension_per_domain: List[int] = return_factory([3, 3, 3])

    get_item_for_translation: bool = False  # Set if export pyarrow to pickle
    use_pickle: bool = True  # if we are using picke loading function

    return_indice: bool = False


@dataclass
class CelebAParams:
    n_dom: int = 1
    targeted_domain: Optional[int] = None  # if n_dom is 1, return only the specified domain

    name: str = 'celeba'
    root: str = r'path/to/CelebAMask-HQ'
    height0: Optional[int] = None
    width0: Optional[int] = None
    height: int = 256
    width: int = 256

    only_image: bool = True

    return_background: bool = True  # if return background in the domains

    value_range: ValueRange = ValueRange.One

    segmentation_fusion: bool = True  # if fusion the segmentation to go from 19 to 10 classes
    number_class_before_fusion: int = 19
    number_class_after_fusion: int = 10

    dimension_per_domain: List[int] = return_factory([3, 1, 10])
    ignore_index: Optional[int] = None
    n_class: int = 10

    channels: int = 3  # only for on domain, for the generation

    sketch_version: int = 2
    """
    1: sketch from pix2style2pix
    2: sketch from Semi-Supervised Learning for Face Sketch Synthesis in the Wild, ACCV2018
    """

    random_flip: bool = True

    return_indice: bool = True


@dataclass
class BRATS2020Params:
    n_dom: int = 5
    n_dom_scan: int = 4
    n_dom_seg: int = 2

    name: str = 'brats2020'
    root: str = r'path/to/brats2020_bounded'
    return_params: bool = False
    height: int = 256
    width: int = 256
    n_class: int = 4
    ignore_index: Optional[int] = None
    segmentation_mode: int = 2
    """
    - 1 -> remove background, and OR on the remaining
    - 2 -> keep background, and OR on the remaining
    - 3 -> remove background, keep the others
    - 4 -> keep everything
    """

    scan_names: List[str] = return_factory(['t1', 't1ce', 't2', 'flair'])
    seg_names: List[str] = return_factory(['GD-enhancing', 'peritumoral_edema', 'necrotic_and_other'])

    value_range: ValueRange = ValueRange.One
    preprocess_func: Optional[str] = None  # [ None | pf01 ]

    dimension_per_domain: List[int] = return_factory([
        1, 1, 1, 1,  # t1, t1ce, t2, flair
        2,  # segmentation map [background, GD-enhancing, peritumoral edema, necrotic and non-enhancing tumor core]
    ])
    split_segmentation: bool = False  # if each segmentation is one domain

    test_flag: bool = False  # if True, return the test set

    return_indice: bool = True

# endregion


@dataclass
class DatasetParams:
    data_params: Any = MISSING

    drop_last_train: bool = True
    drop_last_valid: bool = False
    drop_last_test : bool = False

    batch_size: int = 64
    batch_size_val : int = SI('${dataset_params.batch_size}')
    batch_size_test: int = SI('${dataset_params.batch_size}')
    use_min_for_batch_size: bool = True  # if drop_last_train is True and len(train)<batch size, set batch size to len(train)
    workers: int = 0
    pin_memory: bool = True

    supervision_params_train: SupervisionDatasetConfig = SupervisionDatasetConfig()
    supervision_params_valid: SupervisionDatasetConfig = SupervisionDatasetConfig(proportions=[
        '0v1', '0.33',
        '0v2', '0.33',
        '1v2', '0.34',
    ])
    supervision_params_test : SupervisionDatasetConfig = SupervisionDatasetConfig(proportions=[
        '0v1', '0.33',
        '0v2', '0.33',
        '1v2', '0.34',
    ])

    train_prop: Union[float, int, str] = 0.80
    valid_prop: Union[float, int, str] = 0.10
    test_prop : Union[float, int, str] = 0.10
    file_path: Optional[str] = None  # if not None, use this file to order the indices in the dataset before splitting
    proportion_mode: str = 'frac'
    """
    how to interpret the proportions
        - frac: [float] values are proportions, eg 0.5 is 50%
        - perc: [int, float] values are percentages, eg 50 is 50%
        - abso: [int] values are absolutes, specify the number of samples for each parts
    """

    limit_train: Optional[int] = None
    limit_valid: Optional[int] = None
    limit_test : Optional[int] = None
