import os, json
from itertools import combinations
from datasets import Dataset, load_from_disk
import random
random.seed(42)
import numpy as np
import argparse
import torch
import pickle

'''
    Divide preference dataset according to the average quality score of instructions given by Llama-7.1-70b-instruct and Qwen-2.5-72n-instruct. Margin distribution among all subsets are controlled to be the same.

        create_judge_prompt: prompt for instruction quality judgement
        generate_quality_judge: generate judgement text
        parse_prompt_quality_judgement: parse generation results above
        quality_split: formulate dataset
'''

def create_judge_prompt(chat_his):
    system_prompt = ""

    user_prompt = f"""# Instruction
You are an expert evaluator tasked with rating the quality of the single or multi-turn
dialogue based on its clarity, specificity, and coherence.
The rating scale is as follows:
− very poor: The dialogue is unclear, vague, or incoherent. It lacks essential information and context.
− poor: The dialogue is somewhat unclear or lacks important details. It requires significant clarification.
− average: The dialogue is moderately clear and specific. It may require some additional information for a complete understanding.
− good: The dialogue is clear, specific, and mostly well−formed. It provides sufficient context for understanding the user’s intent.
− excellent: The dialogue is very clear, specific, and well−articulated. It contains all the necessary information and context for providing a comprehensive response.

## Conversation History
{chat_his}

Please directly evaluate the overall quality of the response by 5 levels in the
following format:
Dialog quality: [Very Poor/Poor/Average/Good/Excellent]
"""

    chat_messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt}
    ]
    return chat_messages

