import json
import pdb
import pickle
import numpy as np
from datasets import Dataset
from re import M
from datasets import load_dataset, DatasetDict, concatenate_datasets
import hashlib

if __name__=='__main__':
    
    with open(f'outputs/mistral-7b-instruct-v0.2-on-policy-clean-train-data_rewards.json','r') as f:
        train_data = json.load(f)
    with open(f'outputs/mistral-7b-instruct-v0.2-on-policy-clean-test-data_rewards.json','r') as f:
        test_data = json.load(f)
    
    fin_train_data = train_data
    fin_test_data = test_data
    
    list_new_rm_score = []
    for i in range(len(fin_train_data)):
        list_new_rm_score.extend(list(fin_train_data[i]['all_reward_scores']))
    
    for i in range(len(fin_test_data)):
        list_new_rm_score.extend(list(fin_test_data[i]['all_reward_scores']))
    
    min_rm_score = np.min(list_new_rm_score)
    max_rm_score = np.max(list_new_rm_score)
    
    train_data = []
    test_data = []
    
    for i in range(len(fin_train_data)):
        item = {}
        scores = fin_train_data[i]['all_reward_scores']
        
        item['prompt'] = fin_train_data[i]['prompt']
        
        
        ascending_indices = np.array(scores).argsort()
        descending_indices = ascending_indices[::-1]
        
        item['A0'] = fin_train_data[i]['all_generated_responses'][descending_indices[0]]
        item['A1'] = fin_train_data[i]['all_generated_responses'][descending_indices[1]]
        item['A2'] = fin_train_data[i]['all_generated_responses'][descending_indices[-2]]
        item['A3'] = fin_train_data[i]['all_generated_responses'][descending_indices[-1]]
        
        item['score_A0'] = (fin_train_data[i]['all_reward_scores'][descending_indices[0]] - min_rm_score) / (max_rm_score - min_rm_score)
        item['score_A1'] = (fin_train_data[i]['all_reward_scores'][descending_indices[1]] - min_rm_score) / (max_rm_score - min_rm_score)
        item['score_A2'] = (fin_train_data[i]['all_reward_scores'][descending_indices[-2]] - min_rm_score) / (max_rm_score - min_rm_score)
        item['score_A3'] = (fin_train_data[i]['all_reward_scores'][descending_indices[-1]] - min_rm_score) / (max_rm_score - min_rm_score)
        
        item['A0'] =  [
                        {"role": "user", "content": item['prompt']},
                        {"role": "assistant", "content": item['A0']},
                      ]
        item['A1'] =  [
                        {"role": "user", "content": item['prompt']},
                        {"role": "assistant", "content": item['A1']},
                      ]
        item['A2'] =  [
                        {"role": "user", "content": item['prompt']},
                        {"role": "assistant", "content": item['A2']},
                      ]
        item['A3'] =  [
                        {"role": "user", "content": item['prompt']},
                        {"role": "assistant", "content": item['A3']},
                      ]
        
        item["prompt_id"] = hashlib.sha256(item['prompt'].encode("utf-8")).hexdigest()
        
        train_data.append(item)
    
    for i in range(len(fin_test_data)):
        item = {}
        scores = fin_test_data[i]['all_reward_scores']
        item['prompt'] = fin_test_data[i]['prompt']
        
        ascending_indices = np.array(scores).argsort()
        descending_indices = ascending_indices[::-1]
        
        item['A0'] = fin_test_data[i]['all_generated_responses'][descending_indices[0]]
        item['A1'] = fin_test_data[i]['all_generated_responses'][descending_indices[1]]
        item['A2'] = fin_test_data[i]['all_generated_responses'][descending_indices[-2]]
        item['A3'] = fin_test_data[i]['all_generated_responses'][descending_indices[-1]]
        
        item['score_A0'] = (fin_test_data[i]['all_reward_scores'][descending_indices[0]] - min_rm_score) / (max_rm_score - min_rm_score)
        item['score_A1'] = (fin_test_data[i]['all_reward_scores'][descending_indices[1]] - min_rm_score) / (max_rm_score - min_rm_score)
        item['score_A2'] = (fin_test_data[i]['all_reward_scores'][descending_indices[-2]] - min_rm_score) / (max_rm_score - min_rm_score)
        item['score_A3'] = (fin_test_data[i]['all_reward_scores'][descending_indices[-1]] - min_rm_score) / (max_rm_score - min_rm_score)
        
        
        item['A0'] =  [
                        {"role": "user", "content": item['prompt']},
                        {"role": "assistant", "content": item['A0']},
                      ]
        item['A1'] =  [
                        {"role": "user", "content": item['prompt']},
                        {"role": "assistant", "content": item['A1']},
                      ]
        item['A2'] =  [
                        {"role": "user", "content": item['prompt']},
                        {"role": "assistant", "content": item['A2']},
                      ]
        item['A3'] =  [
                        {"role": "user", "content": item['prompt']},
                        {"role": "assistant", "content": item['A3']},
                      ]
        
        item["prompt_id"] = hashlib.sha256(item['prompt'].encode("utf-8")).hexdigest()
        
        test_data.append(item)
    
    train_dataset = Dataset.from_dict({
        "prompt_id": [item["prompt_id"] for item in train_data],
        "prompt": [item["prompt"] for item in train_data],
        "A0": [item["A0"] for item in train_data],
        "A1": [item["A1"] for item in train_data],
        "A2": [item["A2"] for item in train_data],
        "A3": [item["A3"] for item in train_data],
        "score_A0": [item["score_A0"] for item in train_data],
        "score_A1": [item["score_A1"] for item in train_data],
        "score_A2": [item["score_A2"] for item in train_data],
        "score_A3": [item["score_A3"] for item in train_data],
    })
    
    test_dataset = Dataset.from_dict({
        "prompt_id": [item["prompt_id"] for item in test_data],
        "prompt": [item["prompt"] for item in test_data],
        "A0": [item["A0"] for item in test_data],
        "A1": [item["A1"] for item in test_data],
        "A2": [item["A2"] for item in test_data],
        "A3": [item["A3"] for item in test_data],
        "score_A0": [item["score_A0"] for item in test_data],
        "score_A1": [item["score_A1"] for item in test_data],
        "score_A2": [item["score_A2"] for item in test_data],
        "score_A3": [item["score_A3"] for item in test_data],
    })
    
    def filter_dataset(example):
        if abs(example['completion_token_A0'] - example['completion_token_A1'])>=100 or abs(example['completion_token_A0'] - example['completion_token_A2'])>=100 or abs(example['completion_token_A0'] - example['completion_token_A3'])>=100:
            return False
        
        return True
    
    filtered_train_dataset = train_dataset.filter(filter_dataset)
    filtered_test_dataset = test_dataset.filter(filter_dataset)
    
    print(len(filtered_train_dataset), len(filtered_test_dataset))
    all_ds = DatasetDict()
    all_ds["train_prefs"] = filtered_train_dataset
    all_ds["test_prefs"] = filtered_test_dataset

    # remove empty last turns
    def filter_empty_messages(example):
        if example["A0"][-1]["role"] == "user":
            example["A0"] = example["A0"][:-1]
        if example["A1"][-1]["role"] == "user":
            example["A1"] = example["A1"][:-1]
        if example["A2"][-1]["role"] == "user":
            example["A2"] = example["A2"][:-1]
        if example["A3"][-1]["role"] == "user":
            example["A3"] = example["A3"][:-1]
        return example


    all_ds = all_ds.map(filter_empty_messages)
    all_ds.push_to_hub("Ultrafeedback-mistral-7b-instruct-v0.2-top2vsbottom2-selection")
        
        
        
            