import argparse
import json
import numpy as np
import os
import pandas as pd
import polars as pl
import random
import torch
import tqdm
from client import vllmClientModel
from config import (
    GPQA_DIR, GPQA_PROBE_FREQ, GPQA_NUM_CHAINS,
    MATH_DIR, MATH_PROBE_FREQ, MATH_NUM_CHAINS,
    MMLU_DIR, MMLU_PROBE_FREQ, MMLU_NUM_CHAINS,
    GSM8K_DIR, GSM8K_PROBE_FREQ, GSM8K_NUM_CHAINS,
    OLYMPIAD_DIR, OLYMPIAD_PROBE_FREQ, OLYMPIAD_NUM_CHAINS,
    MODEL_IDS,
)
from evaluator import extract_answer, extract_first_boxed_answer
from math_answer import MathAnswer
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import process_math_id


"""
CUDA_VISIBLE_DEVICES=0 vllm serve deepseek-ai/DeepSeek-R1-Distill-Llama-8B -tp 1 --enable-prefix-caching --port 30000
"""


def probe_math(model_id="deepseek-ai/DeepSeek-R1-Distill-Llama-8B", model_url="http://localhost:30000/v1"):
    df = pd.read_csv(os.path.join(MATH_DIR, 'math.csv'))
    id2problem = dict()
    for _, row in df.iterrows():
        id2problem[row['unique_id']] = row['problem']
    
    response_dir = os.path.join(MATH_DIR, MODEL_IDS[model_id], "response")
    probe_dir = os.path.join(MATH_DIR, MODEL_IDS[model_id], "probe")
    os.makedirs(probe_dir, exist_ok=True)

    model = vllmClientModel(
        model_id,
        model_url,
        "token-abc123")
    probing_tokens = 10
    probing_text = "**Final Answer**\n\n\\[ \\boxed{"

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    fnames = [x for x in os.listdir(response_dir) if x.endswith('.json')]
    random.shuffle(fnames)
    pbar = tqdm.tqdm(total=len(fnames) * MATH_NUM_CHAINS)
    for fidx, fname in enumerate(fnames):
        output_fpath = os.path.join(probe_dir, fname)
        if not os.path.exists(output_fpath):
            with open(os.path.join(response_dir, fname), 'r') as f:
                response_strings = json.load(f)
            probe_results = list()
            problem = id2problem[fname.replace('.json', '')]
            problem_prompt = model.prepare_prompt(problem)

            for ridx, response_string in enumerate(response_strings):
                pbar.update(1)
                pbar.set_description(f"Probing problem {fidx+1}/{len(fnames)}, chain {ridx+1}/{MATH_NUM_CHAINS}")
                response_encoding = tokenizer(response_string, return_tensors="pt")["input_ids"].squeeze().tolist()
                chunks = [response_encoding[i:min(len(response_encoding), i + MATH_PROBE_FREQ)] for i in range(0, len(response_encoding), MATH_PROBE_FREQ)]
                chunks = [tokenizer.decode(c, skip_special_tokens=True) for c in chunks]
                curr_prompt = problem_prompt

                chunk_probes = list()
                for chunk in chunks[:-1]:
                    curr_prompt += chunk
                    probing_prompt = curr_prompt + probing_text
                    probing_response = model.generate(
                        prompt=probing_prompt,
                        max_tokens=probing_tokens,
                        temperature=0.6,
                        top_p=0.95,
                        n=1,
                    )
                    probing_response = probing_response.choices[0].text
                    chunk_probes.append((chunk, probing_response))
                probe_results.append(chunk_probes)

            with open(output_fpath, 'w') as f:
                json.dump(probe_results, f, indent=2)
        else:
            pbar.update(MATH_NUM_CHAINS)
    pbar.close()


