import argparse
import random
import os

from utils import *
from functools import partial

random.seed(1021)

def process_prompt_for_ruler_qa(example, template={},doc_key="short_docs_list"):
    instruction = template["instruction"]
    content = ""
    doc_format = template["evidence_template_str"]
    selected_paragraph_text = example[doc_key] if doc_key in example.keys() else example["all_docs"]
    
    # 2.3 format prompt
    for idx, paragraph in enumerate(selected_paragraph_text):
        cur_content = doc_format.replace("{sentence}", paragraph)
        content += cur_content

    prompt = instruction + "\nInput:\n" + template["context_prefix"] + content + template["context_suffix"] + \
        template["question_prefix"] + example["question"] + template["question_suffix"] + template["answer_prefix"]
    return prompt 

def format_chosen_rejected(example,template,doc_type):

    doc_key = f"{doc_type}_docs_list"
    prompt = process_prompt_for_ruler_qa(example, template=template, doc_key=doc_key)
    
    prompt = [{"role": "user", "content": prompt}]

    chosen_messages = [
                {"role": "assistant", "content": example["positive_sample"]}
                ]
    rejected_messages = [
                {"role": "assistant", "content": example["negative_sample"]}
                ]
    example["prompt"] = prompt
    example["chosen"] = chosen_messages
    example["rejected"] = rejected_messages
    return example

def format_llamafactory_preference_data(example,template,doc_type):

    doc_key = f"{doc_type}_docs_list"
    prompt = process_prompt_for_ruler_qa(example, template=template, doc_key=doc_key)
    
    conversations=[
        {
        "from": "human",
        "value": prompt
      }
    ]
    try:
        chosen_messages = {"from": "gpt", "value": example[f"positive_sample_on_{doc_type}"]}
        rejected_messages ={"from": "gpt", "value": example[f"negative_sample_on_{doc_type}"]}
    except:
        chosen_messages = {"from": "gpt", "value": example[f"positive_sample"]}
        rejected_messages ={"from": "gpt", "value": example[f"negative_sample"]}
    
    example["llamafactory"] = {
        "conversations": conversations,
        "chosen": chosen_messages,
        "rejected": rejected_messages
    }
    return example

def conver_2_llamafactory_preference_format(src_data_path,tgt_data_path,template_path,doc_type,num_samples=-1):
    template = load_json(template_path)
    dataset = load_custom_dataset(src_data_path)
    if num_samples > 0:
        dataset = dataset.select(range(num_samples))
    print(dataset)
    if doc_type in ["long", "short"]:
        dataset = dataset.map(partial(format_llamafactory_preference_data, template=template, doc_type=doc_type),
                            batched=False,
                            num_proc=16,
                            desc="format llamafactory dataset")
        llamafactory = dataset["llamafactory"]
    else:
        raise Exception("wrong document type")

    print(f"num_samples: {len(llamafactory)}")
    if not os.path.exists(os.path.dirname(tgt_data_path)):
        os.makedirs(os.path.dirname(tgt_data_path))
    save_data(llamafactory, tgt_data_path)

def format_llamafactory_short2long_preference_data(example,template,pair_source="short"):
    llamafactory = {}
    for doc_type in ["short", "long"]:
        doc_key = f"{doc_type}_docs_list"
        prompt = process_prompt_for_ruler_qa(example, template=template, doc_key=doc_key)
        conversations=[
            {
            "from": "human",
            "value": prompt
        }
        ]
        llamafactory[f"{doc_type}_conversations"] = conversations
    if pair_source == "short":
        try:
            chosen_messages = {"from": "gpt", "value": example["positive_sample_on_short"]}
            rejected_messages ={"from": "gpt", "value": example["negative_sample_on_short"]}
        except:
            chosen_messages = {"from": "gpt", "value": example["positive_sample"]}
            rejected_messages ={"from": "gpt", "value": example["negative_sample"]}
    llamafactory["conversations"] = llamafactory["short_conversations"]
    llamafactory["chosen"] = chosen_messages
    llamafactory["rejected"] = rejected_messages
    example["llamafactory"] = llamafactory
    return example

def conver_2_llamafactory_short2long_preference_data(src_data_path,tgt_data_path,template_path,doc_type,pair_source):
    template = load_json(template_path)
    dataset = load_custom_dataset(src_data_path)
    print(dataset)
    dataset = dataset.map(partial(format_llamafactory_short2long_preference_data, template=template, pair_source=pair_source),
                          batched=False,
                          num_proc=16,
                          desc="format llamafactory dataset",
                          load_from_cache_file=False,
                          )
    llamafactory = dataset["llamafactory"]
    print(llamafactory[0].keys())

    if not os.path.exists(os.path.dirname(tgt_data_path)):
        os.makedirs(os.path.dirname(tgt_data_path))
    save_data(llamafactory, tgt_data_path)

