import argparse
import json
import numpy as np
import os
import pandas as pd
import polars as pl
import random
import tqdm
from client import vllmClientModel
from config import (
    GPQA_DIR, GPQA_MAX_LEN, GPQA_NUM_CHAINS,
    MATH_DIR, MATH_MAX_LEN, MATH_NUM_CHAINS,
    MMLU_DIR, MMLU_MAX_LEN, MMLU_NUM_CHAINS,
    GSM8K_DIR, GSM8K_MAX_LEN, GSM8K_NUM_CHAINS,
    OLYMPIAD_DIR, OLYMPIAD_MAX_LEN, OLYMPIAD_NUM_CHAINS,
    MODEL_IDS,
)
from evaluator import extract_answer, extract_first_boxed_answer
from math_answer import MathAnswer
from sklearn.model_selection import train_test_split
from utils import process_math_id

# vllm serve deepseek-ai/DeepSeek-R1-Distill-Llama-8B -tp 1 --enable-prefix-caching --port 30000

def profile_math(model_id, model_url):
    os.makedirs(MATH_DIR, exist_ok=True)

    if not os.path.isfile(os.path.join(MATH_DIR, 'math.csv')):
        splits = {'train': 'data/train-00000-of-00001.parquet', 'test': 'data/test-00000-of-00001.parquet'}
        df_train_full = pd.read_parquet("hf://datasets/nlile/hendrycks-MATH-benchmark/" + splits["train"])
        df_test = pd.read_parquet("hf://datasets/nlile/hendrycks-MATH-benchmark/" + splits["test"])
        df_train, _ = train_test_split(
            df_train_full, train_size=5000, random_state=0, stratify=df_train_full['level'].values)
        df_train = df_train.copy()
        df_train['train'] = 1
        df_test['train'] = 0
        df = pd.concat([df_train, df_test], axis=0)
        df['unique_id'] = df.apply(lambda row: process_math_id(row['unique_id']), axis=1)
        df.to_csv(os.path.join(MATH_DIR, 'math.csv'), index=False, header=True)
    else:
        print("Found cached math prompts from disk")
    
    if not os.path.isfile(os.path.join(MATH_DIR, 'math3k.csv')):
        df = pd.read_csv(os.path.join(MATH_DIR, 'math.csv'))
        df_train, df_test = df[df['train'] == 1], df[df['train'] == 0]
        df_train_downsampled, _ = train_test_split(df_train, train_size=3000, random_state=0, stratify=df_train['level'].values)
        df_downsampled = pd.concat([df_train_downsampled, df_test], axis=0)
        df_downsampled.to_csv(os.path.join(MATH_DIR, 'math3k.csv'), index=False, header=True)
    else:
        print("Found cached math 3k prompts from disk")

    df = pd.read_csv(os.path.join(MATH_DIR, 'math.csv'))
    df = df.sample(frac=1).reset_index(drop=True)

    model = vllmClientModel(
        model_id,
        model_url,
        "token-abc123")
    
    output_dir = os.path.join(MATH_DIR, MODEL_IDS[model_id], "response")
    os.makedirs(output_dir, exist_ok=True)

    problem_counter, total_problems = 0, df.shape[0]
    pbar = tqdm.tqdm(total=total_problems)
    for _, row in df.iterrows():
        unique_id = row['unique_id']
        problem_counter += 1
        pbar.update(1)
        pbar.set_description(f"Problem {unique_id} {problem_counter}/{total_problems}")
        output_fpath = os.path.join(output_dir, f"{unique_id}.json")
        if not os.path.exists(output_fpath):
            problem = row['problem']
            problem_prompt = model.prepare_prompt(problem)

            completions = model.generate(
                prompt=problem_prompt,
                max_tokens=MATH_MAX_LEN,
                temperature=0.6,
                top_p=0.95,
                n=MATH_NUM_CHAINS,
            )

            response_strings = [choice.text for choice in completions.choices]
            with open(output_fpath, 'w') as f:
                json.dump(response_strings, f)
    pbar.close()