def probe_mmlu(model_id="deepseek-ai/DeepSeek-R1-Distill-Llama-8B", model_url="http://localhost:30000/v1"):
    response_dir = os.path.join(MMLU_DIR, MODEL_IDS[model_id], "response")
    probe_dir = os.path.join(MMLU_DIR, MODEL_IDS[model_id], "probe")
    os.makedirs(probe_dir, exist_ok=True)
    
    mmlu = pd.read_csv(os.path.join(MMLU_DIR, 'mmlu.csv'))
    id2problem = dict()
    for _, row in mmlu.iterrows():
        id2problem[row['unique_id']] = row['problem']

    model = vllmClientModel(
        model_id,
        model_url,
        "token-abc123")
    probing_tokens = 10
    probing_text = "**Final Answer**\n\n\\[ \\boxed{"

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    fnames = [x for x in os.listdir(response_dir) if x.endswith('.json')]
    random.shuffle(fnames)
    pbar = tqdm.tqdm(total=len(fnames) * MMLU_NUM_CHAINS)
    for fidx, fname in enumerate(fnames):
        output_fpath = os.path.join(probe_dir, fname)
        if not os.path.exists(output_fpath):
            with open(os.path.join(response_dir, fname), 'r') as f:
                response_strings = json.load(f)
            probe_results = list()
            problem = id2problem[int(fname.replace('.json', ''))]
            problem_prompt = model.prepare_prompt(problem)

            for ridx, response_string in enumerate(response_strings):
                pbar.update(1)
                pbar.set_description(f"Probing problem {fname}, chain {ridx+1}/{MMLU_NUM_CHAINS}")
                response_encoding = tokenizer(response_string, return_tensors="pt")["input_ids"].squeeze().tolist()
                chunks = [response_encoding[i:min(len(response_encoding), i + MMLU_PROBE_FREQ)] for i in range(0, len(response_encoding), MMLU_PROBE_FREQ)]
                chunks = [tokenizer.decode(c, skip_special_tokens=True) for c in chunks]
                curr_prompt = problem_prompt

                chunk_probes = list()
                for chunk in chunks[:-1]:
                    curr_prompt += chunk
                    probing_prompt = curr_prompt + probing_text
                    probing_response = model.generate(
                        prompt=probing_prompt,
                        max_tokens=probing_tokens,
                        temperature=0.6,
                        top_p=0.95,
                        n=1,
                    )
                    probing_response = probing_response.choices[0].text
                    chunk_probes.append((chunk, probing_response))
                probe_results.append(chunk_probes)

            with open(output_fpath, 'w') as f:
                json.dump(probe_results, f, indent=2)
        else:
            pbar.update(MMLU_NUM_CHAINS)
    pbar.close()


def probe_gsm8k(model_id="deepseek-ai/DeepSeek-R1-Distill-Llama-8B", model_url="http://localhost:30000/v1"):
    response_dir = os.path.join(GSM8K_DIR, MODEL_IDS[model_id], "response")
    probe_dir = os.path.join(GSM8K_DIR, MODEL_IDS[model_id], "probe")
    os.makedirs(probe_dir, exist_ok=True)
    
    gsm8k = pd.read_csv(os.path.join(GSM8K_DIR, 'gsm8k.csv'))
    id2problem = dict()
    for _, row in gsm8k.iterrows():
        id2problem[row['unique_id']] = row['problem']

    model = vllmClientModel(
        model_id,
        model_url,
        "token-abc123")
    probing_tokens = 10
    probing_text = "**Final Answer**\n\n\\[ \\boxed{"

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    fnames = [x for x in os.listdir(response_dir) if x.endswith('.json')]
    random.shuffle(fnames)
    pbar = tqdm.tqdm(total=len(fnames) * GSM8K_NUM_CHAINS)
    for fidx, fname in enumerate(fnames):
        output_fpath = os.path.join(probe_dir, fname)
        if not os.path.exists(output_fpath):
            with open(os.path.join(response_dir, fname), 'r') as f:
                response_strings = json.load(f)
            probe_results = list()
            problem = id2problem[int(fname.replace('.json', ''))]
            problem_prompt = model.prepare_prompt(problem)

            for ridx, response_string in enumerate(response_strings):
                pbar.update(1)
                pbar.set_description(f"Probing problem {fname}, chain {ridx+1}/{GSM8K_NUM_CHAINS}")
                response_encoding = tokenizer(response_string, return_tensors="pt")["input_ids"].squeeze().tolist()
                chunks = [response_encoding[i:min(len(response_encoding), i + GSM8K_PROBE_FREQ)] for i in range(0, len(response_encoding), GSM8K_PROBE_FREQ)]
                chunks = [tokenizer.decode(c, skip_special_tokens=True) for c in chunks]
                curr_prompt = problem_prompt

                chunk_probes = list()
                for chunk in chunks[:-1]:
                    curr_prompt += chunk
                    probing_prompt = curr_prompt + probing_text
                    probing_response = model.generate(
                        prompt=probing_prompt,
                        max_tokens=probing_tokens,
                        temperature=0.6,
                        top_p=0.95,
                        n=1,
                    )
                    probing_response = probing_response.choices[0].text
                    chunk_probes.append((chunk, probing_response))
                probe_results.append(chunk_probes)

            with open(output_fpath, 'w') as f:
                json.dump(probe_results, f, indent=2)
        else:
            pbar.update(GSM8K_NUM_CHAINS)
    pbar.close()


