import os, json
from itertools import combinations
from datasets import Dataset, load_from_disk
import random
random.seed(42)
import numpy as np

'''
    Divide preference dataset according to the absolute score of responses, margin distribution among all subsets are controlled to be the same.
'''

def highlow_score_dataset_margin_cons():
    data_count = 0
    source_file = "/PATH/to/sharegpt/with/generated/resposes"
    grouped_results = {}
    with open(source_file, 'r', encoding='utf-8') as file:
        data = json.load(file)  # 79997
        # instruction_id raw_prompt model_a response_a score
        for entry in data:
            if "Llama-3.1-Tulu-3-8B-SFT" not in entry['model_a']:
                data_count += 1
                if entry['instruction_id'] not in grouped_results:
                    grouped_results[entry['instruction_id']] = []
                grouped_results[entry['instruction_id']].append(entry)

    pair_dict_1, pair_dict_2, pair_dict_3, pair_dict_4 = {}, {}, {}, {}
    pair_count_1, pair_count_2, pair_count_3, pair_count_4 = {}, {}, {}, {}

    for id, entry_list in grouped_results.items():
        raw_prompt = entry_list[0].get("raw_prompt", [])
     
        for item1, item2 in combinations(entry_list, 2):
            if item1['score'] == item2['score']: continue
            if item1['score'] > item2['score']:
                chosen,rejected = item1['response_a'], item2['response_a']
            else:
                chosen,rejected = item2['response_a'], item1['response_a']

            chosen_chat = raw_prompt + [{"role": "assistant", "content": chosen}]
            rejected_chat = raw_prompt + [{"role": "assistant", "content": rejected}]

            m = np.max([item1['score'], item2['score']])  #### determine score with chosen response (while the margin is controlled among different subsets, control the chosen / rejected response are the same)

            if abs(item1['score'] - item2['score']) == 1:
                if m not in pair_dict_1:
                    pair_dict_1[m] = {"chosen":[], "rejected":[]}
                    pair_count_1[m] = 0
                pair_dict_1[m]["chosen"].append(chosen_chat)
                pair_dict_1[m]["rejected"].append(rejected_chat)
                pair_count_1[m] += 1
            elif abs(item1['score'] - item2['score']) == 2:
                if m not in pair_dict_2:
                    pair_dict_2[m] = {"chosen":[], "rejected":[]}
                    pair_count_2[m] = 0
                pair_dict_2[m]["chosen"].append(chosen_chat)
                pair_dict_2[m]["rejected"].append(rejected_chat)
                pair_count_2[m] += 1
            elif abs(item1['score'] - item2['score']) == 3:
                if m not in pair_dict_3:
                    pair_dict_3[m] = {"chosen":[], "rejected":[]}
                    pair_count_3[m] = 0
                pair_dict_3[m]["chosen"].append(chosen_chat)
                pair_dict_3[m]["rejected"].append(rejected_chat)
                pair_count_3[m] += 1
            elif abs(item1['score'] - item2['score']) == 4:
                if m not in pair_dict_4:
                    pair_dict_4[m] = {"chosen":[], "rejected":[]}
                    pair_count_4[m] = 0
                pair_dict_4[m]["chosen"].append(chosen_chat)
                pair_dict_4[m]["rejected"].append(rejected_chat)
                pair_count_4[m] += 1

    pair_dict_1_all = {"high":{"chosen":[], "rejected":[]},"mid":{"chosen":[], "rejected":[]}, "low":{"chosen":[], "rejected":[]}}
    pair_dict_2_all = {"high":{"chosen":[], "rejected":[]},"mid":{"chosen":[], "rejected":[]}, "low":{"chosen":[], "rejected":[]}}
    pair_dict_3_all = {"high":{"chosen":[], "rejected":[]},"mid":{"chosen":[], "rejected":[]}, "low":{"chosen":[], "rejected":[]}}

    for key in pair_dict_1:
        if key <= 7:  # score <=7 (low)
            pair_dict_1_all["low"]["chosen"].extend(pair_dict_1[key]["chosen"])
            pair_dict_1_all["low"]["rejected"].extend(pair_dict_1[key]["rejected"])
        elif key <= 8:  # 7<score <=8 (mid)
            pair_dict_1_all["mid"]["chosen"].extend(pair_dict_1[key]["chosen"])
            pair_dict_1_all["mid"]["rejected"].extend(pair_dict_1[key]["rejected"])
        else: # score > 8 (high)
            pair_dict_1_all["high"]["chosen"].extend(pair_dict_1[key]["chosen"])
            pair_dict_1_all["high"]["rejected"].extend(pair_dict_1[key]["rejected"])
    for key in pair_dict_2:
        if key <= 7:
            pair_dict_2_all["low"]["chosen"].extend(pair_dict_2[key]["chosen"])
            pair_dict_2_all["low"]["rejected"].extend(pair_dict_2[key]["rejected"])
        elif key <= 8:
            pair_dict_2_all["mid"]["chosen"].extend(pair_dict_2[key]["chosen"])
            pair_dict_2_all["mid"]["rejected"].extend(pair_dict_2[key]["rejected"])
        else:
            pair_dict_2_all["high"]["chosen"].extend(pair_dict_2[key]["chosen"])
            pair_dict_2_all["high"]["rejected"].extend(pair_dict_2[key]["rejected"])
    for key in pair_dict_3:
        if key <= 7:
            pair_dict_3_all["low"]["chosen"].extend(pair_dict_3[key]["chosen"])
            pair_dict_3_all["low"]["rejected"].extend(pair_dict_3[key]["rejected"])
        elif key <= 8:
            pair_dict_3_all["mid"]["chosen"].extend(pair_dict_3[key]["chosen"])
            pair_dict_3_all["mid"]["rejected"].extend(pair_dict_3[key]["rejected"])
        else:
            pair_dict_3_all["high"]["chosen"].extend(pair_dict_3[key]["chosen"])
            pair_dict_3_all["high"]["rejected"].extend(pair_dict_3[key]["rejected"])


    chosen_list_high, rejected_list_high = [], []
    chosen_list_mid, rejected_list_mid = [], []
    chosen_list_low, rejected_list_low = [], []

    # margin1 
    sample_idx = random.sample(range(len(pair_dict_1_all["high"]["chosen"])), 12851)
    chosen_list_high.extend([pair_dict_1_all["high"]["chosen"][idx] for idx in sample_idx])
    rejected_list_high.extend([pair_dict_1_all["high"]["rejected"][idx] for idx in sample_idx])
    sample_idx = random.sample(range(len(pair_dict_1_all["mid"]["chosen"])), 12851)
    chosen_list_mid.extend([pair_dict_1_all["mid"]["chosen"][idx] for idx in sample_idx])
    rejected_list_mid.extend([pair_dict_1_all["mid"]["rejected"][idx] for idx in sample_idx])
    sample_idx = random.sample(range(len(pair_dict_1_all["low"]["chosen"])), 12851)
    chosen_list_low.extend([pair_dict_1_all["low"]["chosen"][idx] for idx in sample_idx])
    rejected_list_low.extend([pair_dict_1_all["low"]["rejected"][idx] for idx in sample_idx])

    # margin2
    sample_idx = random.sample(range(len(pair_dict_2_all["high"]["chosen"])), 4844)
    chosen_list_high.extend([pair_dict_2_all["high"]["chosen"][idx] for idx in sample_idx])
    rejected_list_high.extend([pair_dict_2_all["high"]["rejected"][idx] for idx in sample_idx])
    sample_idx = random.sample(range(len(pair_dict_2_all["mid"]["chosen"])), 4844)
    chosen_list_mid.extend([pair_dict_2_all["mid"]["chosen"][idx] for idx in sample_idx])
    rejected_list_mid.extend([pair_dict_2_all["mid"]["rejected"][idx] for idx in sample_idx])
    sample_idx = random.sample(range(len(pair_dict_2_all["low"]["chosen"])), 4844)
    chosen_list_low.extend([pair_dict_2_all["low"]["chosen"][idx] for idx in sample_idx])
    rejected_list_low.extend([pair_dict_2_all["low"]["rejected"][idx] for idx in sample_idx])

    # margin3
    sample_idx = random.sample(range(len(pair_dict_3_all["high"]["chosen"])), 1437)
    chosen_list_high.extend([pair_dict_3_all["high"]["chosen"][idx] for idx in sample_idx])
    rejected_list_high.extend([pair_dict_3_all["high"]["rejected"][idx] for idx in sample_idx])
    sample_idx = random.sample(range(len(pair_dict_3_all["mid"]["chosen"])), 1437)
    chosen_list_mid.extend([pair_dict_3_all["mid"]["chosen"][idx] for idx in sample_idx])
    rejected_list_mid.extend([pair_dict_3_all["mid"]["rejected"][idx] for idx in sample_idx])
    sample_idx = random.sample(range(len(pair_dict_3_all["low"]["chosen"])), 1437)
    chosen_list_low.extend([pair_dict_3_all["low"]["chosen"][idx] for idx in sample_idx])
    rejected_list_low.extend([pair_dict_3_all["low"]["rejected"][idx] for idx in sample_idx])

    print(len(chosen_list_high),len(chosen_list_mid),len(chosen_list_low))

    # save pref. data for openrlhf
    processed_samples = {
        "chosen": chosen_list_high,
        "rejected": rejected_list_high
    }
    dataset = Dataset.from_dict(processed_samples)
    dataset.save_to_disk(f'/PATH/to/high/score/preference/dataset')

    processed_samples = {
        "chosen": chosen_list_mid,
        "rejected": rejected_list_mid
    }
    dataset = Dataset.from_dict(processed_samples)
    dataset.save_to_disk(f'/PATH/to/mid/score/preference/dataset')

    processed_samples = {
        "chosen": chosen_list_low,
        "rejected": rejected_list_low
    }
    dataset = Dataset.from_dict(processed_samples)
    dataset.save_to_disk(f'/PATH/to/low/score/preference/dataset')

    # # 3: 1437

if __name__ == "__main__":
    highlow_score_dataset_margin_cons()