
import json
import datasets
from fire import Fire
from functools import partial
from typing import List
from loguru import logger
import os
import sys
sys.path.append("..")
from utils import (
    generate_with_references,
    DEBUG,
)
from datasets import load_dataset, load_from_disk

def process_fn(
    item, 
    model, 
    reference_models = [],
    temperature=0.7,
    max_tokens=2048,
    rounds=1,
    n=1,
    aggPrompt="Default",
    args=None,
):
    max_turns,use_mt = args[0],args[1]
    convs = item['conversation']
    references = item.get('references', [])

    prompts = [turn["content"] for turn in convs if turn["role"] == "user"]
    # logger.info(len(prompts))
    total_completion_tokens = 0
    total_encoded_tokens = 0
    if rounds == 0:
        # need to adjust it to handle multi-turns and multi sample
        # if n > 1, then we do parallel sampling on the first turn. For later turn we just change it to sequential sampling
        for idx, prompt in enumerate(prompts[:max_turns]):
            if idx == 0:
                messages = [{"role": "user", "content": prompt}]
                outputs = generate_with_references(
                    model=model,
                    messages=messages,
                    references=references,
                    temperature=temperature,
                    n=n,
                    max_tokens=max_tokens,
                    aggPrompt=aggPrompt,
                )
                inputs = outputs["inputs"]
                outputs = outputs["outputs"]
                new_messages = []
                for oupt in outputs:
                    temp = []
                    temp.append({"role": "user", "content": prompt})
                    temp.append({"role": "assistant", "content": oupt})
                    new_messages.append(temp)
            else:
                for m_idx in range(len(new_messages)):
                    new_messages[m_idx].append({"role": "user", "content": prompt})
                    outputs = generate_with_references(
                        model=model,
                        messages=new_messages[m_idx],
                        references=references,
                        temperature=temperature,
                        n=1,
                        max_tokens=max_tokens,
                        aggPrompt=aggPrompt,
                    )
                    inputs = outputs["inputs"]
                    outputs = outputs["outputs"]
                    new_messages[m_idx].append({"role": "assistant", "content": outputs[0]})

    else:
        num_references = len(references)
        # print(num_references, len(references[0]))
        # print(references)
        for idx, prompt in enumerate(prompts[:max_turns]):
            messages.append({"role": "user", "content": prompt})
            
            ref_answers = [ref[idx*2+1] for ref in references]
            cur_references = [ref["content"] for ref in ref_answers]

            mt_prompt_order = True if idx != 0 else False
            if not use_mt:
                mt_prompt_order = False

            outputs = generate_with_references(
                model=model,
                messages=messages,
                references=cur_references,
                temperature=temperature,
                n=n,
                max_tokens=max_tokens,
                aggPrompt=aggPrompt,
                mt_bench=mt_prompt_order,
            )

            inputs = outputs["inputs"]
            total_completion_tokens += outputs["tokens_generated"]
            total_encoded_tokens += outputs["tokens_encoded"]
            outputs = outputs["outputs"]
            messages.append({"role": "assistant", "content": outputs[0]})
        
    return {
         'messages': new_messages, 'generator': model + '-together', 'total_completion_tokens': total_completion_tokens, 'total_encoded_tokens': total_encoded_tokens
    }


def main(
    model: str,
    output_path: str,
    additional_info: str = "",
    reference_paths: str = None,
    reference_models: str = None,
    num_reference_path: int = None,
    aggPrompt: str = "Default",
    temperature: float = 0.7,
    max_tokens: int = 2048,
    rounds: int = 1,
    num_proc: int = 16,
    n=1,
    max_turns=10,
    use_mt=True,
    dataset_path="",
    start: int = 0,
    finish: int = None,
):
    args = [max_turns,use_mt]
    if reference_paths is None:
        reference_paths = []
    else:
        if "*" in reference_paths:
            import glob
            reference_paths = glob.glob(reference_paths)
            reference_paths = sorted(reference_paths)
        else:
            reference_paths = reference_paths.split(',')

    if reference_models is None:
        reference_models = []
    else:
        reference_models = reference_models.split(',')
    

    if "ultrafeedback" in dataset_path:
        eval_set = load_dataset("HuggingFaceH4/ultrafeedback_binarized")["train_prefs"]
        # change the messages column to conversation
        eval_set = eval_set.rename_column("messages", "conversation")
    else:
        eval_set = load_from_disk(dataset_path)
    


    if len(reference_paths):
        num_reference_path = len(reference_paths) if num_reference_path is None else num_reference_path
        reference_paths = reference_paths[:num_reference_path]
        logger.info(f"`reference_paths` provided: {reference_paths}")        

        references = []
        for reference_path in reference_paths:
            with open(reference_path) as f:
                reference_responses = json.load(f)
                logger.info(f"Reading reference outputs: {reference_path} ({len(reference_responses)})")
                for i_reference_response, reference_response in enumerate(reference_responses):
                    if len(references) <= i_reference_response:
                        references.append([reference_response['messages']])
                    else:
                        references[i_reference_response].append(reference_response['messages'])

        eval_set = eval_set.add_column(f"references", references)

    elif len(reference_models):

        logger.info(f"`reference_models` provided: {reference_models}, {len(reference_models)} of them. Will generate reference responses on-the-fly.")
    
    eval_set = eval_set.select(range(start, len(eval_set) if finish is None else finish))
    logger.info(f"Start.")
    logger.info(eval_set)

    eval_set = eval_set.map(
        partial(
            process_fn, 
            model=model, 
            reference_models=reference_models,
            temperature=temperature,
            max_tokens=max_tokens,
            rounds=rounds,
            n=n,
            aggPrompt=aggPrompt,
            args=args,
        ),
        batched=False, num_proc=num_proc,
    )
    model_format = model.replace("/checkpoint", "-checkpoint")
    model_name = model_format.split('/')[-1]
    output_dir = f'{output_path}/{model_name}/'
    os.makedirs(output_dir, exist_ok=True)
    # print(eval_set)


    try:
        eval_set = eval_set.remove_columns(f"conversation")
        eval_set = eval_set.remove_columns(f"chosen")
        eval_set = eval_set.remove_columns(f"rejected")
        eval_set = eval_set.remove_columns(f"references")
    except Exception as e:
        pass
    eval_set_list = list(eval_set)
    for i in range(n):
        if n == 1:
            num_reference_path_str = f"-num_reference_path{num_reference_path}" if num_reference_path is not None else ""
            output_path = f'{output_dir}/{model_name}-round_{rounds}-temp{temperature}{num_reference_path_str}_{additional_info}.json'
        else:
            output_path = f'{output_dir}/{model_name}-round_{rounds}-temp{temperature}-{i}_{n}_{additional_info}.json'
        
        logger.info(f"Saving outputs to {output_path}.")

        new_eval_set = []
        for item in eval_set_list:
            new_item = item.copy()
            new_item['messages'] = item['messages'][i]
            new_eval_set.append(new_item)

        with open(output_path, 'w') as f:
            json.dump(new_eval_set, f, indent=2)


if __name__ == '__main__':

    Fire(main)