def profile_mmlu(model_id, model_url):
    os.makedirs(MMLU_DIR, exist_ok=True)
    if not os.path.isfile(os.path.join(MMLU_DIR, 'mmlu.csv')):
        splits = {'test': 'all/test-00000-of-00001.parquet', 'validation': 'all/validation-00000-of-00001.parquet', 'dev': 'all/dev-00000-of-00001.parquet', 'auxiliary_train': 'all/auxiliary_train-00000-of-00001.parquet'}
        df_train = pd.read_parquet("hf://datasets/cais/mmlu/" + splits["auxiliary_train"])
        df_val = pd.read_parquet("hf://datasets/cais/mmlu/" + splits["validation"])
        df_test = pd.read_parquet("hf://datasets/cais/mmlu/" + splits["test"])

        # downsample
        # df_train = df_train.sample(2000)
        # df_val = df_val.sample(500)
        # df_test = df_test.sample(500)

        df_train, _ = train_test_split(df_train, train_size=5000, random_state=0, stratify=df_train['subject'].values)
        df_val, _ = train_test_split(df_val, train_size=1000, random_state=0, stratify=df_val['subject'].values)
        df_test, _ = train_test_split(df_test, train_size=1000, random_state=0, stratify=df_test['subject'].values)

        processed = list()
        counter = 0
        for df, category in zip([df_train, df_val, df_test], ['train', 'val', 'test']):
            for _, row in df.iterrows():
                question_str = row['question']
                choices = row['choices']
                letters = ["A", "B", "C", "D"]
                choices_str = ", ".join(
                    f"{letter}) {choice}" for letter, choice in zip(letters, choices)
                )

                problem_str = (
                    "Return your final response within \\boxed{{}} and only include the letter choice (A, B, C, or D) as your final response. "
                    + question_str
                    + "\n"
                    + choices_str
                )

                ans = letters[row['answer']]

                processed.append({
                    "unique_id": counter,
                    "problem": problem_str,
                    "subject": row['subject'],
                    "answer": ans,
                    "category": category,
                })

                counter += 1
        
        processed = pd.DataFrame(processed)
        processed.to_csv(os.path.join(MMLU_DIR, 'mmlu.csv'), index=False, header=True)
    else:
        print("Found cached MMLU prompts from disk")
    
    mmlu = pd.read_csv(os.path.join(MMLU_DIR, 'mmlu.csv'))
    model = vllmClientModel(
        model_id,
        model_url,
        "token-abc123")
    
    output_dir = os.path.join(MMLU_DIR, MODEL_IDS[model_id], "response")
    os.makedirs(output_dir, exist_ok=True)

    problem_counter, total_problems = 0, mmlu.shape[0]
    pbar = tqdm.tqdm(total=total_problems)
    mmlu = mmlu.sample(frac=1).reset_index(drop=True)
    for _, row in mmlu.iterrows():
        unique_id = row['unique_id']

        problem_counter += 1
        pbar.update(1)
        pbar.set_description(f"Problem {problem_counter}/{total_problems}; unique id {unique_id}")

        output_fpath = os.path.join(output_dir, f"{unique_id}.json")
        if not os.path.exists(output_fpath):
            problem = row['problem']
            problem_prompt = model.prepare_prompt(problem)

            completions = model.generate(
                prompt=problem_prompt,
                max_tokens=MMLU_MAX_LEN,
                temperature=0.6,
                top_p=0.95,
                n=MMLU_NUM_CHAINS,
            )

            response_strings = [choice.text for choice in completions.choices]
            with open(output_fpath, 'w') as f:
                json.dump(response_strings, f)
    pbar.close()


def extract_gsm8k_answer(ans_raw):
    last_chunk = ans_raw.split("####")[-1].strip().replace(",", "")
    return float(last_chunk)


