from dataclasses import dataclass
from datetime import datetime
from pathlib import Path

from omegaconf import DictConfig, OmegaConf

from paths import get_image_path, get_question_path, get_model_path, get_save_path

@dataclass
class Config:
    exp_name: str
    seed: int
    model: str
    model_base: str
    model_path: str
    dataset: str
    dataset_type: str
    image_path: str
    question_path: str
    save_path: str
    exp_type: str
    exp_hp: DictConfig
    use_cd: bool

    conv_mode: str
    num_chunks: int
    chunk_idx: int
    temperature: float
    top_p: float
    top_k: int
    
    cd_alpha: float
    cd_beta: float
    noise_step: float

    do_sample: bool

    @classmethod
    def from_omegaconf(cls, args: DictConfig):

        args = OmegaConf.to_object(args)
        exp_name = get_exp_name(args['model'], args['dataset'], args['dataset_type'], args['exp_type'], args['custom_name'])

        image_path = get_image_path(args['dataset'])
        question_path = get_question_path(args['question_path'], args['dataset_type'])
        model_path = get_model_path(args['model'])
        save_path = get_save_path(args['dataset'], exp_name)
        
        use_cd = True
        if args['exp_type'] == 'regular':
            use_cd = False

        config = cls(
            exp_name=exp_name,
            seed=args['seed'],
            model=args['model'],
            model_base=args['model_base'],
            model_path=model_path,
            dataset=args['dataset'],
            dataset_type=args['dataset_type'],
            image_path=image_path,
            question_path=question_path,
            save_path=save_path,
            exp_type=args['exp_type'],
            exp_hp=OmegaConf.create(args['exp_hp']),
            use_cd=use_cd,
            
            conv_mode=args['conv_mode'],
            num_chunks=args['num_chunks'],
            chunk_idx=args['chunk_idx'],
            temperature=args['temperature'],
            top_p=args['top_p'],
            top_k=args['top_k'],
            
            cd_alpha=args['cd_alpha'],
            cd_beta=args['cd_beta'],
            noise_step=args['noise_step'],

            do_sample=args['do_sample']
        )
        return config
    
def get_exp_name(model: str, dataset: str, dataset_type:str, exp_type: str, custom_name: str) -> str:
    if dataset_type == None:
        exp_name = f"{model}_{dataset}_{exp_type}_{custom_name}"
    else:
        exp_name = f"{model}_{dataset}_{dataset_type}_{exp_type}_{custom_name}"
    return exp_name