from datasets import load_dataset, Dataset, DatasetDict
import json
from src.helper import extract_math_answer, format_MMLU, format_MED
from src.model_loader import base_model_chat_template
from src.generate_sudoku import generate_sudoku_dataset

def process_gsm():
    return load_dataset("openai/gsm8k", "main")

def process_aime():
    ds = load_dataset("di-zhang-fdu/AIME_1983_2024", split='train')
    ds = ds.rename_columns({'Answer': 'answer', 'Question': 'question'})
    return ds.train_test_split(test_size=0.1, seed=42)

def process_mmlu():
    ds = load_dataset("TIGER-Lab/MMLU-STEM")
    ds = ds.shuffle(seed=42)
    ds = ds['test'].train_test_split(test_size=0.2, seed=42)
    return ds.map(
        format_MMLU,
        remove_columns=ds['train'].column_names,
        load_from_cache_file=False  # Force reprocessing
    )

def process_med(subset="Surgery"):
    ds = load_dataset("openlifescienceai/medmcqa")
    ds = ds.filter(lambda x: x["choice_type"] == "single")
    if subset:
        ds = ds.filter(lambda x: x["subject_name"] == subset)

    return ds.map(
        format_MED,
        remove_columns=ds['train'].column_names,
        load_from_cache_file=False  # Force reprocessing
    )

def process_sudoku(num_prefilled=10):
    raw_data = generate_sudoku_dataset(10000, num_prefilled)
    return raw_data.train_test_split(test_size=0.1, seed=42)

def process_gsm_symbolic():
    ds = load_dataset("apple/GSM-Symbolic", "p1")
    return ds['test'].train_test_split(test_size=0.9, seed=42)

def split_proportionally(a, b, ratio):
    # if ratio is -1, we return the original list
    if ratio == -1:
        return a, b
    total = len(a) + len(b)
    a_idx = int(total * ratio)
    b_idx = int(total * (1 - ratio))

    if b_idx > len(b):
        b_idx = len(b)
        a_idx = int((ratio * len(b)) / (1 - ratio))
    if a_idx > len(a):
        a_idx = len(a)
        b_idx = int((1 - ratio) * len(a) / ratio)
        if b_idx > len(b):
            b_idx = len(b)

    return a[:a_idx], b[:b_idx]

def make_feedback_dataset(data_dir, feedback_prompt, tokenizer):
    with open(f'{data_dir}/full_record.json', 'r') as f:
        full_record = json.load(f)
    list_dataset = []
    for row in full_record:
        if "Instruct" in tokenizer.name_or_path:
            messages =  [{'role': 'user', 'content' : feedback_prompt.format(question=row['question'], initial_response=row['response'])}]
            formatted_messages = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                continue_final_message=False,
                add_generation_prompt=True
            )
        else:
            messages = [row['question'], row['response']]
            formatted_messages = base_model_chat_template(messages, feedback_prompt)
        list_dataset.append({
            'prompt': formatted_messages,
            'ground_truth': extract_math_answer(row['label'])
        })
    hf_dataset = Dataset.from_list(list_dataset)
    return hf_dataset

def make_baseline_dataset(prompt, dataset_name):
    DATASET_PROCESSORS = {
        "GSM": process_gsm,
        "AIME": process_aime,
        "MMLU": process_mmlu,
    }
    if dataset_name == "SUDOKU_12":
        dataset = process_sudoku(num_prefilled=12)
    elif dataset_name == "SUDOKU_14":
        dataset = process_sudoku(num_prefilled=14)
    else:
        dataset = DATASET_PROCESSORS.get(dataset_name, lambda: None)()
    dataset_train_valid = dataset['train'].train_test_split(test_size=0.1, seed=42)
    full_dataset = DatasetDict({
        'train': dataset_train_valid['train'],
        'val': dataset_train_valid['test'],
        'test': dataset['test']
    })
    list_dataset = []
    for row in full_dataset['train']:
        formatted_messages = prompt.format(question=row['question'])
        list_dataset.append({
            'prompt': formatted_messages,
            'ground_truth': extract_math_answer(row['answer'])
        })
    hf_dataset = Dataset.from_list(list_dataset)
    return hf_dataset