def format_llamafactory_sft_data(example,template,doc_type):
    doc_key = f"{doc_type}_docs_list" 
    prompt = process_prompt_for_ruler_qa(example, template=template, doc_key=doc_key)
    conversations=[
        {
            "from": "human",
            "value": prompt
        },
        {
            "from": "gpt", 
            "value": example[f"positive_sample_on_{doc_type}"] if f"positive_sample_on_{doc_type}" in example else example["positive_sample"]
        }
    ]
    example["llamafactory"] = {
        "conversations": conversations
    }
    return example

def convert_2_llamafactory_sft_format(src_data_path,tgt_data_path,template_path,doc_type, num_train):
    template = load_json(template_path)
    dataset = load_custom_dataset(src_data_path)
    if num_train > -1:
        dataset = dataset.select(range(num_train))
    if doc_type in ["long", "short"]:
        dataset = dataset.map(partial(format_llamafactory_sft_data, template=template, doc_type=doc_type),
                            batched=False,
                            num_proc=16,
                            desc="format llamafactory dataset")
        llamafactory = dataset["llamafactory"]
    else:
        raise Exception("wrong document type")

    print(llamafactory[0].keys())
    print(f"num samples: {len(llamafactory)}")

    if not os.path.exists(os.path.dirname(tgt_data_path)):
        os.makedirs(os.path.dirname(tgt_data_path))
    save_data(llamafactory, tgt_data_path)

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--num_train", type=int, default=-1,help="-1 means all")

    parser.add_argument("--src_data_file_name", type=str)
    parser.add_argument("--src_data_dir_path", type=str)
    parser.add_argument("--tgt_save_data_dir_path", type=str)

    parser.add_argument("--template_path", type=str, default="./template/cot.json")
    parser.add_argument("--length_list", type=str,default="8k", help="use commas for separation: 4k,8k,12k")
    parser.add_argument("--convert_type", type=str,default="all", help="all or reallong")
    return parser.parse_args()

def main():
    args = get_args()

    template_path = args.template_path

    num_train = args.num_train
    length_list = args.length_list.split(",")
    if args.convert_type == "all":
        doc_type_list = ["short", "long"]
    elif args.convert_type == "reallong":
        doc_type_list = ["long"]

    src_data_name = args.src_data_file_name if not args.src_data_file_name.endswith(".jsonl") else args.src_data_file_name.replace(".jsonl", "")
            
    #### 1. convert preference data
    for length in length_list:
        for doc_type in doc_type_list:
            if args.convert_type == "reallong":
                src_data_path = f"{args.src_data_dir_path}/{src_data_name}.jsonl"
                tgt_data_path = f"{args.tgt_save_data_dir_path}/{length}_po_{doc_type}_{src_data_name}.json"
            else:
                src_data_path = f"{args.src_data_dir_path}/{length}/{src_data_name}.jsonl"
                tgt_data_path = f"{args.tgt_save_data_dir_path}/{length}_po_{doc_type}_{src_data_name}.json"
            conver_2_llamafactory_preference_format(src_data_path,tgt_data_path,template_path,doc_type,num_train)
    
    ##### 2. convert sft data
    for length in length_list:
        for doc_type in doc_type_list:
            if args.convert_type == "reallong":
                src_data_path = f"{args.src_data_dir_path}/{src_data_name}.jsonl"
                tgt_data_path = f"{args.tgt_save_data_dir_path}/{length}_sft_{doc_type}_{src_data_name}.json"
            else:
                src_data_path = f"{args.src_data_dir_path}/{length}/{src_data_name}.jsonl"
                tgt_data_path = f"{args.tgt_save_data_dir_path}/{length}_sft_{doc_type}_{src_data_name}.json"
            convert_2_llamafactory_sft_format(src_data_path,tgt_data_path,template_path,doc_type,num_train)

    ##### 3. convert short2long preference data
    if args.convert_type == "all":
        doc_type_list = ["short2long"]
        pair_source = "short"
        for length in length_list:
            for doc_type in doc_type_list:
                src_data_path = f"{args.src_data_dir_path}/{length}/{src_data_name}.jsonl"
                tgt_data_path = f"{args.tgt_save_data_dir_path}/{length}_s2l_{doc_type}_{pair_source}_{src_data_name}.json"
                conver_2_llamafactory_short2long_preference_data(src_data_path,tgt_data_path,template_path,doc_type,pair_source)

if __name__ == '__main__':
    main()