
from dataclasses import dataclass, field
import pyrallis
import os
from datasets import load_dataset

import warnings

from prompts import *
from language_model import LabelModel
from functools import partial
import srsly
from tqdm import tqdm


@dataclass
class DataConfig:
    model_path: str = "Qwen/Qwen2.5-72B-Instruct" #label model path for evaluation
    model_name: str = "qwen2.5-72b" #label model name for evaluation
    
    #response datasets 
    response1_path: str = None # huggingface hub path storing location of response 1
    response2_path: str = None # huggingface hub path storing location of response 2
    
    #dataset arguments 
    cache_dir: str = "" #cache hf dataset and models
    ds_type: str = "summary" #summarization, SHP, or other dataset type
    split: str = "train" # dataset split for evaluation
    num_examples: int = -1 #num examples for evaluation
    seed: int = 42 #dataset seed after shuffling for reproducibilit
    prompt_template: str = 'detailed_1_shot_preamble'  # prompt template to be used for prompting label model
    
    zero_id: str = "1" # label for response 1
    one_id: str = "2" #label for response 2
    reprompt: bool = False # have label LLM generate full response and then compute forward pass to extract label
    
    device: str = 'cuda' #label model device arguments
    use_auto: bool = True #allocate model parameters across available devices (GPU, CPU)
    
    # hf kwargs
    use_cache: bool = True #temperature sampling
    do_sample: bool = False  #whether or not to use greedy decoding 
    temperature: float = 0.2  #temperature during generation
    max_new_tokens: int = 400 #max new tokens


def main():
    warnings.filterwarnings("ignore")
    cfg = pyrallis.parse(config_class=DataConfig)
    HF_CACHE = os.environ['HF_CACHE'] if 'HF_CACHE' in os.environ else cfg.cache_dir
    
    
    os.makedirs(f"eval_results/{cfg.ds_type}", exist_ok=True)
    ovr_path = f'eval_results/{cfg.ds_type}/p1-{cfg.response1_path.split("/")[-1]}_p2-{cfg.response2_path.split("/")[-1]}'
    
    # import pdb;pdb.set_trace()
    
    if cfg.ds_type == 'summary':
        ds1 = load_dataset(cfg.response1_path, split = cfg.split, cache_dir = HF_CACHE)
        response1_lst = ds1['model_response']
        
        contexts_lst = [query[:-7] for query in ds1['query'] ]
        
        ds2 = load_dataset(cfg.response2_path, split = cfg.split, cache_dir = HF_CACHE)
        response2_lst = ds2['model_response']
    else:
        #eval some arbitrary set of responses
        ds1 = load_dataset(cfg.response1_path, split = cfg.split, cache_dir = HF_CACHE)
        num_examples = cfg.num_examples if cfg.num_examples != -1 else len(ds1)
        ds1 = ds1.select(range(num_examples))
        response1_lst = ds1['model_response']
        ds2 = load_dataset(cfg.response2_path, split = cfg.split, cache_dir = HF_CACHE)
        num_examples = cfg.num_examples if cfg.num_examples != -1 else len(ds2)
        ds2 = ds2.select(range(num_examples))
        response2_lst = ds2['model_response']

    assert len(response1_lst) == len(response2_lst), 'lengths mismatch in number of examples'
    

    gen_kwargs = {'use_cache': cfg.use_cache, 'temperature': cfg.temperature, 'do_sample': cfg.do_sample, 'max_new_tokens': cfg.max_new_tokens}
    llm = LabelModel(cfg.model_path, device = cfg.device, use_auto = cfg.use_auto, cache_dir = HF_CACHE)
    llm.set_constrained(cfg.zero_id, cfg.one_id)        
        
    prompt_template = partial(eval(cfg.prompt_template), first_identifier = cfg.zero_id, second_identifier = cfg.one_id)
    reprompt_template = cot_reprompt if ('cot' in cfg.prompt_template or cfg.reprompt) else None
    
    resp1_pref, resp2_pref = [], []
    
    for i, (context, resp1, resp2) in tqdm(enumerate(zip(contexts_lst, response1_lst, response2_lst)), total = len(response1_lst), desc = 'Running judging'):
        pref = llm(context, resp1, resp2, prompt_template, reprompt_template, **gen_kwargs)
        if pref == cfg.zero_id:
            resp1_pref.append(i)
        elif pref == cfg.one_id:
            resp2_pref.append(i)
        else:
            raise RuntimeError('response not valid label') 
    
    srsly.write_json(ovr_path, {'model1_win_rate': len(resp1_pref)/len(response1_lst), 
                                            'model2_win_rate': len(resp2_pref)/len(response2_lst), 
                                            'resp1_pref': resp1_pref, 'resp2_pref': resp2_pref, 
                                            'resp1_path': cfg.response1_path, 'resp2_path': cfg.response2_path})

if __name__ == '__main__':
    main()