import os, json
from itertools import combinations
from datasets import Dataset, load_from_disk
import random
random.seed(42)
import numpy as np

'''
    Hybrid Onpolicy and Offpolicy response for preference dataset. Margin distribution among all subsets are controlled to be the same.
'''

data_count, data_count_SFT = 0, 0
grouped_results, grouped_results_SFT = {}, {}
source_file = "/PATH/to/sharegpt/with/generated/resposes"
with open(source_file, 'r', encoding='utf-8') as file:
    data = json.load(file)  # 79997
    # instruction_id raw_prompt model_a response_a score
    for entry in data:
        if "Llama-3.1-Tulu-3-8B-SFT" not in entry['model_a']:
            data_count += 1
            if entry['instruction_id'] not in grouped_results:
                grouped_results[entry['instruction_id']] = []
            grouped_results[entry['instruction_id']].append(entry)    # all off-policy response
        else:
            data_count_SFT += 1
            if entry['instruction_id'] not in grouped_results_SFT:
                grouped_results_SFT[entry['instruction_id']] = []
            grouped_results_SFT[entry['instruction_id']].append(entry)    # all on-policy response

setup = "impure"  
'''
    4 levels of mixing 
        pure : Pure-On-Policy
        midmix:Mid-Mix
        lowmix:Low-Mix
        pure_offpolicy:Pure-Off-Policy
'''

margin_1_num = 7132
margin_2_num = 7000
margin_3_num = 5000

if setup == "pure":  # 纯on-policy
    chosen_list_1, rejected_list_1 = [], []
    chosen_list_2, rejected_list_2 = [], []
    chosen_list_3, rejected_list_3 = [], []
    tie_count, count = 0, 0
    for i, entry_list in enumerate(grouped_results_SFT.values()):
        raw_prompt = entry_list[0].get("raw_prompt", [])
        for item1, item2 in combinations(entry_list, 2):
            if abs(item1['score'] - item2['score']) < 0.1:
                tie_count += 1
                continue 
            elif item1['score'] > item2['score']:
                chosen = item1['response_a']
                rejected = item2['response_a']
                count += 1
            else:
                chosen = item2['response_a']
                rejected = item1['response_a']
                count += 1
            chosen_chat = raw_prompt + [{"role": "assistant", "content": chosen}]
            rejected_chat = raw_prompt  + [{"role": "assistant", "content": rejected}]

            if abs(item1['score'] - item2['score']) == 1:
                chosen_list_1.append(chosen_chat)
                rejected_list_1.append(rejected_chat)
            elif abs(item1['score'] - item2['score']) == 2:
                chosen_list_2.append(chosen_chat)
                rejected_list_2.append(rejected_chat)
            elif abs(item1['score'] - item2['score']) == 3:
                chosen_list_3.append(chosen_chat)
                rejected_list_3.append(rejected_chat)

elif setup == "midmix":
    chosen_list_1, rejected_list_1 = [], []
    chosen_list_2, rejected_list_2 = [], []
    chosen_list_3, rejected_list_3 = [], []
    tie_count, count = 0, 0

    for id, entry_list in grouped_results.items():
        raw_prompt = entry_list[0].get("raw_prompt", [])

        random_integer = random.randint(0, len(grouped_results_SFT[id])-1)
        SFT_entry = grouped_results_SFT[id][random_integer]
        for item in entry_list:
            if abs(item['score'] - SFT_entry['score']) < 0.1:
                tie_count += 1
                continue 
            elif item['score'] > SFT_entry['score']:
                chosen = item['response_a']
                rejected = SFT_entry['response_a']
                count += 1
            else:
                chosen = SFT_entry['response_a']
                rejected = item['response_a']
                count += 1
            chosen_chat = raw_prompt + [{"role": "assistant", "content": chosen}]
            rejected_chat = raw_prompt  + [{"role": "assistant", "content": rejected}]

            if abs(item['score'] - SFT_entry['score']) == 1:
                chosen_list_1.append(chosen_chat)
                rejected_list_1.append(rejected_chat)
            elif abs(item['score'] - SFT_entry['score']) == 2:
                chosen_list_2.append(chosen_chat)
                rejected_list_2.append(rejected_chat)
            elif abs(item['score'] - SFT_entry['score']) == 3:
                chosen_list_3.append(chosen_chat)
                rejected_list_3.append(rejected_chat)


