from dataclasses import dataclass, field
import pyrallis
import os
from datasets import load_dataset, Dataset, DatasetDict
import warnings
import torch
from prompts import *
from task_model import BoNModel, GenerateModel
from functools import partial
from tqdm import tqdm

os.environ['HF_CACHE'] = '/workspace/rlhf-code/.cache/root'

@dataclass
class DataConfig:
    # Generate responses for evaluation
    gen_model_path: str = None #specify model address to generate responses for eval
    tokenizer_path: str = "EleutherAI/pythia-1b" #specify tokenizer hf path to generate responses for eval
    reward_model_path: str = "" #specify reward model name for BoN
    
    #distr args
    num_gen_workers: int = 20 #number of workers for generating responses
    num_bon_workers_per_gpu: int = 1 #number of workers for BoN per GPU
    num_bon_gpus: int = 1 #number of GPUs for BoN

    
    #dataset arguments 
    ds_type: str = "summary" #summarization, SHP, or other dataset type for evaluation
    split: str = 'test' # evalution dataset split
    num_examples: int = 250 #num examples for evaluation
    ds_seed: int = 42 #dataset seed after shuffling for reproducibility
    hf_org: str = None # hf org to store dataset with saved responses
    cache_dir: str = "" #cache hf dataset and models
    prompt_template: str = 'summary_0_shot' # prompt template to be used for prompting
    
    
    #hf generation kwargs
    temperature: float = 0.1  #temperature during generation
    top_k: int = 0 #top k sampling
    top_p: float = 1.0 #top p sampling
    max_new_tokens: int = 64 #max new tokens
    max_seq_len: int = 576
    
    #BoN args
    gen_type: str = 'sft'
    bon_alpha: float = 1.0
    best_of: int = 1
    bon_batch_size: int = 1
    bon_chunk_size: int = 1

    dataset: str = ""  #dataset name to be used for evaluation


def main():
    import ray
    if not ray.is_initialized():
        ray.init()
    
    
    warnings.filterwarnings("ignore")
    cfg = pyrallis.parse(config_class=DataConfig)
    
    assert cfg.gen_type == 'sft' or cfg.gen_type == 'bon'
    
    HF_CACHE = os.environ['HF_CACHE'] if 'HF_CACHE' in os.environ else cfg.cache_dir
    prompt_template = cfg.prompt_template
    
    if cfg.ds_type == 'summary':
        ds = load_dataset(cfg.dataset, split = cfg.split, cache_dir= HF_CACHE)
        ds = ds.shuffle(seed = cfg.ds_seed)
        num_examples = cfg.num_examples if cfg.num_examples != -1 else len(ds)
        ds = ds.select(range(num_examples))
        save_name = f"summarize_sft-{cfg.split}_lm-{cfg.gen_model_path.split('/')[-1]}_{cfg.ds_seed}_{num_examples}_{cfg.max_new_tokens}_{cfg.best_of}"
        prompt_template = lambda x : x
    elif cfg.ds_type == 'imdb':
        ds = load_dataset(cfg.dataset, split = cfg.split, cache_dir= HF_CACHE)
        ds = ds.shuffle(seed = cfg.ds_seed)
        num_examples = cfg.num_examples if cfg.num_examples != -1 else len(ds)
        ds = ds.select(range(num_examples))
        save_name = f"imdb_sft-{cfg.split}_lm-{cfg.gen_model_path.split('/')[-1]}_{cfg.ds_seed}_{num_examples}_{cfg.max_new_tokens}_{cfg.best_of}"
        prompt_template = lambda x : x
        
    else: 
        raise NotImplementedError(f'{cfg.ds_type} not supported!')
    
    prompt_template = eval(prompt_template) if isinstance(prompt_template, str) else prompt_template
    
    ds = ds.map(lambda x, idx: {"idx": idx}, with_indices=True)
    ray_dataset = ray.data.from_huggingface(ds)
    

    @ray.remote(num_gpus=0)
    class LLMActor:
        def __init__(self): 
            self.llm = GenerateModel(gen_model_path = cfg.gen_model_path, tokenizer_path = cfg.tokenizer_path, n_samples = cfg.best_of, max_new_tokens = cfg.max_new_tokens, 
                                     top_k = cfg.top_k, top_p = cfg.top_p, temperature = cfg.temperature, cache_dir = HF_CACHE)
        
        def process_item(self, batch): 
            batch['prompt'] = [prompt_template(batch['query'][i]) for i in range(len(batch['query']))]
            return self.llm(batch)
    
    actors = [LLMActor.remote() for _ in range(cfg.num_gen_workers)]
    
    futures = []
    for batch in ray_dataset.iter_batches(batch_size=1):
        actor_id = len(futures) % len(actors)
        futures.append(actors[actor_id].process_item.remote(batch))    
    

    ds_with_responses = []
    with tqdm(total=len(futures), desc="Generating responses") as pbar:
        while futures:
                done_refs, remaining_refs = ray.wait(futures, num_returns=1)
                done_resp = ray.get(done_refs)[0][0] 
                if cfg.gen_type == 'sft': 
                    done_resp['model_response'] = done_resp['model_response'][0]
                ds_with_responses.append(done_resp)
                futures = remaining_refs
                pbar.update(1)  
    
    del futures 
    del actors  
    
    ds_with_responses = Dataset.from_list(ds_with_responses) 
    
    if cfg.gen_type == 'sft':  
        ds_with_responses  = ds_with_responses.sort("idx")
        ds_with_responses = ds_with_responses.remove_columns("idx")
        DatasetDict({cfg.split: ds_with_responses}).push_to_hub(f"{cfg.hf_org}/{save_name}")
        return  
    
    
    ray_dataset = ray.data.from_huggingface(ds_with_responses) 
    
    @ray.remote(num_gpus=1/cfg.num_bon_workers_per_gpu)
    class BoNActor:
        def __init__(self): 
            self.bon = BoNModel(reward_model_path = cfg.reward_model_path, max_seq_len = cfg.max_seq_len, batch_size = cfg.bon_batch_size, reward_chunk_size = cfg.bon_chunk_size, cache_dir = HF_CACHE)
        
        def process_item(self, batch): 
            return self.bon(batch)
    
    actors = [BoNActor.remote() for _ in range(cfg.num_bon_workers_per_gpu * cfg.num_bon_gpus)]  
    futures = []
    for batch in ray_dataset.iter_batches(batch_size=cfg.bon_batch_size):
        actor_id = len(futures) % len(actors)
        futures.append(actors[actor_id].process_item.remote(batch))  
    
    ds_with_responses = []
    with tqdm(total=len(futures), desc="Generating responses") as pbar:
        while futures:
                done_refs, remaining_refs = ray.wait(futures, num_returns=1)
                ds_with_responses.append(ray.get(done_refs)[0][0])
                futures = remaining_refs
                pbar.update(1)  
    
    del futures 
    del actors  
    
    ds_with_responses = Dataset.from_list(ds_with_responses) 
    ds_with_responses = ds_with_responses.sort("idx")
    ds_with_responses = ds_with_responses.remove_columns("idx")
    DatasetDict({cfg.split: ds_with_responses}).push_to_hub(f"{cfg.hf_org}/{save_name}")
    


if __name__ == '__main__':
    main()