import argparse
import json
import numpy as np
import os
import pandas as pd
import random
import time
import torch
import tqdm
from client import vllmClientModel
from config import (
    GPQA_DIR, GPQA_MAX_LEN, GPQA_NUM_CHAINS,
    MATH_DIR, MATH_MAX_LEN, MATH_NUM_CHAINS,
    MMLU_DIR, MMLU_MAX_LEN, MMLU_NUM_CHAINS,
    GSM8K_DIR, GSM8K_MAX_LEN, GSM8K_NUM_CHAINS,
    OLYMPIAD_DIR, OLYMPIAD_MAX_LEN, OLYMPIAD_NUM_CHAINS,
    MODEL_IDS,
)
from config import *
from evaluator import extract_answer, extract_first_boxed_answer
from math_answer import MathAnswer
from sklearn.model_selection import train_test_split
from utils import process_math_id

# vllm serve deepseek-ai/DeepSeek-R1-Distill-Llama-8B -tp 1 --enable-prefix-caching --port 30000


def main():
    parser = argparse.ArgumentParser()
    
    parser.add_argument('--dataset', type=str, required=True)
    parser.add_argument('--model_id', type=str, required=True)
    parser.add_argument('--trace_dir', type=str, required=True)
    parser.add_argument('--model_url', type=str, required=True)
    parser.add_argument('--mode', type=str, required=True)
    parser.add_argument('--m', type=float, required=False, default=3)
    parser.add_argument('--interval', type=int, required=False, default=64)
    
    args = parser.parse_args()
    trace_dir_full = os.path.join(
        DSET_TO_DIR[args.dataset],
        MODEL_IDS[args.model_id],
        "latency_traces",
        args.trace_dir
    )
    
    if args.dataset == "math":
        df = pd.read_csv(os.path.join(MATH_DIR, 'math3k.csv'))
    elif args.dataset == "mmlu":
        df = pd.read_csv(os.path.join(MMLU_DIR, 'mmlu.csv'))
    elif args.dataset == "gsm8k":
        df = pd.read_csv(os.path.join(GSM8K_DIR, 'gsm8k.csv'))
    
    if 'train' in df.columns:
        df = df[df['train'] == 0]
    elif 'category' in df.columns:
        df = df[df['category'] == 'test']
    else:
        exit(1)

    uid_to_prompt = dict(zip(df['unique_id'], df['problem']))
    uids = [u for u in df.unique_id.values if os.path.isfile(os.path.join(trace_dir_full, f"{u}.json"))]

    model = vllmClientModel(
        args.model_id,
        args.model_url,
        "token-abc123")
    
    results = list()
    uid_iter = tqdm.tqdm(uids)
    for uid in uid_iter:
        prompt = uid_to_prompt[uid]
        problem_prompt = model.prepare_prompt(prompt)
        with open(os.path.join(trace_dir_full, f"{uid}.json"), 'r') as f:
            trace = json.load(f)
        max_tokens = trace['max_tokens']
        problem_prompts = [problem_prompt for _ in range(len(max_tokens))]
        is_actives = [True for _ in range(len(problem_prompt))]
        if args.mode == "default":
            start = torch.cuda.Event(enable_timing=True)
            end  = torch.cuda.Event(enable_timing=True)
            start.record()
            completions = model.generate_batch(
                prompts=problem_prompt,
                max_tokens=max_tokens,
                temperatures=0.6,
                top_p=0.95,
                n=1,
                is_actives=is_actives,
                extra_body={"ignore_eos": True}
            )
            end.record()
            end.synchronize()
            latency = start.elapsed_time(end)
            results.append({
                'uid': uid,
                'latency_default': latency / 1000,
            })
        elif args.mode == "dynasor":
            latency = 0
            latency_probe = 0
            interval = args.interval
            num_steps = int(np.ceil(np.amax(max_tokens) / args.interval))
            num_chains = len(max_tokens)
            prefixes = [problem_prompt] * num_chains
            for i in range(num_steps):
                curr_budget = args.interval * (i + 1)
                step_sizes = [min(interval, mt - curr_budget) for mt in max_tokens]
                start = torch.cuda.Event(enable_timing=True)
                end  = torch.cuda.Event(enable_timing=True)
                start.record()
                completions = model.generate_batch(
                    prompts=prefixes,
                    max_tokens=step_sizes,
                    temperatures=0.6,
                    top_p=0.95,
                    n=1,
                    is_actives=is_actives,
                    extra_body={"ignore_eos": True}
                )
                end.record()
                end.synchronize()
                latency += start.elapsed_time(end)
                prefixes = [p + c.choices[0].text if c is not None else p for p, c, a in zip(prefixes, completions, is_actives)]

                is_actives = [curr_budget <= mt for mt in max_tokens]
                probe_prefixes = [pp for pp in prefixes]
                probe_queries = [pp + "**Final Answer**\n\n\\[ \\boxed{" for pp in probe_prefixes]
                start = torch.cuda.Event(enable_timing=True)
                end  = torch.cuda.Event(enable_timing=True)
                start.record()
                completions = model.generate_batch(
                    prompts=probe_queries,
                    max_tokens=10,
                    temperatures=0.6,
                    top_p=0.95,
                    n=1,
                    is_actives=is_actives,
                    extra_body={}
                )
                end.record()
                end.synchronize()
                latency_probe += start.elapsed_time(end)

            results.append({
                'uid': uid,
                'latency_dynasor': latency / 1000,
                'latency_probe': latency_probe / 1000,
            })

            uid_iter.set_description(f"UID: {uid}, Latency: {latency / 1000:.2f}s, Latency Probe: {latency_probe / 1000:.2f}s")

        elif args.mode == "shortm":
            num_chains = len(max_tokens)
            if args.m < 1:
                m_value = int(args.m * num_chains)
            else:
                m_value = int(min(args.m, num_chains))
            cap = sorted(max_tokens)[m_value - 1]
            capped_max_tokens = [min(mt, cap) for mt in max_tokens]
            start = torch.cuda.Event(enable_timing=True)
            end  = torch.cuda.Event(enable_timing=True)
            start.record()
            completions = model.generate_batch(
                prompts=problem_prompt,
                max_tokens=capped_max_tokens,
                temperatures=0.6,
                top_p=0.95,
                n=1,
                is_actives=is_actives,
                extra_body={"ignore_eos": True}
            )
            end.record()
            end.synchronize()
            latency = start.elapsed_time(end)
            results.append({
                'uid': uid,
                'latency_shortm': latency / 1000,
            })
        elif args.mode == "duchess":
            num_chains = len(max_tokens)
            prefixes = [problem_prompt] * num_chains
            is_actives = [True for _ in range(num_chains)]
            iteration_logs = trace['iterations']
            latency_decode, latency_probe = 0, 0
            for iteration_log in iteration_logs:
                # process all branches
                start = torch.cuda.Event(enable_timing=True)
                end  = torch.cuda.Event(enable_timing=True)
                start.record()
                completions = model.generate_batch(
                    prompts=prefixes,
                    max_tokens=iteration_log['step_sizes'],
                    temperatures=0.6,
                    top_p=0.95,
                    n=1,
                    is_actives=is_actives,
                    extra_body={"ignore_eos": True}
                )
                end.record()
                end.synchronize()
                latency_decode += start.elapsed_time(end)
                prefixes = [p + c.choices[0].text for p, c in zip(prefixes, completions)]

                # probing
                probes = iteration_log['probed_chains']
                probe_prefixes = [prefixes[i] for i in probes]
                probe_queries = [pp + "**Final Answer**\n\n\\[ \\boxed{" for pp in probe_prefixes]
                start = torch.cuda.Event(enable_timing=True)
                end  = torch.cuda.Event(enable_timing=True)
                start.record()
                completions = model.generate_batch(
                    prompts=probe_queries,
                    max_tokens=10,
                    temperatures=0.6,
                    top_p=0.95,
                    n=1,
                    is_actives=is_actives,
                    extra_body={}
                )
                end.record()
                end.synchronize()
                latency_probe += start.elapsed_time(end)
                
                # branch out actions
                branch_outs = iteration_log['branch_outs']
                for branch_target, branch_source in branch_outs:
                    prefixes[branch_target] = prefixes[branch_source]

            results.append({
                'uid': uid,
                'latency_decode': latency_decode / 1000,
                'latency_probe': latency_probe / 1000,
            })
        elif args.mode == "prefill":
            start = torch.cuda.Event(enable_timing=True)
            end  = torch.cuda.Event(enable_timing=True)
            start.record()
            completions = model.generate_batch(
                prompts=problem_prompt,
                max_tokens=[1 for _ in range(len(max_tokens))],
                temperatures=0.6,
                top_p=0.95,
                n=1,
                is_actives=is_actives,
                extra_body={"ignore_eos": True}
            )
            end.record()
            end.synchronize()
            latency = start.elapsed_time(end)
            results.append({
                'uid': uid,
                'latency_prefill': latency / 1000,
            })
    results = pd.DataFrame(results)
    if args.mode == "default":
        results.to_csv(
            os.path.join(trace_dir_full, f"latency_default.csv"),
            index=False, header=True)
    elif args.mode == "duchess":
        results.to_csv(
            os.path.join(trace_dir_full, f"latency_duchess_vllm.csv"),
            index=False, header=True)
    elif args.mode == "shortm":
        results.to_csv(
            os.path.join(trace_dir_full, f"latency_shortm.csv"),
            index=False, header=True)
    elif args.mode == "prefill":
        results.to_csv(
            os.path.join(trace_dir_full, f"latency_prefill.csv"),
            index=False, header=True)
    elif args.mode == "dynasor":
        results.to_csv(
            os.path.join(trace_dir_full, f"latency_dynasor.csv"),
            index=False, header=True)


if __name__ == "__main__":
    main()