def probe_olympiad(model_id="deepseek-ai/DeepSeek-R1-Distill-Llama-8B", model_url="http://localhost:30000/v1"):
    response_dir = os.path.join(OLYMPIAD_DIR, MODEL_IDS[model_id], "response")
    probe_dir = os.path.join(OLYMPIAD_DIR, MODEL_IDS[model_id], "probe")
    os.makedirs(probe_dir, exist_ok=True)
    
    olympiad = pd.read_csv(os.path.join(OLYMPIAD_DIR, 'olympiad.csv'))
    id2problem = dict()
    for _, row in olympiad.iterrows():
        id2problem[row['unique_id']] = row['problem']

    model = vllmClientModel(
        model_id,
        model_url,
        "token-abc123")
    probing_tokens = 10
    probing_text = "**Final Answer**\n\n\\[ \\boxed{"

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    fnames = [x for x in os.listdir(response_dir) if x.endswith('.json')]
    random.shuffle(fnames)
    pbar = tqdm.tqdm(total=len(fnames) * OLYMPIAD_NUM_CHAINS)
    for fidx, fname in enumerate(fnames):
        output_fpath = os.path.join(probe_dir, fname)
        if not os.path.exists(output_fpath):
            with open(os.path.join(response_dir, fname), 'r') as f:
                response_strings = json.load(f)
            probe_results = list()
            problem = id2problem[int(fname.replace('.json', ''))]
            problem_prompt = model.prepare_prompt(problem)

            for ridx, response_string in enumerate(response_strings):
                pbar.update(1)
                pbar.set_description(f"Probing problem {fname}, chain {ridx+1}/{OLYMPIAD_NUM_CHAINS}")
                response_encoding = tokenizer(response_string, return_tensors="pt")["input_ids"].squeeze().tolist()
                chunks = [response_encoding[i:min(len(response_encoding), i + OLYMPIAD_PROBE_FREQ)] for i in range(0, len(response_encoding), OLYMPIAD_PROBE_FREQ)]
                chunks = [tokenizer.decode(c, skip_special_tokens=True) for c in chunks]
                curr_prompt = problem_prompt

                chunk_probes = list()
                for chunk in chunks[:-1]:
                    curr_prompt += chunk
                    probing_prompt = curr_prompt + probing_text
                    probing_response = model.generate(
                        prompt=probing_prompt,
                        max_tokens=probing_tokens,
                        temperature=0.6,
                        top_p=0.95,
                        n=1,
                    )
                    probing_response = probing_response.choices[0].text
                    chunk_probes.append((chunk, probing_response))
                probe_results.append(chunk_probes)

            with open(output_fpath, 'w') as f:
                json.dump(probe_results, f, indent=2)
        else:
            pbar.update(GSM8K_NUM_CHAINS)
    pbar.close()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("-d", "--dataset", type=str, required=True)
    parser.add_argument("-m", "--model", type=str, required=True)
    parser.add_argument("-u", "--url", type=str, required=True)
    args = parser.parse_args()

    if args.dataset == "math":
        probe_math(args.model, args.url)
    elif args.dataset == "mmlu":
        probe_mmlu(args.model, args.url)
    elif args.dataset == "gsm8k":
        probe_gsm8k(args.model, args.url)
    elif args.dataset == "olympiad":
        probe_olympiad(args.model, args.url)
