import os, json
from itertools import combinations
from datasets import Dataset, load_from_disk
import random
random.seed(42)
# from transformers import AutoTokenizer
import numpy as np
import statistics

'''
    Divide preference dataset according to the score variance of five responses from the same instruction. Margin distribution among all subsets are controlled to be the same.
'''

def var_split():
    data_count = 0
    grouped_results = {}
    source_file = "/PATH/to/sharegpt/with/generated/resposes"
    with open(source_file, 'r', encoding='utf-8') as file:
        data = json.load(file)  # 79997
        # instruction_id raw_prompt model_a response_a score
        for entry in data:
            if entry['model_a'] == "Llama-3.1-Tulu-3-8B-SFT" or "Llama-3.1-Tulu-3-8B-SFT" not in entry['model_a']:
                data_count += 1
                if entry['instruction_id'] not in grouped_results:
                    grouped_results[entry['instruction_id']] = []
                grouped_results[entry['instruction_id']].append(entry)    # all off-policy data
    print(data_count)

    # Keep margin distribution off all subset same 
    chosen_low_1, rejected_low_1, chosen_low_2, rejected_low_2, chosen_low_3, rejected_low_3, chosen_low_4, rejected_low_4 = [], [], [], [], [], [], [], []
    chosen_mid_1, rejected_mid_1, chosen_mid_2, rejected_mid_2, chosen_mid_3, rejected_mid_3, chosen_mid_4, rejected_mid_4 = [], [], [], [], [], [], [], []
    chosen_high_1, rejected_high_1, chosen_high_2, rejected_high_2, chosen_high_3, rejected_high_3, chosen_high_4, rejected_high_4 = [], [], [], [], [], [], [], []

    margin_1_count = {"low":0, "mid":0, "high":0}
    margin_2_count = {"low":0, "mid":0, "high":0}
    margin_3_count = {"low":0, "mid":0, "high":0}
    margin_4_count = {"low":0, "mid":0, "high":0}
    for i, entry_list in enumerate(grouped_results.values()):
        raw_prompt = entry_list[0].get("raw_prompt", [])
        if len(entry_list) != 5: continue

        # filter prompts with variance
        scores = [item['score'] for item in entry_list]
        var = statistics.variance(scores)

        if int(var//0.1) <=15:    # var <= 1.5 (low)
            var_class = "low"
        elif int(var//0.1) <= 30:    # var <= 3 (mid)
            var_class = "mid"
        else:
            var_class = "high"

        for item1, item2 in combinations(entry_list, 2):
            if abs(item1['score'] - item2['score']) == 1:
                margin_1_count[var_class] += 1

                if item1['score'] > item2['score']:
                    chosen = item1['response_a']
                    rejected = item2['response_a']
                else:
                    chosen = item2['response_a']
                    rejected = item1['response_a']
                chosen_chat = raw_prompt + [{"role": "assistant", "content": chosen}]
                rejected_chat = raw_prompt  + [{"role": "assistant", "content": rejected}]

                if var_class == "low":
                    chosen_low_1.append(chosen_chat)
                    rejected_low_1.append(rejected_chat)
                elif var_class == "mid":
                    chosen_mid_1.append(chosen_chat)
                    rejected_mid_1.append(rejected_chat)
                elif var_class == "high":
                    chosen_high_1.append(chosen_chat)
                    rejected_high_1.append(rejected_chat)

            elif abs(item1['score'] - item2['score']) == 2:
                margin_2_count[var_class] += 1

                if item1['score'] > item2['score']:
                    chosen = item1['response_a']
                    rejected = item2['response_a']
                else:
                    chosen = item2['response_a']
                    rejected = item1['response_a']
                chosen_chat = raw_prompt + [{"role": "assistant", "content": chosen}]
                rejected_chat = raw_prompt  + [{"role": "assistant", "content": rejected}]

                if var_class == "low":
                    chosen_low_2.append(chosen_chat)
                    rejected_low_2.append(rejected_chat)
                elif var_class == "mid":
                    chosen_mid_2.append(chosen_chat)
                    rejected_mid_2.append(rejected_chat)
                elif var_class == "high":
                    chosen_high_2.append(chosen_chat)
                    rejected_high_2.append(rejected_chat)

            elif abs(item1['score'] - item2['score']) == 3:
                margin_3_count[var_class] += 1

                if item1['score'] > item2['score']:
                    chosen = item1['response_a']
                    rejected = item2['response_a']
                else:
                    chosen = item2['response_a']
                    rejected = item1['response_a']
                chosen_chat = raw_prompt + [{"role": "assistant", "content": chosen}]
                rejected_chat = raw_prompt  + [{"role": "assistant", "content": rejected}]

                if var_class == "low":
                    chosen_low_3.append(chosen_chat)
                    rejected_low_3.append(rejected_chat)
                elif var_class == "mid":
                    chosen_mid_3.append(chosen_chat)
                    rejected_mid_3.append(rejected_chat)
                elif var_class == "high":
                    chosen_high_3.append(chosen_chat)
                    rejected_high_3.append(rejected_chat)
            elif abs(item1['score'] - item2['score']) == 4:
                margin_4_count[var_class] += 1

                if item1['score'] > item2['score']:
                    chosen = item1['response_a']
                    rejected = item2['response_a']
                else:
                    chosen = item2['response_a']
                    rejected = item1['response_a']
                chosen_chat = raw_prompt + [{"role": "assistant", "content": chosen}]
                rejected_chat = raw_prompt  + [{"role": "assistant", "content": rejected}]

                if var_class == "low":
                    chosen_low_4.append(chosen_chat)
                    rejected_low_4.append(rejected_chat)
                elif var_class == "mid":
                    chosen_mid_4.append(chosen_chat)
                    rejected_mid_4.append(rejected_chat)
                elif var_class == "high":
                    chosen_high_4.append(chosen_chat)
                    rejected_high_4.append(rejected_chat)

    m1_num, m2_num, m3_num = 10290, 6609, 2233

    chosen_low, chosen_mid, chosen_high = [], [], []
    rejected_low, rejected_mid, rejected_high = [], [], []

    # Keep margin distribution off all subset same 
    sample_idx = random.sample(range(len(chosen_low_1)), m1_num)
    chosen_low_1 = [chosen_low_1[idx] for idx in sample_idx]
    rejected_low_1 = [rejected_low_1[idx] for idx in sample_idx]
    chosen_low.extend(chosen_low_1)
    rejected_low.extend(rejected_low_1)
    sample_idx = random.sample(range(len(chosen_low_2)), m2_num)
    chosen_low_2 = [chosen_low_2[idx] for idx in sample_idx]
    rejected_low_2 = [rejected_low_2[idx] for idx in sample_idx]
    chosen_low.extend(chosen_low_2)
    rejected_low.extend(rejected_low_2)
    sample_idx = random.sample(range(len(chosen_low_3)), m3_num)
    chosen_low_3 = [chosen_low_3[idx] for idx in sample_idx]
    rejected_low_3 = [rejected_low_3[idx] for idx in sample_idx]
    chosen_low.extend(chosen_low_3)
    rejected_low.extend(rejected_low_3)

    sample_idx = random.sample(range(len(chosen_mid_1)), m1_num)
    chosen_mid_1 = [chosen_mid_1[idx] for idx in sample_idx]
    rejected_mid_1 = [rejected_mid_1[idx] for idx in sample_idx]
    chosen_mid.extend(chosen_mid_1)
    rejected_mid.extend(rejected_mid_1)
    sample_idx = random.sample(range(len(chosen_mid_2)), m2_num)
    chosen_mid_2 = [chosen_mid_2[idx] for idx in sample_idx]
    rejected_mid_2 = [rejected_mid_2[idx] for idx in sample_idx]
    chosen_mid.extend(chosen_mid_2)
    rejected_mid.extend(rejected_mid_2)
    sample_idx = random.sample(range(len(chosen_mid_3)), m3_num)
    chosen_mid_3 = [chosen_mid_3[idx] for idx in sample_idx]
    rejected_mid_3 = [rejected_mid_3[idx] for idx in sample_idx]
    chosen_mid.extend(chosen_mid_3)
    rejected_mid.extend(rejected_mid_3)

    sample_idx = random.sample(range(len(chosen_high_1)), m1_num)
    chosen_high_1 = [chosen_high_1[idx] for idx in sample_idx]
    rejected_high_1 = [rejected_high_1[idx] for idx in sample_idx]
    chosen_high.extend(chosen_high_1)
    rejected_high.extend(rejected_high_1)
    sample_idx = random.sample(range(len(chosen_high_2)), m2_num)
    chosen_high_2 = [chosen_high_2[idx] for idx in sample_idx]
    rejected_high_2 = [rejected_high_2[idx] for idx in sample_idx]
    chosen_high.extend(chosen_high_2)
    rejected_high.extend(rejected_high_2)
    sample_idx = random.sample(range(len(chosen_high_3)), m3_num)
    chosen_high_3 = [chosen_high_3[idx] for idx in sample_idx]
    rejected_high_3 = [rejected_high_3[idx] for idx in sample_idx]
    chosen_high.extend(chosen_high_3)
    rejected_high.extend(rejected_high_3)


    sample_idx = random.sample(range(len(chosen_low)), 19132)
    chosen_low = [chosen_low[idx] for idx in sample_idx]
    rejected_low = [rejected_low[idx] for idx in sample_idx]

    sample_idx = random.sample(range(len(chosen_mid)), 19132)
    chosen_mid = [chosen_mid[idx] for idx in sample_idx]
    rejected_mid = [rejected_mid[idx] for idx in sample_idx]

    sample_idx = random.sample(range(len(chosen_high)), 19132)
    chosen_high = [chosen_high[idx] for idx in sample_idx]
    rejected_high = [rejected_high[idx] for idx in sample_idx]

    #### save dataset for openrlhf
    processed_samples = {
        "chosen": chosen_low,
        "rejected": rejected_low
    }
    dataset = Dataset.from_dict(processed_samples)
    dataset.save_to_disk(f'/PATH/to/low/variance/dataset')

    processed_samples = {
        "chosen": chosen_mid,
        "rejected": rejected_mid
    }
    dataset = Dataset.from_dict(processed_samples)
    dataset.save_to_disk(f'/PATH/to/mid/variance/dataset')

    processed_samples = {
        "chosen": chosen_high,
        "rejected": rejected_high
    }
    dataset = Dataset.from_dict(processed_samples)
    dataset.save_to_disk(f'/PATH/to/high/variance/dataset')


if __name__ == "__main__":
    var_split()