from dataclasses import dataclass
from typing import Optional, Any, List
import yaml

from ...util.hparams import HyperParams


@dataclass
class MOREMultimodalTrainingHparams(HyperParams):
    # M-ORE
    rank: int
    lora_alpha: float
    lora_dropout: float
    top_k: int
    eta: float
    rls_lambda: float
    n_last_layers: int
    more_update_mode: str
    more_use_masked_z: bool
    more_restore_p: bool
    more_score_norm: str
    more_eta_vision: Optional[float]
    more_rls_lambda_vision: Optional[float]

    # Multimodal
    qformer_name_or_path: str
    state_dict_file: str

    # Image_dir
    coco_image: str
    rephrase_image: str

    # Model
    name: str
    model_name: str
    model_class: str
    tokenizer_class: str
    tokenizer_name: str
    inner_params: List[str]

    archive: Any

    # Method
    alg: str
    lr: float
    seed: int
    debug: bool
    cedit: float
    iedit: float
    cloc: float
    cbase: float
    dropout: float
    train_base: bool
    no_grad_layers: Any
    one_sided: bool
    n_hidden: int
    hidden_dim: Any
    init: str
    norm: bool
    combine: bool
    x_only: bool
    delta_only: bool
    act: str
    shared: bool

    # Output
    results_dir: str

    # Train
    device: str
    batch_size: int
    model_save_pt: int
    silent: bool
    log_interval: int
    eval_log_interval: int
    final_eval: bool
    val_interval: int
    early_stop_patience: int
    early_stop_key: str
    eval_only: bool
    half: bool
    save: bool
    verbose: bool

    val_batch_size: int
    accumulate_bs: int
    val_steps: int
    opt: str
    grad_clip: float

    more_nonlinear: str = "none"
    more_layer_norm: bool = False
    more_layer_norm_eps: float = 1e-5
    qformer_checkpoint: Optional[str] = None
    exact_match: bool = False
    model_parallel: bool = False
    freeze_qformer: bool = True
    max_epochs: Optional[int] = None
    max_iters: Optional[int] = None
    pretrained_ckpt: Optional[str] = None
    opt_precision: Optional[str] = None
    use_chat_template: bool = False
    image_processor_name_or_path: Optional[str] = None

    @classmethod
    def from_hparams(cls, hparams_name_or_path: str):
        if ".yaml" not in hparams_name_or_path:
            hparams_name_or_path = hparams_name_or_path + ".yaml"

        with open(hparams_name_or_path, "r") as stream:
            config = yaml.safe_load(stream)
            config = super().construct_float_from_scientific_notation(config)

        assert (config and config["alg"] == "MORE") or print(
            f"MOREMultimodalTrainingHparams can not load from {hparams_name_or_path}, "
            f"alg_name is {config['alg']}"
        )
        return cls(**config)
