import re
from datasets import load_dataset, concatenate_datasets, DatasetDict
from functools import partial

# Gsm
gsm_ds = load_dataset("gsm8k", 'main')

# Math
math_subcategories = ['algebra', 'counting_and_probability', 'geometry', 'intermediate_algebra', 'number_theory', 'prealgebra', 'precalculus']
all_train_ds = []
all_test_ds = []
for subcategory in math_subcategories:
    math_subcategory = load_dataset('EleutherAI/hendrycks_math', subcategory)
    train_ds = math_subcategory['train']
    train_ds.add_column('subset', [subcategory] * len(train_ds))
    test_ds = math_subcategory['test']
    test_ds.add_column('subset', [subcategory] * len(test_ds))
    all_train_ds.append(train_ds)
    all_test_ds.append(test_ds)
    
math_ds = DatasetDict({'train': concatenate_datasets(all_train_ds), 'test': concatenate_datasets(all_test_ds)})

import re

def _clean_step(_step):
    _step = _step.strip()
    if not _step.endswith('.'):
        _step += '.'
    return _step

def parse_gsm_sample(_sample, idx):
    # Get the final answer number 
    def _remove_intermediate_compute(_cot):
        _cot = re.sub(r"(\d),(\d)", r"\1\2", _cot)
        regexp = re.compile(r'<<.*=(\d+/\d+|[-+]?\d*\.\d+|[-+]?\d+)>>(\d+/\d+|[-+]?\d*\.\d+|[-+]?\d+)')
        m = regexp.search(_cot)
        if m:
            if ('/' in m.group(1) and m.group(1) == m.group(2)) or float(m.group(1)) == float(m.group(2)):
                parsed = regexp.sub(m.group(2), _cot)
                return parsed
            else:
                print('Error 1', _cot)
        elif '<<' in _cot or '>>' in _cot:
            print('Error 2', _cot)

        return _clean_step(_cot) 
        
    original_answer = _sample['answer']
    final_answer = int(original_answer.split('#### ')[1].replace(',', ''))
    cot = original_answer.split('#### ')[0]
    parsed_cot = [_remove_intermediate_compute(c) for c in cot.split('\n') if c.strip()]

    _sample.update({
        "id": f"gsm_{idx}",
        "question": _sample['question'],
        "final_answer": final_answer,
        "cot": cot.strip(),
        "steps": parsed_cot,
    })
    return _sample 
    

def remove_boxed(s):
    left = "\\boxed{"
    try:
        assert s[:len(left)] == left
        assert s[-1] == "}"
        return s[len(left):-1]
    except:
        return None

def last_boxed_only_string(string):
    idx = string.rfind("\\boxed")
    if idx < 0:
        idx = string.rfind("\\fbox")
        if idx < 0:
            return None

    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    while i < len(string):
        if string[i] == "{":
            num_left_braces_open += 1
        if string[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1

    if right_brace_idx == None:
        retval = None
    else:
        retval = string[idx : right_brace_idx + 1]

    return retval

def parse_math_example(sample, idx):
    solution = sample['solution']
    final_answer = remove_boxed(last_boxed_only_string(solution))
    steps = []
    if '\n' in solution:
        steps = [_clean_step(s) for s in solution.split('\n')]
    elif '. ' in solution:
        steps =[_clean_step(s) for s in solution.split('. ')] 
    else:
        steps = [_clean_step(solution)]

    sample.update({
        "id": f"math_{idx}",
        "final_answer": final_answer,
        "steps": steps
    })
    
    return sample
        
def create_completion_prompt_format(example, instruction_key="question", final_answer_key="final_answer", steps_key="steps"):
    prompt = (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response: Let's think step by step. "
    ).format(instruction=example[instruction_key].strip())
    # 2
    steps = example[steps_key]
    if isinstance(steps, list):
        steps_str = '\n- '.join(steps)
        steps_str = '- ' + steps_str
    else:
        steps_str = steps

    completion = "\nSteps:\n" + steps_str + "\n\nFinal answer:\n" + str(example[final_answer_key])
    example.update({
        "prompt": prompt,
        "completion": completion,
    })
    return example

parsed_gsm_ds = gsm_ds.map(parse_gsm_sample, with_indices=True)
parsed_math_ds = math_ds.map(parse_math_example, with_indices=True) 


create_gsm_completion_prompt_format = partial(create_completion_prompt_format, instruction_key="question", final_answer_key="final_answer", steps_key="steps")

parsed_gsm_ds = parsed_gsm_ds.map(create_gsm_completion_prompt_format)

create_math_completion_prompt_format = partial(create_completion_prompt_format, instruction_key="problem", final_answer_key="final_answer", steps_key="steps")
parsed_math_ds = parsed_math_ds.map(create_math_completion_prompt_format)

def add_gsm_category(example):
    example['category_name'] = 'gsm'
    return example
    
parsed_gsm_ds = parsed_gsm_ds.map(add_gsm_category)

def add_math_category(example):
    category = f"math_{example['type'].lower()}_{example['level'].lower().replace(' ', '_')}"
    example['category_name'] = category
    return example
    
parsed_math_ds = parsed_math_ds.map(add_math_category)
parsed_math_ds['train'][0]

import os
from pathlib import Path
ws=os.environ['WS_PATH']
# Save train
version=2
gsm_train = parsed_gsm_ds['train']
gsm_train = gsm_train.remove_columns([c for c in gsm_train.column_names if c not in ['id', 'prompt', 'completion', 'category_name']])
gsm_path = f'{ws}/data/train/gsm_prompt_{version}/gsm_prompt_{version}_data.jsonl'
if Path(gsm_path).exists():
    raise FileExistsError(gsm_path)
gsm_train.to_json(gsm_path, orient='records', lines=True)

math_train = parsed_math_ds['train']
math_train = math_train.remove_columns([c for c in math_train.column_names if c not in ['id', 'prompt', 'completion', 'category_name']])
math_path = f'{ws}/data/train/math_prompt_{version}/math_prompt_{version}_data.jsonl'
if Path(math_path).exists():
    raise FileExistsError(math_path)
math_train.to_json(math_path, orient='records', lines=True)
    