from dataclasses import dataclass
from typing import Optional, List


@dataclass
class TrainingConfig:
    seed: int = 42
    wandb_project: str = "univa-denoiser"
    wandb_name: str = "default_config"
    output_dir: str = "./output"
    logging_dir: str = "./logs"
    gradient_accumulation_steps: int = 1
    learning_rate: float = 1e-5
    adam_beta1: float = 0.9
    adam_beta2: float = 0.999
    adam_epsilon: float = 1e-8
    adam_weight_decay: float = 1e-2
    mixed_precision: str = "bf16"
    report_to: str = "wandb"
    gradient_checkpointing: bool = False
    num_train_epochs: int = 1
    max_train_steps: Optional[int] = None
    lr_scheduler: str = "constant"
    lr_warmup_steps: int = 0
    lr_num_cycles: int = 1
    lr_power: float = 1.0
    resume_from_checkpoint: Optional[str] = None
    weighting_scheme: Optional[str] = (
        "logit_normal"  # ["sigma_sqrt", "logit_normal", "mode", "cosmap", "null"]
    )
    logit_mean: float = 0.0
    logit_std: float = 1.0
    mode_scale: float = 1.29
    max_grad_norm: float = 1.0
    checkpointing_steps: int = 100
    checkpoints_total_limit: Optional[int] = 500
    drop_condition_rate: float = 0.0
    drop_t5_rate: float = 1.0
    validation_steps: int = 100
    num_validation_images: int = 1

    noise_reference_images: bool = False 
    mask_weight_type: Optional[str] = None     # ['log', 'exp']
    sigmas_as_weight: bool = False  # Used in Flux

    discrete_timestep: bool = True  # Used in Flux


    optimizer: str = 'adamw'  # ['adamw', 'prodigy']

    prodigy_use_bias_correction: bool = True
    prodigy_safeguard_warmup: bool = True
    prodigy_decouple: bool = True
    prodigy_beta3: Optional[float] = None 
    prodigy_d_coef: float = 1.0

    profile_out_dir: Optional[str] = None

    ema_deepspeed_config_file: Optional[str] = None
    ema_update_freq: int = 1
    ema_decay: float = 0.99

@dataclass
class DatasetConfig:
    dataset_type: str
    data_txt: str
    batch_size: int = 16
    num_workers: int = 4
    height: int = 512
    width: int = 512
    min_pixels: int = 448*448
    max_pixels: int = 448*448
    anyres: str = 'any_1ratio'
    ocr_enhancer: bool = False
    random_data: bool = False
    padding_side: str = 'right'
    validation_t2i_prompt: Optional[str] = None
    validation_it2i_prompt: Optional[str] = None
    validation_image_path: Optional[str] = None
    pin_memory: bool = True
    validation_iit2i_prompt: Optional[str] = None
    validation_iit2i_path: Optional[List[str]] = None

    validation_REFiit2i_prompt: Optional[str] = None
    validation_REFiit2i_path: Optional[List[str]] = None


    validation_cannyt2i_prompt: Optional[str] = None
    validation_cannyt2i_path: Optional[str] = None
    validation_poset2i_prompt: Optional[str] = None
    validation_poset2i_path: Optional[str] = None
    
    validation_it2pose_prompt: Optional[str] = None
    validation_it2pose_path: Optional[str] = None
    validation_it2canny_prompt: Optional[str] = None
    validation_it2canny_path: Optional[str] = None

    validation_NIKEit2i_prompt: Optional[str] = None
    validation_NIKEit2i_path: Optional[str] = None

    validation_TRANSFERit2i_prompt: Optional[str] = None
    validation_TRANSFERit2i_path: Optional[str] = None

    validation_EXTRACTit2i_prompt: Optional[str] = None
    validation_EXTRACTit2i_path: Optional[str] = None

    validation_TRYONit2i_prompt: Optional[str] = None
    validation_TRYONit2i_path: Optional[str] = None

    validation_REPLACEit2i_prompt: Optional[str] = None
    validation_REPLACEit2i_path: Optional[str] = None

    validation_DETit2i_prompt: Optional[str] = None
    validation_DETit2i_path: Optional[str] = None

    validation_SEGit2i_prompt: Optional[str] = None
    validation_SEGit2i_path: Optional[str] = None

@dataclass
class ModelConfig:
    pretrained_lvlm_name_or_path: str
    pretrained_denoiser_name_or_path: str
    

    guidance_scale: float = 1.0  # Used in Flux
    tune_mlp1_only: bool = False
    pretrained_mlp1_path: Optional[str] = None

    # Denoise projector (QwenVL -> FLUX dimension mapping)
    with_tune_mlp2: bool = False
    only_tune_mlp2: bool = False
    pretrained_mlp2_path: Optional[str] = None

    only_tune_image_branch: bool = True  # Used in SD3

    flux_train_layer_idx: Optional[list] = None

    joint_ref_feature: bool = False
    joint_ref_feature_as_condition: bool = False
    only_use_t5: bool = False

    vlm_residual_image_factor: float = 0.0

    vae_fp32: bool = True
    compile_flux: bool = False
    compile_qwen2p5vl: bool = False

    ema_pretrained_lvlm_name_or_path: Optional[str] = None
    
@dataclass
class UnivaTrainingDenoiseConfig:
    training_config: TrainingConfig
    dataset_config: DatasetConfig
    model_config: ModelConfig
