from dataclasses import dataclass, field
from typing import Any, List, Optional, Tuple

from omegaconf import MISSING, SI

from conf._util import return_factory
from conf.dataset import ValueRange


@dataclass
class LossParams:
    loss_name: str = 'l2'  # [ l1 | l2 | huber | ce ]
    reduction: str = 'none'  # [ mean | sum | none ]
    predict_noise: bool = True


@dataclass
class MetricsParams:
    name: str = 'blender3'
    n_class     : Optional[int] = SI('${dataset_params.data_params.n_class}')
    ignore_index: Optional[int] = SI('${dataset_params.data_params.ignore_index}')

    metrics_logging_stage: List[str] = return_factory(['train', 'valid', 'test'])
    metrics_logging_freq : List[int] = return_factory([1, 1, 1])


@dataclass
class LearningRateWarmUp:
    use_scheduler: bool = True
    start_factor: float = 1/5
    end_factor: float = 1.
    total_iters: int = 5
    last_epoch: int = -1
    verbose: bool = True


@dataclass
class ReduceOnPlateauParams:
    use_scheduler: bool = True  # use reduce on plateau scheduler
    mode: str = 'min'  # [ min | max ]
    factor: float = 0.5
    patience: int = 5
    threshold: float = 0
    threshold_mode: str = 'rel'  # [ rel | abs ]
    cooldown: int = 0
    min_lr: float = 1e-07
    eps: float = 1e-08
    verbose: bool = True
    ##
    interval: str = 'epoch'  # [ epoch | step ]
    monitor: str = 'valid/step/loss'


@dataclass
class CosinesSchedulerParams:
    use_scheduler: bool = False
    T_max: int = SI('${trainer_params.max_epochs}')
    eta_min: float = 0
    verbose: bool = True


@dataclass
class EMAParams:
    use: bool = False
    decay: float = 0.9999
    validate_original_weights: bool = True  # If True, the EMA Callback will not swap EMA params during the validation steps
                                            # If False, the EMA Callback will swap EMA params during the validation steps
    every_n_steps: int = 1
    cpu_offload: bool = False

    #######
    # Not Ema constructor parameters
    perform_double_validation: bool = True  # If should perform one validation step over the original weights and one over the ema weights


@dataclass
class OptimizersParams:
    learning_rate: float = 2e-5
    optimizer: str = 'adam'
    betas: Tuple[float, float] = (0.9, 0.999)
    weight_decay: float = 0
    momentum: float = 0.9

    reduce_on_plateau: ReduceOnPlateauParams = ReduceOnPlateauParams()
    cosines_scheduler: CosinesSchedulerParams = CosinesSchedulerParams()
    learning_rate_warmup: LearningRateWarmUp = LearningRateWarmUp()
    ema: EMAParams = EMAParams()

    max_epochs: int = SI('${trainer_params.max_epochs}')
    max_steps : int = SI('${trainer_params.max_steps}' )


@dataclass
class BackboneParams:
    controlnet_ckpt: Optional[str] = None

    name: str = 'none'
    """
    unet_cold
    unet_cold_multi_time
    celeba_model1
    celeba_model3
    """
    dim: int = 64
    time_dim: int = 256
    init_dim: int = 64
    out_dim: int = SI('${model_params.backbone.channels}')
    dim_mults: List[int] = return_factory([1, 2, 4, 8])
    channels: int = 9
    resnet_block_groups: int = 4
    use_convnext: bool = True
    convnext_mult: int = 2

    with_time_emb: bool = True
    residual: bool = True  # should be set to false if classes != in_channels

    # added for multi time
    dimension_per_domain: List[int] = SI('${dataset_params.data_params.dimension_per_domain}')

    encoder_split: bool = False
    encoder_attention_per_block: Optional[List[bool]] = None
    encoder_time_embedding_per_block: Optional[List[bool]] = None

    # used only when encoder_split is True
    pz_strat: str = 'cat'  # [ cat | sum | mean | m0 | m1 | m2 | m01 | m02 ]
    z_mid_strat: str = 'reduce_to_single'  # [ identity | reduce_to_single ]

    middle_linear_attention: bool = False
    middle_attention: bool = True
    middle_time_embedding: bool = True
    middle_nb_block_following: int = 0

    decoder_split: bool = False
    decoder_attention_per_block: Optional[List[bool]] = None
    decoder_time_embedding_per_block: Optional[List[bool]] = None

    use_double_skip: bool = True
    ##############
    umm_csgm_vanilla_middle: bool = True
    with_mode_emb: bool = True  # if perform additive attention on the mode
    use_pz_m: bool = True  # if perform concatenation on the mode

    ##############

    # Celebahq3 params: change its location later if not lazy
    freeze_encoder: List[bool] = return_factory([True, False, False])
    freeze_decoder: List[bool] = return_factory([True, False, False])
    freeze_time_embedding: bool = False
    freeze_bottleneck: bool = False

    # [ non | path_<path> | link_<link> ]
    pretrained_encoder: List[str] = return_factory([
        'link_google/ddpm-ema-celebahq-256',
        'non',
        'non',
    ])
    pretrained_decoder: List[str] = return_factory([
        'link_google/ddpm-ema-celebahq-256',
        'non',
        'non',
    ])
    pretrained_bottleneck: str = 'non'
    pretrained_time_embedding: str = 'link_google/ddpm-ema-celebahq-256'