def profile_gsm8k(model_id, model_url):
    os.makedirs(GSM8K_DIR, exist_ok=True)
    if not os.path.isfile(os.path.join(GSM8K_DIR, 'gsm8k.csv')):
        print("Downloading and processing GSM8K dataset from Hugging Face...")
        splits = {'train': 'socratic/train-00000-of-00001.parquet', 'test': 'socratic/test-00000-of-00001.parquet'}
        df_train = pd.read_parquet("hf://datasets/openai/gsm8k/" + splits["train"])
        df_test = pd.read_parquet("hf://datasets/openai/gsm8k/" + splits["test"])
        df_train['category'] = 'train'
        df_test['category'] = 'test'
        df = pd.concat([df_train, df_test], axis=0)
        uids = list(range(df.shape[0]))
        df['unique_id'] = uids
        df.rename(columns={'answer': 'answer_raw', 'question': 'problem'}, inplace=True)

        df['answer'] = df.apply(lambda row: extract_gsm8k_answer(row['answer_raw']), axis=1)
        df.to_csv(os.path.join(GSM8K_DIR, 'gsm8k.csv'), index=False, header=True)
    
    df = pd.read_csv(os.path.join(GSM8K_DIR, 'gsm8k.csv'))
    model = vllmClientModel(
        model_id,
        model_url,
        "token-abc123")
    
    output_dir = os.path.join(GSM8K_DIR, MODEL_IDS[model_id], "response")
    os.makedirs(output_dir, exist_ok=True)

    pbar = tqdm.tqdm(total=df.shape[0])
    for _, row in df.iterrows():
        pbar.update(1)
        unique_id = row['unique_id']
        output_fpath = os.path.join(output_dir, f"{unique_id}.json")
        pbar.set_description(f"Processing problem {unique_id}")
        if not os.path.exists(output_fpath):
            problem = row['problem']
            problem_prompt = model.prepare_prompt(problem)

            completions = model.generate(
                prompt=problem_prompt,
                max_tokens=GSM8K_MAX_LEN,
                temperature=0.6,
                top_p=0.95,
                n=GSM8K_NUM_CHAINS,
            )

            response_strings = [choice.text for choice in completions.choices]
            with open(output_fpath, 'w') as f:
                json.dump(response_strings, f)

    pbar.close()


def profile_olympiad(model_id, model_url):
    os.makedirs(OLYMPIAD_DIR, exist_ok=True)
    if not os.path.isfile(os.path.join(OLYMPIAD_DIR, 'olympiad.csv')):
        df = pd.read_parquet("hf://datasets/math-ai/olympiadbench/test.parquet")
        df = df[['id', 'question', 'final_answer', 'subfield']]
        df.rename(columns={'id': 'unique_id', 'question': 'problem', 'final_answer': 'answer', 'subfield': 'subfield'}, inplace=True)
        df_train, df_test = train_test_split(df, train_size=0.8, random_state=0)
        df_train = df_train.copy()
        df_train['category'] = 'trian'
        df_test = df_test.copy()
        df_test['category'] = 'test'
        df = pd.concat([df_train, df_test], axis=0)
        df.to_csv(os.path.join(OLYMPIAD_DIR, 'olympiad.csv'), index=False, header=True)
    else:
        print("Found cached olympiad dataset from disk")
    
    df = pd.read_csv(os.path.join(OLYMPIAD_DIR, 'olympiad.csv'))
    model = vllmClientModel(
        model_id,
        model_url,
        "token-abc123")
    
    output_dir = os.path.join(OLYMPIAD_DIR, MODEL_IDS[model_id], "response")
    os.makedirs(output_dir, exist_ok=True)
    df = df.sample(frac=1).reset_index(drop=True)

    pbar = tqdm.tqdm(total=df.shape[0])
    for _, row in df.iterrows():
        pbar.update(1)
        unique_id = row['unique_id']
        output_fpath = os.path.join(output_dir, f"{unique_id}.json")
        pbar.set_description(f"Processing problem {unique_id}")
        if not os.path.exists(output_fpath):
            problem = row['problem']
            problem_prompt = model.prepare_prompt(problem)

            completions = model.generate(
                prompt=problem_prompt,
                max_tokens=OLYMPIAD_MAX_LEN,
                temperature=0.6,
                top_p=0.95,
                n=OLYMPIAD_NUM_CHAINS,
            )

            response_strings = [choice.text for choice in completions.choices]
            with open(output_fpath, 'w') as f:
                json.dump(response_strings, f)

    pbar.close()


if __name__ == '__main__':
    # python3 profile_prompts.py -d mmlu -m "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" -u "http://localhost:30000/v1"
    parser = argparse.ArgumentParser()
    parser.add_argument("-d", "--dataset", type=str, required=True)
    parser.add_argument("-m", "--model", type=str, required=False, default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B")
    parser.add_argument("-u", "--url", type=str, required=False, default="http://localhost:30000/v1")
    args = parser.parse_args()

    if args.dataset == "math":
        profile_math(model_id=args.model, model_url=args.url)
    elif args.dataset == "mmlu":
        profile_mmlu(model_id=args.model, model_url=args.url)
    elif args.dataset == "gsm8k":
        profile_gsm8k(model_id=args.model, model_url=args.url)
    elif args.dataset == "olympiad":
        profile_olympiad(model_id=args.model, model_url=args.url)
