from dataclasses import dataclass, field
import pyrallis
import os
from datasets import load_dataset
import warnings
import torch
from prompts import *
from language_model import TaskModel
from functools import partial
import srsly
from tqdm import tqdm


@dataclass
class DataConfig:
    # Generate responses for evaluation
    
    model_path: str = "" #specify model hf path to generate responses for eval
    model_name: str = "" #specify model name to generate responses for eval
    
    
    device: str = 'cuda' #model device 
    use_auto: bool = True #allocate model parameters across available devices (GPU, CPU)
    
    #dataset arguments 
    ds_type: str = "summary" #summarization, SHP, or other dataset type for evaluation
    split: str = 'train' # evalution dataset split
    num_examples: int = -1 #num examples for evaluation
    ds_seed: int = 42 #dataset seed after shuffling for reproducibility
    hf_org: str = "" # hf org to store dataset with saved responses
    cache_dir: str = "" #cache hf dataset and models
    prompt_template: str = 'summary_0_shot' # prompt template to be used for prompting
    
    
    #hf generation kwargs
    use_cache: bool = True #temperature sampling
    do_sample: bool = False  #whether or not to use greedy decoding 
    temperature: float = 0.2  #temperature during generation
    max_new_tokens: int = 400 #max new tokens

    dataset: str = ""  #dataset name to be used for evaluation


def main():
    warnings.filterwarnings("ignore")
    cfg = pyrallis.parse(config_class=DataConfig)
    
    HF_CACHE = os.environ['HF_CACHE'] if 'HF_CACHE' in os.environ else cfg.cache_dir
    prompt_template = cfg.prompt_template
    
    if cfg.ds_type == 'summary':
        ds = load_dataset(cfg.dataset, split = cfg.split, cache_dir= HF_CACHE)
        ds = ds.shuffle(seed = cfg.ds_seed)
        num_examples = cfg.num_examples if cfg.num_examples != -1 else len(ds)
        ds = ds.select(range(num_examples))
        map_func = add_model_summary
        save_name = f"summarize_sft-{cfg.split}_lm-{cfg.model_name}_seed-{cfg.ds_seed}_numex-{num_examples}"
        prompt_template = lambda x : x
    elif cfg.ds_type == 'imdb':
        ds = load_dataset(cfg.dataset, split = cfg.split, cache_dir= HF_CACHE)
        ds = ds.shuffle(seed = cfg.ds_seed)
        num_examples = cfg.num_examples if cfg.num_examples != -1 else len(ds)
        ds = ds.select(range(num_examples))
        map_func = add_model_summary
        save_name = f"imdb_sft-{cfg.split}_lm-{cfg.model_name}_seed-{cfg.ds_seed}_numex-{num_examples}"
        prompt_template = lambda x : x
        
    else: 
        raise NotImplementedError(f'{cfg.ds_type} not supported!')
    
    prompt_template = eval(prompt_template) if isinstance(prompt_template, str) else prompt_template
    
    gen_kwargs = {'use_cache': cfg.use_cache, 'temperature': cfg.temperature, 'do_sample': cfg.do_sample, 'max_new_tokens': cfg.max_new_tokens}
    lm = TaskModel(model_name = cfg.model_path, device = cfg.device, use_auto = cfg.use_auto, cache_dir= HF_CACHE)
    
    ds = ds.map(map_func, fn_kwargs= {'llm': lm, 
            'prompt_template': prompt_template, 'gen_kwargs': gen_kwargs})
    del lm 
    torch.cuda.empty_cache()
    ds.push_to_hub(f'{cfg.hf_org}/{save_name}')


def add_model_summary(example, llm, prompt_template, gen_kwargs):
    example['model_response'] = llm(prompt_template(example['query']), **gen_kwargs)
    return example

if __name__ == '__main__':
    main()