def generate_quality_judge():
    from vllm import LLM, SamplingParams
    from transformers import AutoTokenizer

    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, default="/path/to/judge/result/text")
    parser.add_argument("--temperature", type=float, default=0.8)
    parser.add_argument("--max_tokens", type=int, default=2048)
    parser.add_argument("--tensor_parallel_size", type=int, default=torch.cuda.device_count())
    parser.add_argument("--model", type=str, default="Meta-Llama-3.1-70B-Instruct")    # Meta-Llama-3.1-70B-Instruct Qwen2.5-72B-Instruct DeepSeek-V2.5
    args = parser.parse_args()
    
    #### 文件读取
    grouped_results = {}
    source_file = "/PATH/to/sharegpt/with/generated/resposes"
    with open(source_file, 'r', encoding='utf-8') as file:
        data = json.load(file)  # 79997
        for entry in data:
            if "Llama-3.1-Tulu-3-8B-SFT" not in entry['model_a']:
                if entry['instruction_id'] not in grouped_results:
                    grouped_results[entry['instruction_id']] = []
                grouped_results[entry['instruction_id']].append(entry)    # all off-policy data

    os.makedirs(args.output_dir, exist_ok=True)
    model_name_or_path = os.path.join("/PATH/to/models", args.model)

    # # Initialize LLM
    sampling_params = SamplingParams(
        temperature=args.temperature,
        max_tokens=args.max_tokens,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
    llm = LLM(
        model=model_name_or_path,
        tensor_parallel_size=args.tensor_parallel_size,
    )

    prompts = []
    all_ids = []
    all_raw_prompt = []
    for k, id in enumerate(grouped_results):
        all_ids.append(id)
        raw_prompt = grouped_results[id][0].get("raw_prompt", [])
        all_raw_prompt.append(raw_prompt)
        dialogue = ""
        for diag in raw_prompt:
            dialogue+="-"+diag['role']+": "+diag['content']+"\n"
        prompts.append(tokenizer.apply_chat_template(
            create_judge_prompt(dialogue),
            tokenize=False,
            add_generation_prompt=True))

    outputs = llm.generate(prompts, sampling_params)
    all_generated_text = []
    for k in range(len(outputs)):
        output = outputs[k]
        generated_text = output.outputs[0].text
        all_generated_text.append({"id":all_ids[k],"raw_prompt":all_raw_prompt[k],"judge":generated_text})

    with open(os.path.join(args.output_dir, args.model + "-judge_prompt_quality.json"), "w", encoding="utf-8") as file:
        json.dump(all_generated_text, file, indent=4, ensure_ascii=False)


def parse_prompt_quality_judgement():
    with open('/path/to/llama/judge/result/text', 'r', encoding='utf-8') as file:
        data1=json.load(file)
    with open('/path/to/qwen/judge/result/text', 'r', encoding='utf-8') as file:
        data2 = json.load(file)

    score_dict = {}

    for k in range(len(data1)):
        raw_judge = data1[k]["judge"].split("input_quality")[-1]
        score1 = -1
        if "very poor" in raw_judge:
            score1 = 1
        elif "poor" in raw_judge:
            score1 = 2
        elif "average" in raw_judge:
            score1 = 3
        elif "good" in raw_judge:
            score1 = 4
        elif "excellent" in raw_judge:
            score1 = 5

        raw_judge = data2[k]["judge"].split("input_quality")[-1]
        score2 = -1
        if "very poor" in raw_judge:
            score2 = 1
        elif "poor" in raw_judge:
            score2 = 2
        elif "average" in raw_judge:
            score2 = 3
        elif "good" in raw_judge:
            score2 = 4
        elif "excellent" in raw_judge:
            score2 = 5
        if score1 == -1 or score2 == -1:
            score = -1
        else:
            score = (score1 + score2) / 2

        assert data1[k]["id"] == data2[k]["id"]

        score_dict[data1[k]["id"]] = score
    with open('path/to/parsed/judge/results', 'wb') as file:
        pickle.dump(score_dict, file)


def quality_split():
    with open('path/to/parsed/judge/results', 'rb') as file:
        score_dict = pickle.load(file)

    data_count = 0
    grouped_results = {}
    source_file = "/PATH/to/sharegpt/with/generated/resposes"
    with open(source_file, 'r', encoding='utf-8') as file:
        data = json.load(file)  # 79997
        # instruction_id raw_prompt model_a response_a score
        for entry in data:
            if "Llama-3.1-Tulu-3-8B-SFT" not in entry['model_a']:
                data_count += 1
                if entry['instruction_id'] not in grouped_results:
                    grouped_results[entry['instruction_id']] = []
                grouped_results[entry['instruction_id']].append(entry)

    chosen_low_1, rejected_low_1, chosen_low_2, rejected_low_2, chosen_low_3, rejected_low_3, chosen_low_4, rejected_low_4 = [], [], [], [], [], [], [], []
    chosen_mid_1, rejected_mid_1, chosen_mid_2, rejected_mid_2, chosen_mid_3, rejected_mid_3, chosen_mid_4, rejected_mid_4 = [], [], [], [], [], [], [], []
    chosen_high_1, rejected_high_1, chosen_high_2, rejected_high_2, chosen_high_3, rejected_high_3, chosen_high_4, rejected_high_4 = [], [], [], [], [], [], [], []

    margin_1_count = {"low":0, "mid":0, "high":0}
    margin_2_count = {"low":0, "mid":0, "high":0}
    margin_3_count = {"low":0, "mid":0, "high":0}
    margin_4_count = {"low":0, "mid":0, "high":0}

    for id, entry_list in grouped_results.items():
        raw_prompt = entry_list[0].get("raw_prompt", [])

        if score_dict[id] == -1:
            continue
        elif score_dict[id] < 4:
            quality_class = "low"
        elif score_dict[id] == 4:
            quality_class = "mid"
        else:
            quality_class = "high"

        # Keep margin distribution of all subset same
        for item1, item2 in combinations(entry_list, 2):
            if item1['score'] == item2['score']:
                continue
            elif item1['score'] > item2['score']:
                chosen = item1['response_a']
                rejected = item2['response_a']
            else:
                chosen = item2['response_a']
                rejected = item1['response_a']
            chosen_chat = raw_prompt + [{"role": "assistant", "content": chosen}]
            rejected_chat = raw_prompt  + [{"role": "assistant", "content": rejected}]

            if quality_class == "low":
                if abs(item1['score'] - item2['score']) == 1:
                    chosen_low_1.append(chosen_chat)
                    rejected_low_1.append(rejected_chat)
                    margin_1_count["low"] += 1
                elif abs(item1['score'] - item2['score']) == 2:
                    chosen_low_2.append(chosen_chat)
                    rejected_low_2.append(rejected_chat)
                    margin_2_count["low"] += 1
                elif abs(item1['score'] - item2['score']) == 3:
                    chosen_low_3.append(chosen_chat)
                    rejected_low_3.append(rejected_chat)
                    margin_3_count["low"] += 1
            elif quality_class == "mid":
                if abs(item1['score'] - item2['score']) == 1:
                    chosen_mid_1.append(chosen_chat)
                    rejected_mid_1.append(rejected_chat)
                    margin_1_count["mid"] += 1
                elif abs(item1['score'] - item2['score']) == 2:
                    chosen_mid_2.append(chosen_chat)
                    rejected_mid_2.append(rejected_chat)
                    margin_2_count["mid"] += 1
                elif abs(item1['score'] - item2['score']) == 3:
                    chosen_mid_3.append(chosen_chat)
                    rejected_mid_3.append(rejected_chat)
                    margin_3_count["mid"] += 1
            elif quality_class == "high":
                if abs(item1['score'] - item2['score']) == 1:
                    chosen_high_1.append(chosen_chat)
                    rejected_high_1.append(rejected_chat)
                    margin_1_count["high"] += 1
                elif abs(item1['score'] - item2['score']) == 2:
                    chosen_high_2.append(chosen_chat)
                    rejected_high_2.append(rejected_chat)
                    margin_2_count["high"] += 1
                elif abs(item1['score'] - item2['score']) == 3:
                    chosen_high_3.append(chosen_chat)
                    rejected_high_3.append(rejected_chat)
                    margin_3_count["high"] += 1

    m1_num, m2_num, m3_num = 9884, 6395, 2853

    chosen_low, chosen_mid, chosen_high = [], [], []
    rejected_low, rejected_mid, rejected_high = [], [], []

    sample_idx = random.sample(range(len(chosen_low_1)), m1_num)
    chosen_low_1 = [chosen_low_1[idx] for idx in sample_idx]
    rejected_low_1 = [rejected_low_1[idx] for idx in sample_idx]
    chosen_low.extend(chosen_low_1)
    rejected_low.extend(rejected_low_1)
    sample_idx = random.sample(range(len(chosen_low_2)), m2_num)
    chosen_low_2 = [chosen_low_2[idx] for idx in sample_idx]
    rejected_low_2 = [rejected_low_2[idx] for idx in sample_idx]
    chosen_low.extend(chosen_low_2)
    rejected_low.extend(rejected_low_2)
    sample_idx = random.sample(range(len(chosen_low_3)), m3_num)
    chosen_low_3 = [chosen_low_3[idx] for idx in sample_idx]
    rejected_low_3 = [rejected_low_3[idx] for idx in sample_idx]
    chosen_low.extend(chosen_low_3)
    rejected_low.extend(rejected_low_3)

    sample_idx = random.sample(range(len(chosen_mid_1)), m1_num)
    chosen_mid_1 = [chosen_mid_1[idx] for idx in sample_idx]
    rejected_mid_1 = [rejected_mid_1[idx] for idx in sample_idx]
    chosen_mid.extend(chosen_mid_1)
    rejected_mid.extend(rejected_mid_1)
    sample_idx = random.sample(range(len(chosen_mid_2)), m2_num)
    chosen_mid_2 = [chosen_mid_2[idx] for idx in sample_idx]
    rejected_mid_2 = [rejected_mid_2[idx] for idx in sample_idx]
    chosen_mid.extend(chosen_mid_2)
    rejected_mid.extend(rejected_mid_2)
    sample_idx = random.sample(range(len(chosen_mid_3)), m3_num)
    chosen_mid_3 = [chosen_mid_3[idx] for idx in sample_idx]
    rejected_mid_3 = [rejected_mid_3[idx] for idx in sample_idx]
    chosen_mid.extend(chosen_mid_3)
    rejected_mid.extend(rejected_mid_3)

    sample_idx = random.sample(range(len(chosen_high_1)), m1_num)
    chosen_high_1 = [chosen_high_1[idx] for idx in sample_idx]
    rejected_high_1 = [rejected_high_1[idx] for idx in sample_idx]
    chosen_high.extend(chosen_high_1)
    rejected_high.extend(rejected_high_1)
    sample_idx = random.sample(range(len(chosen_high_2)), m2_num)
    chosen_high_2 = [chosen_high_2[idx] for idx in sample_idx]
    rejected_high_2 = [rejected_high_2[idx] for idx in sample_idx]
    chosen_high.extend(chosen_high_2)
    rejected_high.extend(rejected_high_2)
    sample_idx = random.sample(range(len(chosen_high_3)), m3_num)
    chosen_high_3 = [chosen_high_3[idx] for idx in sample_idx]
    rejected_high_3 = [rejected_high_3[idx] for idx in sample_idx]
    chosen_high.extend(chosen_high_3)
    rejected_high.extend(rejected_high_3)


    sample_idx = random.sample(range(len(chosen_low)), 19132)
    chosen_low = [chosen_low[idx] for idx in sample_idx]
    rejected_low = [rejected_low[idx] for idx in sample_idx]

    sample_idx = random.sample(range(len(chosen_mid)), 19132)
    chosen_mid = [chosen_mid[idx] for idx in sample_idx]
    rejected_mid = [rejected_mid[idx] for idx in sample_idx]

    sample_idx = random.sample(range(len(chosen_high)), 19132)
    chosen_high = [chosen_high[idx] for idx in sample_idx]
    rejected_high = [rejected_high[idx] for idx in sample_idx]

    #### save dataset for openrlhf
    processed_samples = {
        "chosen": chosen_low,
        "rejected": rejected_low
    }
    dataset = Dataset.from_dict(processed_samples)
    dataset.save_to_disk(f'/PATH/to/low/quality/instruction/dataset')

    processed_samples = {
        "chosen": chosen_mid,
        "rejected": rejected_mid
    }
    dataset = Dataset.from_dict(processed_samples)
    dataset.save_to_disk(f'/PATH/to/mid/quality/instruction/dataset')

    processed_samples = {
        "chosen": chosen_high,
        "rejected": rejected_high
    }
    dataset = Dataset.from_dict(processed_samples)
    dataset.save_to_disk(f'/PATH/to/high/quality/instruction/dataset')

if __name__ == "__main__":
    generate_quality_judge()
    parse_prompt_quality_judgement()
    quality_split()