import argparse
import json
import os
import pandas as pd
from config import (
    MATH_DIR, MATH_MAX_LEN, MATH_NUM_CHAINS, MATH_PROBE_FREQ,
    MMLU_DIR, MMLU_MAX_LEN, MMLU_NUM_CHAINS, MMLU_PROBE_FREQ,
    GSM8K_DIR, GSM8K_MAX_LEN, GSM8K_NUM_CHAINS, GSM8K_PROBE_FREQ,
    MODEL_IDS,
)
from evaluator import extract_first_boxed_answer, extract_answer
from math_answer import MathAnswer
from transformers import AutoTokenizer
from tqdm import tqdm
from utils import process_math_id


"""
Postprocessed intermediate probing results should have following columns:
  - unique_id: problem identifier
  - chain_id: chain identifier
  - tokens: current token budget
  - curr_answer: answer in canonical form
  - type: "intermediate" or "final"
"""

def strip_deepseek_r1_thinking(response):
    if "</think>" not in response:
        return "Unfinished"
    return response.split('</think>')[1].strip()


def postprocess_math(model_id="deepseek-ai/DeepSeek-R1-Distill-Llama-8B"):
    math = pd.read_csv(os.path.join(MATH_DIR, 'math3k.csv'))
    math = math.sample(frac=1).reset_index(drop=True)
    
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    
    total_problems = math.shape[0]
    pbar = tqdm(total=total_problems)

    output_dir = os.path.join(MATH_DIR, MODEL_IDS[model_id], 'probe_postprocessed')
    os.makedirs(output_dir, exist_ok=True)

    for _, row in math.iterrows():
        pbar.update(1)
        id = row['unique_id']
        fname = process_math_id(id) + ".json"
        output_fname = os.path.join(output_dir, f"{id}.csv")
        if not os.path.isfile(output_fname):
            probe_fname = os.path.join(MATH_DIR, MODEL_IDS[model_id], "probe", fname)
            if os.path.isfile(probe_fname):
                response_fname = os.path.join(MATH_DIR, MODEL_IDS[model_id], "response", fname)
                plog = list()
                with open(probe_fname, 'r') as f:
                    probe_results = json.load(f)
                with open(response_fname, 'r') as f:
                    response_results = json.load(f)

                ma = MathAnswer(row['answer'])

                chain_id = 0
                for probes, full_response in zip(probe_results, response_results):
                    # add intermediate answers
                    for idx in range(len(probes)):
                        tb = (idx + 1) * MATH_PROBE_FREQ
                        ans = extract_first_boxed_answer("**Final Answer**\n\n\\[ \\boxed{" + probes[idx][1], "math500")
                        ans_canonical = ma.add_answer(ans)
                        plog.append({
                            "unique_id": id,
                            "chain_id": chain_id,
                            "tokens": tb,
                            "curr_answer": ans_canonical,
                            "type": "intermediate"
                        })
                    
                    # add final answer
                    full_response_tokens = len(tokenizer.encode(full_response))
                    tb = min(full_response_tokens, MATH_MAX_LEN)
                    ans = extract_answer(full_response, "math500")
                    ans_canonical = ma.add_answer(ans)
                    plog.append({
                        "unique_id": id,
                        "chain_id": chain_id,
                        "tokens": tb,
                        "curr_answer": ans_canonical,
                        "type": "final"
                    })

                    chain_id += 1
                plog = pd.DataFrame(plog)
                plog.to_csv(output_fname, index=False, header=True)
                        

    pbar.close()

def force_extract_multiple_choice(pred):
    pred_clean = pred.strip().upper()
    if pred_clean in ["A", "B", "C", "D"]:
        return pred_clean
    
    for choice_str in ["A", "B", "C", "D"]:
        if pred_clean.startswith(choice_str):
            if pred_clean[1] in [":", ")", ".", " ", "*"]:
                return choice_str
    return "Invalid"


