from cProfile import label
from datasets import load_dataset
from datasets import Dataset
import itertools
import os
BASE_PATH='../dataset'

def load_benchmark_data(dataset_name):

    if dataset_name == 'rm_bench':
        data_path = os.path.join(BASE_PATH, "rm_bench.json")
        ds = load_dataset("json", data_files=data_path, split="train")
        
        def augment_batch(batch):
            # Prepare output in column format
            ids = []
            prompts = []
            chosens = []
            rejecteds = []
            domains = []
            difficulties = []

            for idx in range(len(batch['id'])):
                chosen_list = batch['chosen'][idx]
                rejected_list = batch['rejected'][idx]
                for i, chosen_text in enumerate(chosen_list):
                    for j, rejected_text in enumerate(rejected_list):
                        ids.append(batch['id'][idx]+'/'+str(i)+'_'+str(j))
                        prompts.append(batch['prompt'][idx])
                        chosens.append(chosen_text)
                        rejecteds.append(rejected_text)
                        domains.append(batch['domain'][idx])
                        difficulties.append(1 if i > j else 2 if i == j else 3)

            return {
                'id': ids,
                'prompt': prompts,
                'chosen': chosens,
                'rejected': rejecteds,
                'domain': domains,
                'difficulty': difficulties
            }

        # Flatten the dataset using the augmentation logic
        ds = ds.map(augment_batch, batched=True, remove_columns=ds.column_names)

    elif dataset_name == 'reward_bench_v2':
        data_path = os.path.join(BASE_PATH, "reward_bench_v2.json")
        ds = load_dataset("json", data_files=data_path, split="train")
        def formatting_func(example):
            id = example['id']
            prompt = example['prompt']
            chosen = example['chosen'][0]
            rejected = example['rejected'][0]
            domain = example['subset']
            return {
                "id": id,        
                "prompt": prompt,
                "chosen": chosen,
                "rejected": rejected,
                "domain": domain
            }

        # Apply transformation
        ds = ds.map(formatting_func, batched=False, num_proc=10)
        # Keep only desired columns
        ds = ds.remove_columns([col for col in ds.column_names if col not in ["id", "prompt", "chosen", "rejected","domain"]])
    elif dataset_name == 'reward_bench':
        data_path = os.path.join(BASE_PATH, "reward_bench.json")
        ds = load_dataset("json", data_files=data_path, split="train")
        def formatting_func(example):
            id = example['id']
            prompt = example['prompt']
            chosen = example['chosen'][0]
            rejected = example['rejected'][0]
            domain = example['subset']
            return {
                "id": id,        
                "prompt": prompt,
                "chosen": chosen,
                "rejected": rejected,
                "category": domain
            }


    elif dataset_name == 'chatbot_arena':
        data_path = os.path.join(BASE_PATH, "chatbot_arena.json")
        ds = load_dataset("json", data_files=data_path, split="train")
        def formatting_func(example, idx):
            label = example['preference_labels']['human']
            if label == 'response_1':
                chosen = example['response_1']
                rejected = example['response_2']
            elif label == 'response_2':
                chosen = example['response_2']
                rejected = example['response_1']
            else:
                chosen, rejected = None, None  # fallback

            prompt = example['query']
            domain = example['scenario_group']
            return {
                "id": idx,        # add row id
                "prompt": prompt,
                "chosen": chosen,
                "rejected": rejected,
                "domain": domain
            }

        # Apply transformation
        ds = ds.map(lambda ex, idx: formatting_func(ex, idx),
                    with_indices=True,
                    batched=False,
                    num_proc=10)

        # Keep only desired columns
        ds = ds.remove_columns([col for col in ds.column_names if col not in ["id", "prompt", "chosen", "rejected","domain"]])
    elif dataset_name == 'PPE_HF':
        data_path = os.path.join(BASE_PATH, "PPE_HF.json")
        ds = load_dataset("json", data_files=data_path, split="train")
        ds = ds.filter(lambda ex: ex["winner"] in ["model_a", "model_b"])
        def formatting_func(example, idx):
            label = example['label']
            if label == 'model_a':
                chosen = example['response_1']
                rejected = example['response_2']
            elif label == 'model_b':
                chosen = example['response_2']
                rejected = example['response_1']

            prompt = example['prompt']
            return {
                "id": idx,        # add row id
                "prompt": prompt,
                "chosen": chosen,
                "rejected": rejected,
            }

        # Apply transformation
        ds = ds.map(lambda ex, idx: formatting_func(ex, idx),
                    with_indices=True,
                    batched=False,
                    num_proc=10)

        # Keep only desired columns
        ds = ds.remove_columns([col for col in ds.column_names if col not in ["id", "prompt", "chosen", "rejected"]])

    elif dataset_name == 'judgebench':
        data_path = os.path.join(BASE_PATH, "judgebench.json")
        ds = load_dataset("json", data_files=data_path, split="train")

        def formatting_func(example, idx):
            label = example['label']
            if label == 'A>B':
                chosen = example['response_A']
                rejected = example['response_B']
            elif label == 'B>A':
                chosen = example['response_B']
                rejected = example['response_A']

            prompt = example['question']
            source = example['source']
            if 'mmlu-pro' in source:
                domain = 'knowledge'
            elif source == 'livebench-reasoning':
                domain = 'reasoning'
            elif source == 'livebench-math':
                domain = 'math'
            elif source == 'livecodebench':
                domain = 'coding'
            else:
                domain = 'other'
            return {
                "id": idx,        # add row id
                "prompt": prompt,
                "chosen": chosen,
                "rejected": rejected,
                "domain": domain,
                "response_model": example['response_model']
            }

        # Apply transformation
        ds = ds.map(lambda ex, idx: formatting_func(ex, idx),
                    with_indices=True,
                    batched=False,
                    num_proc=10)

        # Keep only desired columns
        ds = ds.remove_columns([col for col in ds.column_names if col not in ["id", "prompt", "chosen", "rejected","domain", "response_model"]])

    elif dataset_name == 'HelpSteer3':
        data_path = os.path.join(BASE_PATH, "helpsteer3.json")
        ds = load_dataset("json", data_files=data_path, split="train")
        def formatting_func(example, idx):
            label = example['overall_preference']
            split = example['split']
            difficulty = abs(label)
            if label < 0:
                chosen = example['response1']
                rejected = example['response2']
            elif label > 0:
                chosen = example['response2']
                rejected = example['response1']
            else:
                raise ValueError("Unknown label")

            prompt = example['context'][0]['content']
            domain = example['domain']
            return {
                "id": idx,        # add row id
                "prompt": prompt,
                "chosen": chosen,
                "rejected": rejected,
                "domain": domain,
                "split": split,
                "difficulty": difficulty
            }

        # Apply transformation
        ds = ds.map(lambda ex, idx: formatting_func(ex, idx),
                    with_indices=True,
                    batched=False,
                    num_proc=10)

        # Keep only desired columns
        ds = ds.remove_columns([col for col in ds.column_names if col not in ["id", "prompt", "chosen", "rejected","domain", "split", "difficulty"]])
    
    elif dataset_name == 'mixture':
        data_path = os.path.join(BASE_PATH, "mixture.json")
        ds = load_dataset("json", data_files=data_path, split="train").shuffle(seed=42)
        def formatting_func(example, idx):
            prompt = example['prompt']
            chosen = example['chosen']
            rejected = example['rejected']
            return {
                "id": idx,
                "prompt": prompt,
                "chosen": chosen,
                "rejected": rejected,
            }

        # Apply transformation
        ds = ds.map(lambda ex, idx: formatting_func(ex, idx),
                    with_indices=True,
                    batched=False,
                    num_proc=10)

        # Keep only desired columns
        ds = ds.remove_columns([col for col in ds.column_names if col not in ["id", "prompt", "chosen", "rejected"]])

    elif dataset_name == 'skywork':
        data_path = os.path.join(BASE_PATH, "skywork.json")
        ds = load_dataset("json", data_files=data_path, split="train").shuffle(seed=42)
        
        # Filter out samples where list length > 2
        ds = ds.filter(lambda ex: len(ex['chosen']) == 2 and len(ex['rejected']) == 2)
        # ds = ds.select(range(5000))
        
        def formatting_func(example, idx):
            prompt = example['chosen'][0]['content']  # First element is prompt
            chosen = example['chosen'][1]['content']   # Second element is response
            rejected = example['rejected'][1]['content']
            return {
                "id": idx,
                "prompt": prompt,
                "chosen": chosen,
                "rejected": rejected,
            }

        # Apply transformation
        ds = ds.map(lambda ex, idx: formatting_func(ex, idx),
                    with_indices=True,
                    batched=False,
                    num_proc=10)

        # Keep only desired columns
        ds = ds.remove_columns([col for col in ds.column_names if col not in ["id", "prompt", "chosen", "rejected"]])
    else:
        raise ValueError(f"Unknown dataset name: {dataset_name}")

    return ds