@dataclass
class DiffusionParams_DDPM:
    diffusion_name: str = 'ddpm'

    n_steps_training  : int = 1_000
    n_steps_generation: int = SI('${model_params.diffusion.n_steps_training}')
    beta_min: float = 0.0001
    beta_max: float = 0.02

    clamp_generation: bool = True
    clamp_end: bool = True
    clamp_min_max: Tuple[int, int] = (-1, 1)
    jump: bool = False
    jump_len: int = 10
    jump_n_sample: int = 10


@dataclass
class DiffusionParams_Cosine:
    diffusion_name: str = 'cosine'

    n_steps_training  : int = 50
    n_steps_generation: int = SI('${model_params.diffusion.n_steps_training}')
    s: float = 0.008
    beta_min: float = 0.
    beta_max: float = 0.999

    clamp_generation: bool = True
    clamp_end: bool = True
    clamp_min_max: Tuple[int, int] = (-1, 1)
    jump: bool = False
    jump_len: int = 10
    jump_n_sample: int = 10


@dataclass
class DiffusionParams_DDIM:
    diffusion_name: str = 'ddim'

    n_steps_training  : int = 1_000
    n_steps_generation: int = 100
    ddim_discretize: str = 'uniform'  # [ uniform | quad ] specifies how to extract τ from [1,2,…,T].
    ddim_eta: float = 0.  # ddim_eta is η used to calculate στi.η=0 makes the sampling process deterministic.

    time_step_factor: float = 0.8

    beta_min: float = 0.0001
    beta_max: float = 0.02

    #
    repeat_noise: bool = False
    temperature : float = 1.   # temperature is the noise temperature (random noise gets multiplied by this)
    skip_steps: int = 0

    clamp_generation: bool = True
    clamp_end: bool = True
    clamp_min_max: Tuple[int, int] = (-1, 1)
    jump: bool = False
    jump_len: int = 10
    jump_n_sample: int = 10


@dataclass
class SubLoggingParams:
    logging_mode: Optional[str] = 'epoch'  # [ epoch | batch ]
    stages: List[str] = return_factory(['train', 'valid', 'test'])  # in which stage to log
    frequencies: List[int] = return_factory([1, 1, 1])
    log_first: tuple[bool, bool, bool] = (True, True, True)  # whether to log the first batch/epoch
    max_quantity: int = 5


@dataclass
class SubLoggingParamsDiversity(SubLoggingParams):
    variation_quantity: Optional[int] = None  # if not None, batch size
    generate_all_in_batch: bool = False  # if False, only generate one random image from the batch


@dataclass
class LoggingParams:
    name: str = 'blender3'

    # Logging step params
    log_steps: SubLoggingParams = SubLoggingParams()

    # Logging generate params
    log_generate: SubLoggingParams = SubLoggingParams()

    # Logging generate params
    log_generate_diversity: SubLoggingParamsDiversity = SubLoggingParamsDiversity(stages=[])

    # Generation Logging parameters
    time_step_in_process: int = 10  # number of time steps (from generation denoising process) logged in process
    strategy: str = 'quad_end'
    """
        uniform   : uniform time steps
        quad_start: quadratic time steps: more early gen
        quad_end  : quadratic time steps: more end gen
    """
    quad_factor: float = 0.8
    value_range: ValueRange = SI('${dataset_params.data_params.value_range}')

    if_all_here_generate_none: bool = False
    """
    If model use supervision information, and all the supervision is available, the flip FullSupervision To NoSupervision
    """
    hack_mode: Optional[list[list[int]]] = None  # if not None, set the modes to this during generation, if list randomly populate batch

    early_leave: bool = False  # If True, leave once the needed amount of logged image is reached

    save_image_to_disk: bool = False
    log_pt: bool = False  # if True, save the img at .pt tensor to the disk


@dataclass
class ApproachSpeParams:
    one_t_per_dom: bool = True
    train_condition: bool = False  # if True, define a specific condition (t=0 or min) during the training step and the
                                   # other are random
    train_condition_learn_condition: bool = True  # if True, also learn on the condition domain

    # Training params
    empty_handling: str = 'noise'
    """
    noise: remplace with random noise
    pred : prediction from the model at the previous time step, given the inputs + noise of the current time step
    minusone: replace with -1
    """
    replace_missing_t_with_T: bool = True  # replace missing t with T

    proportion_t0: Optional[float] = None

    noise_train: str = 'vanilla_noise'  # can be during training and generation both
    noise_gen  : str = 'vanilla_noise'  # can be during training and generation both
    """
    vanilla_noise: different noise for each domain
    constant_noise: same noise for all domains
    """

    # Generation params
    condition_mode: str = 'noisy_constant'
    """
    noisy: apply the same noise than on the generation
    clean: do not apply noise
    noisy_constant: apply constant noise during all the diffusion process
    noisy_constant_fade: apply constant noise, when generation catch-up it syncs
    noisy_skip: apply the regular noise level, but skip early steps
    """
    noisy_condition_progression: int = 80  # progression of the constant noise to use
    shift_from_gen: int = 0  # shift the generation time step, F(t_cond) = F(t_gen - shift_from_gen) [shift to the right]

    post_norm: Optional[str] = None  # [ None | norm_minmax | norm_max ] Put data in [0;1] range post generation


@dataclass
class ModelParams:
    name            : str                   = 'diffusion_all_to_all'
    """
    diffusion_all_to_all
    """
    optimizer       : OptimizersParams      = OptimizersParams()
    backbone        : BackboneParams        = BackboneParams()
    logging         : LoggingParams         = LoggingParams()
    loss            : LossParams            = LossParams()
    diffusion       : Any                   = MISSING
    metrics         : MetricsParams         = MetricsParams
    approach_spe    : ApproachSpeParams     = ApproachSpeParams()