def postprocess_mmlu(model_id="deepseek-ai/DeepSeek-R1-Distill-Llama-8B"):
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    mmlu = pd.read_csv(os.path.join(MMLU_DIR, 'mmlu.csv'))
    pbar = tqdm(total=mmlu.shape[0])

    output_dir = os.path.join(MMLU_DIR, MODEL_IDS[model_id], 'probe_postprocessed')
    os.makedirs(output_dir, exist_ok=True)

    for _, row in mmlu[mmlu.category == 'test'].iterrows():
        pbar.update(1)
        id = row['unique_id']
        fname = f"{id}.json"
        output_fname = os.path.join(output_dir, f"{id}.csv")
        if not os.path.isfile(output_fname):
            probe_fname = os.path.join(MMLU_DIR, MODEL_IDS[model_id], "probe", fname)
            if os.path.isfile(probe_fname):
                response_fname = os.path.join(MMLU_DIR, MODEL_IDS[model_id], "response", fname)
                plog = list()
                with open(probe_fname, 'r') as f:
                    probe_results = json.load(f)
                with open(response_fname, 'r') as f:
                    response_results = json.load(f)
                
                chain_id = 0
                for probes, full_response in zip(probe_results, response_results):
                    for idx in range(len(probes)):
                        tb = (idx + 1) * MMLU_PROBE_FREQ
                        ans_raw = extract_first_boxed_answer("**Final Answer**\n\n\\[ \\boxed{" + probes[idx][1], "mmlu")
                        ans = force_extract_multiple_choice(ans_raw)
                        plog.append({
                            "unique_id": id,
                            "chain_id": chain_id,
                            "tokens": tb,
                            "curr_answer": ans,
                            "curr_answer_raw": ans_raw,
                            "type": "intermediate"
                        })
                    
                    full_response_tokens = len(tokenizer.encode(full_response))
                    tb = min(full_response_tokens, MMLU_MAX_LEN)
                    if model_id.startswith("deepseek-ai"):
                        ans_raw = strip_deepseek_r1_thinking(full_response)
                        ans = extract_answer(full_response, "mmlu", False)
                        ans = force_extract_multiple_choice(ans)
                    else:
                        ans_raw = extract_answer(ans_raw, "mmlu", False)
                        ans = force_extract_multiple_choice(ans_raw)
                    plog.append({
                        "unique_id": id,
                        "chain_id": chain_id,
                        "tokens": tb,
                        "curr_answer": ans,
                        "curr_answer_raw": ans_raw,
                        "type": "final"
                    })

                    chain_id += 1
                plog = pd.DataFrame(plog)
                plog.to_csv(output_fname, index=False, header=True)
    
    pbar.close()


def postprocess_gsm8k(model_id="deepseek-ai/DeepSeek-R1-Distill-Llama-8B"):

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    
    gsm8k = pd.read_csv(os.path.join(GSM8K_DIR, 'gsm8k.csv'))
    pbar = tqdm(total=gsm8k.shape[0])
    
    output_dir = os.path.join(GSM8K_DIR, MODEL_IDS[model_id], 'probe_postprocessed')
    os.makedirs(output_dir, exist_ok=True)
    
    for _, row in gsm8k.iterrows():
        pbar.update(1)
        id = row['unique_id']
        fname = f"{id}.json"
        output_fname = os.path.join(output_dir, f"{id}.csv")
        if not os.path.isfile(output_fname):
            probe_fname = os.path.join(GSM8K_DIR, MODEL_IDS[model_id], "probe", fname)
            if os.path.isfile(probe_fname):
                response_fname = os.path.join(GSM8K_DIR, MODEL_IDS[model_id], "response", fname)
                plog = list()
                with open(probe_fname, 'r') as f:
                    probe_results = json.load(f)
                with open(response_fname, 'r') as f:
                    response_results = json.load(f)
                
                chain_id = 0
                for probes, full_response in zip(probe_results, response_results):
                    for idx in range(len(probes)):
                        tb = (idx + 1) * GSM8K_PROBE_FREQ
                        ans = extract_first_boxed_answer("**Final Answer**\n\n\\[ \\boxed{" + probes[idx][1], "gsm8k")
                        plog.append({
                            "unique_id": id,
                            "chain_id": chain_id,
                            "tokens": tb,
                            "curr_answer": ans,
                            "type": "intermediate"
                        })
                    
                    full_response_tokens = len(tokenizer.encode(full_response))
                    tb = min(full_response_tokens, GSM8K_MAX_LEN)
                    ans = extract_answer(full_response, "gsm8k", True)
                    plog.append({
                        "unique_id": id,
                        "chain_id": chain_id,
                        "tokens": tb,
                        "curr_answer": ans,
                        "type": "final"
                    })

                    chain_id += 1
                plog = pd.DataFrame(plog)
                plog.to_csv(output_fname, index=False, header=True)
    
    pbar.close()
                


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("-d", "--dataset", type=str, required=True)
    parser.add_argument("-m", "--model", type=str, required=True)
    args = parser.parse_args()

    if args.dataset == "math":
        postprocess_math(args.model)
    elif args.dataset == "mmlu":
        postprocess_mmlu(args.model)
    elif args.dataset == "gsm8k":
        postprocess_gsm8k(args.model)
