import os, json
from itertools import combinations
from datasets import Dataset, load_from_disk
import random
random.seed(42)
import numpy as np

'''
    Divide preference dataset according to the margin between responses
'''

def margin_dataset_keep_prompt_divide_score():
    data_count = 0
    source_file = "/PATH/to/sharegpt/with/generated/resposes"
    grouped_results = {}
    with open(source_file, 'r', encoding='utf-8') as file:
        data = json.load(file)  # 79997
        # instruction_id raw_prompt model_a response_a score
        for entry in data:
            if "Llama-3.1-Tulu-3-8B-SFT" not in entry['model_a']:
                data_count += 1
                if entry['instruction_id'] not in grouped_results:
                    grouped_results[entry['instruction_id']] = []
                grouped_results[entry['instruction_id']].append(entry)

    tie_count, count1, count2, count4 = 0, 0, 0, 0
    pairs = {}

    for id, entry_list in grouped_results.items():
        pairs[id] = {"1":[], "23":[], ">4":[]}
        raw_prompt = entry_list[0].get("raw_prompt", [])

        for item1, item2 in combinations(entry_list, 2):
            if abs(item1['score'] - item2['score']) == 0:
                tie_count += 1
                continue

            if item1['score'] > item2['score']:
                chosen, rejected = item1['response_a'], item2['response_a']
            else:
                chosen, rejected = item2['response_a'], item1['response_a']

            if abs(item1['score'] - item2['score']) >= 4:
                pairs[id][">4"].append({"chosen":raw_prompt + [{"role": "assistant", "content": chosen}], "rejected":raw_prompt + [{"role": "assistant", "content": rejected}]})
                count4 += 1
            elif abs(item1['score'] - item2['score']) >= 2:
                pairs[id]["23"].append({"chosen":raw_prompt + [{"role": "assistant", "content": chosen}], "rejected":raw_prompt + [{"role": "assistant", "content": rejected}]})
                count2 += 1
            elif abs(item1['score'] - item2['score']) >= 1:
                pairs[id]["1"].append({"chosen":raw_prompt + [{"role": "assistant", "content": chosen}], "rejected":raw_prompt + [{"role": "assistant", "content": rejected}]})
                count1 += 1
            else:
                raise Exception("error!")
    print(count1, count2, count4, tie_count)

    chosen_list_1, chosen_list_2, chosen_list_4 = [], [], []
    rejected_list_1, rejected_list_2, rejected_list_4 = [], [], []

    for id, entry_list in pairs.items():
        for result in pairs[id][">4"]:
            chosen_list_4.append(result["chosen"])
            rejected_list_4.append(result["rejected"])
        for result in pairs[id]["23"]:
            chosen_list_2.append(result["chosen"])
            rejected_list_2.append(result["rejected"])
        for result in pairs[id]["1"]:
            chosen_list_1.append(result["chosen"])
            rejected_list_1.append(result["rejected"])

    data_size = 19132
    sample_idx_1 = random.sample(range(len(chosen_list_1)), data_size)
    chosen_list_1 = [chosen_list_1[idx] for idx in sample_idx_1]
    rejected_list_1 = [rejected_list_1[idx] for idx in sample_idx_1]

    sample_idx_2 = random.sample(range(len(chosen_list_2)), data_size)
    chosen_list_2 = [chosen_list_2[idx] for idx in sample_idx_2]
    rejected_list_2 = [rejected_list_2[idx] for idx in sample_idx_2]

    sample_idx_4 = random.sample(range(len(chosen_list_4)), data_size)
    chosen_list_4 = [chosen_list_4[idx] for idx in sample_idx_4]
    rejected_list_4 = [rejected_list_4[idx] for idx in sample_idx_4]


    processed_samples = {
        "chosen": chosen_list_1,
        "rejected": rejected_list_1
    }
    dataset = Dataset.from_dict(processed_samples)
    dataset.save_to_disk(f'/PATH/to/margin/1/preference/dataset')

    processed_samples = {
        "chosen": chosen_list_2,
        "rejected": rejected_list_2
    }
    dataset = Dataset.from_dict(processed_samples)
    dataset.save_to_disk(f'/PATH/to/margin/23/preference/dataset')

    processed_samples = {
        "chosen": chosen_list_4,
        "rejected": rejected_list_4
    }
    dataset = Dataset.from_dict(processed_samples)
    dataset.save_to_disk(f'/PATH/to/margin/over4/preference/dataset')

if __name__ == "__main__":
    margin_dataset_keep_prompt_divide_score()