elif setup == "lowmix":
    chosen_list_1, rejected_list_1 = [], []
    chosen_list_2, rejected_list_2 = [], []
    chosen_list_3, rejected_list_3 = [], []
    tie_count, count = 0, 0
    for id, entry_list in grouped_results.items():
        raw_prompt = entry_list[0].get("raw_prompt", [])

        random_integer = random.randint(0, len(grouped_results_SFT[id])-1)
        SFT_entry = grouped_results_SFT[id][random_integer]

        mix_entry_list = entry_list[:3] + [SFT_entry]

        for item1, item2 in combinations(mix_entry_list, 2):
            if abs(item1['score'] - item2['score']) < 0.1:
                tie_count += 1
                continue 
            elif item1['score'] > item2['score']:
                chosen = item1['response_a']
                rejected = item2['response_a']
                count += 1
            else:
                chosen = item2['response_a']
                rejected = item1['response_a']
                count += 1
        
            chosen_chat = raw_prompt + [{"role": "assistant", "content": chosen}]
            rejected_chat = raw_prompt + [{"role": "assistant", "content": rejected}]

            if abs(item1['score'] - item2['score']) == 1:
                chosen_list_1.append(chosen_chat)
                rejected_list_1.append(rejected_chat)
            elif abs(item1['score'] - item2['score']) == 2:
                chosen_list_2.append(chosen_chat)
                rejected_list_2.append(rejected_chat)
            elif abs(item1['score'] - item2['score']) == 3:
                chosen_list_3.append(chosen_chat)
                rejected_list_3.append(rejected_chat)

elif setup == "pure_offpolicy": 
    chosen_list_1, rejected_list_1 = [], []
    chosen_list_2, rejected_list_2 = [], []
    chosen_list_3, rejected_list_3 = [], []
    tie_count, count = 0, 0
    for i, entry_list in enumerate(grouped_results.values()):
        raw_prompt = entry_list[0].get("raw_prompt", [])
        for item1, item2 in combinations(entry_list, 2):
            if abs(item1['score'] - item2['score']) < 0.1:
                tie_count += 1
                continue 
            elif item1['score'] > item2['score']:
                chosen = item1['response_a']
                rejected = item2['response_a']
                count += 1
            else:
                chosen = item2['response_a']
                rejected = item1['response_a']
                count += 1
            chosen_chat = raw_prompt + [{"role": "assistant", "content": chosen}]
            rejected_chat = raw_prompt  + [{"role": "assistant", "content": rejected}]

            if abs(item1['score'] - item2['score']) == 1:
                chosen_list_1.append(chosen_chat)
                rejected_list_1.append(rejected_chat)
            elif abs(item1['score'] - item2['score']) == 2:
                chosen_list_2.append(chosen_chat)
                rejected_list_2.append(rejected_chat)
            elif abs(item1['score'] - item2['score']) == 3:
                chosen_list_3.append(chosen_chat)
                rejected_list_3.append(rejected_chat)
    print(count, tie_count) 

## keep margin distribution same
chosen_list, rejected_list = [], []
margin1_idx = random.sample(range(len(chosen_list_1)), margin_1_num)
for idx in margin1_idx:
    chosen_list.append(chosen_list_1[idx])
    rejected_list.append(rejected_list_1[idx])
margin2_idx = random.sample(range(len(chosen_list_2)), margin_2_num)
for idx in margin2_idx:
    chosen_list.append(chosen_list_2[idx])
    rejected_list.append(rejected_list_2[idx])
margin3_idx = random.sample(range(len(chosen_list_3)), margin_3_num)
for idx in margin3_idx:
    chosen_list.append(chosen_list_3[idx])
    rejected_list.append(rejected_list_3[idx])

indices = list(range(len(chosen_list)))
random.shuffle(indices)
chosen_list[:] = [chosen_list[i] for i in indices]
rejected_list[:] = [rejected_list[i] for i in indices]

# sample_idx = random.sample(range(len(chosen_list)), data_size)
# chosen_list = [chosen_list[idx] for idx in sample_idx]
# rejected_list = [rejected_list[idx] for idx in sample_idx]
# print(len(chosen_list))

#### save dataset for openrlhf
processed_samples = {
    "chosen": chosen_list,
    "rejected": rejected_list
}
dataset = Dataset.from_dict(processed_samples)
dataset.save_to_disk(f'/PATH/to/{setup}/mix/dataset')