from dataclasses import dataclass
from typing import Optional
from ...util.hparams import HyperParams
import yaml


@dataclass
class MOREMultimodalHyperParams(HyperParams):
    # M-ORE
    rank: int
    lora_alpha: float
    lora_dropout: float
    top_k: int
    eta: float
    rls_lambda: float
    n_last_layers: int
    more_score_norm: str
    more_eta_vision: Optional[float]
    more_rls_lambda_vision: Optional[float]

    # Model
    name: str
    model_name: str
    model_class: str
    tokenizer_class: str
    tokenizer_name: str

    # Multimodal
    qformer_name_or_path: str
    qformer_checkpoint: str
    state_dict_file: str
    pretrained_ckpt: str

    # Image
    coco_image: str
    rephrase_image: str

    device: int
    alg_name: str

    # Defaults
    more_nonlinear: str = "none"
    more_layer_norm: bool = False
    more_layer_norm_eps: float = 1e-5
    exact_match: bool = False
    batch_size: int = 1
    max_length: int = 30
    model_parallel: bool = False
    freeze_qformer: bool = True

    @classmethod
    def from_hparams(cls, hparams_name_or_path: str):
        if ".yaml" not in hparams_name_or_path:
            hparams_name_or_path = hparams_name_or_path + ".yaml"

        with open(hparams_name_or_path, "r") as stream:
            config = yaml.safe_load(stream)
            config = super().construct_float_from_scientific_notation(config)

        assert (config and config["alg_name"] == "M-ORE") or print(
            f"MOREMultimodalHyperParams can not load from {hparams_name_or_path}, "
            f"alg_name is {config['alg_name']}"
        )
        return cls(**config)
