from dataclasses import dataclass
from typing import List
from ...util.hparams import HyperParams
import yaml

@dataclass
class MODELHyperParams(HyperParams):
    name: str
    class_name: str
    tokenizer_class: str
    tokenizer_name: str
    fan_in_fan_out: bool
    target_modules: list[str]
    pt: str # set this to 'hallucination' inside your checkpoint directory
@dataclass
class LoRAHyperParams(HyperParams):
  cls_name: str
  cls_class: str
  supervised: bool
  cos: bool
  freeze: str
  square: bool
  bound_embeds: bool
  use_all_negatives: bool
  freeze_lora: bool
  dist_heads: int
  cross_attend: bool
  soft_weighting: bool
  checkpoint_grad: bool
  lora_r: int
  lora_alpha: int
  lora_dropout: float
 
@dataclass
class CAKEHyperParams(HyperParams):
    model_name: str
    alg_name: str
    model_parallel: bool
    device: int
    max_length: int
    task: str
    lora_task_type: str
    check_dir: str
    batch_size: int
    model: MODELHyperParams
    lora: LoRAHyperParams
    
    @classmethod
    def from_hparams(cls, hparams_name_or_path: str):
        if '.yaml' not in hparams_name_or_path:
            hparams_name_or_path = hparams_name_or_path + '.yaml'

        with open(hparams_name_or_path, "r") as stream:
            config = yaml.safe_load(stream)
            config = super().construct_float_from_scientific_notation(config)
        
        model_config = MODELHyperParams(**config['model'])
        config['model'] = model_config
        lora_config = LoRAHyperParams(**config['lora'])
        config['lora'] = lora_config
        return cls(**config)
