import sys
sys.path.insert(0, __file__.rsplit("/", 2)[0])
sys.path.insert(0, __file__.rsplit("/", 2)[0] + "/utils")

import os
import json
import argparse

import torch
import transformers
from tqdm import tqdm
from datasets import load_dataset, load_from_disk

from tts_evaluator import TTSEvaluator
from utils.answer_processing import find_last_valid_expression, check_equality_judge, check_equality_local_model
from utils.gpu_parallel import get_worker_rank, init_worker_logger
from utils.task_queue import TaskQueue

if "NV_YT_OPERATION_ID" in os.environ:
    import nirvana_dl

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--queue",
        type=str,
        default=None,
        help="Endpoint for a zmq-like task dispenser that dispenses task indices. Provide *either* this or start & end"
    )
    parser.add_argument(
        "--start",
        type=int,
        default=None,
        help="First task to be processed by script inclusive. E.g --start 0 --end 100 will process tasks [0-99]"
    )
    parser.add_argument(
        "--end",
        type=int,
        default=None,
        help="Last task to be processed by script exclusive. E.g --start 0 --end 100 will process tasks [0-99]"
    )
    parser.add_argument(
        "--mode",
        type=str,
        required=True,
        choices=["async_reasoning", "baseline_think", "baseline_no_think"],
        help="Select reasoning mode",
    )
    parser.add_argument("--model-name", type=str, default="Qwen/Qwen3-32B", help="Model name from hf")
    parser.add_argument("--budget", type=int, default=16384, help="Budget to eval on")
    parser.add_argument("--use-slow-kernel", action="store_true", default=False, help="Disable fast kernel")
    parser.add_argument("--use-local-judge", action="store_true", default=False, help="Use the same model as a judge for result.")
    parser.add_argument("--dataset_path", type=str, required=True,
                        help="sharded math500 dataset path for load_from_disk")
    parser.add_argument("--path-to-results", type=str, help="path to store exp results", default="./eval_results/math-500")
    parser.add_argument("--dump_snapshot_freq", type=int, default=4, help="yandex-internal snapshotting frequency")
    parser.add_argument("--device_map", type=str, default="auto", help="passed to model.from_pretrained")
    parser.add_argument("--next_shard_every_steps", type=int, help="Setting to set up shards appearance frequency. Exceptions are: 0 -- concat, -1 -- never supply the rest of the shards.")
    parser.add_argument(
        "--shard_to_target",
        nargs="+",
        choices=["thinker", "writer", "input"],
        default=None,
        help='Where to share live context. Use: --shard_to_target input | thinker | writer',
    )
    parser.add_argument(
        "--target_reminders",
        nargs="+",
        choices=["thinker", "writer"],
        default=[],
        help='Which of shard_to_target are reminders. Use: --target_reminders thinker | writer',
    )

    parser.add_argument(
        "--shard_wait_step", action="store_true", default=False, help='Wait for \\n\\n before inserting shard?',
    )
    return parser.parse_args()


def main():
    args = parse_args()
    rank = get_worker_rank()
    logger = init_worker_logger()
    logger.info(f'The script was run in the following way:')
    logger.info(f"python {__file__} \\\n" + "\n".join(f"\t\t--{k} {v} \\" for k, v in vars(args).items()))
    use_fast_kernel = not args.use_slow_kernel
    assert (not args.use_local_judge) or (not use_fast_kernel), "You cannot use local model with kernel as a judge"
    for reminder in args.target_reminders:
        assert reminder not in args.shard_to_target, f"{reminder} is both in shard_to_target and target_reminders, choose one "
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    tokenizer = transformers.AutoTokenizer.from_pretrained(args.model_name)
    model = transformers.AutoModelForCausalLM.from_pretrained(
        args.model_name, torch_dtype='auto', device_map=args.device_map, low_cpu_mem_usage=True
    )

    solver_kwargs = {}
    if args.mode in ["async_reasoning"]:
        from async_reasoning.solver import AsyncReasoningSolver as Solver
        from async_reasoning.async_input_hook import async_input_hook_constructor

        system_tokens = [key for key in tokenizer.vocab.keys() if key.endswith("SYSTEM") or key.endswith("SYSTEM:")]
        writer_forbidden_token_ix = [tokenizer.vocab[x] for x in ["</think>", "<|im_start|>", "<|endoftext|>"] + system_tokens]
        thinker_forbidden_token_ix = [tokenizer.vocab[x] for x in ["<|im_start|>", "<|im_end|>", "<|endoftext|>"] + system_tokens]
        end_of_think_token_ix = [tokenizer.vocab[x] for x in ["</think>"]]
        solver_kwargs.update({
            "writer_forbidden_token_ix": writer_forbidden_token_ix,
            "thinker_forbidden_token_ix": thinker_forbidden_token_ix,
            "use_fast_kernel": use_fast_kernel,
            "end_of_think_token_ix": end_of_think_token_ix,
        })
    elif args.mode in ["baseline_think", "baseline_no_think"]:
        assert args.next_shard_every_steps is None, "async_input mode does not work in baselines."
        assert args.shard_to_target is None, "async_input mode does not work in baselines."
        assert args.shard_wait_step is [], "async_input mode does not work in baselines."
        from evals.baseline_solver import BaselineSolver as Solver
        solver_kwargs.update({
            "thinker_enabled": (args.mode == "baseline_think"),
        })
    else:
        raise ValueError("unsupported mode")

    solver = Solver(model, tokenizer, **solver_kwargs)
    dataset_math = load_from_disk(args.dataset_path)
    accuracy_numerator = accuracy_denominator = 0
    exp_dir_path = f"{args.path_to_results}/math-500_sharded_{args.next_shard_every_steps}_steps/{args.mode}"
    os.makedirs(exp_dir_path, exist_ok=True)
    evaluator = TTSEvaluator()

    def _solve_task_and_save(idx: int):
        save_path = f"{exp_dir_path}/sample_{idx}.json"
        if os.path.exists(save_path):  
            return  # already solved by previous run and saved in snapshot

        nonlocal accuracy_numerator, accuracy_denominator
        problem_shards = dataset_math[idx]['problem_shards']
        answer = str(dataset_math[idx]['answer'])
        assert len(problem_shards) >= 2, f"Unexpected number of shards on id: {idx}, {len(problem_shards)}"
        instruction = "".join(problem_shards) if args.next_shard_every_steps == 0 else problem_shards[0]
        problem = f"Please reason step by step, and put your final answer within \\boxed{{}}.\n\n{instruction}"

        writer_output_str, thinker_output_str, token_times, eos_generated = \
            solver.solve(
                problem, 
                budget=args.budget,  
                on_new_tokens_generated=async_input_hook_constructor(
                    solver,
                    args.shard_to_target,
                    args.target_reminders,
                    args.next_shard_every_steps,
                    problem_shards[1:],
                    args.shard_wait_step,
                )
            )
        response = find_last_valid_expression(writer_output_str, extract_result=lambda x: x[7:-1])
        assert len(token_times) > 0

        if args.use_local_judge:
            is_equal = check_equality_local_model(model, tokenizer, response, answer)
        else:
            is_equal = check_equality_judge(response, answer)

        chunks = evaluator.get_chunks_with_tts(token_times[:-1] if eos_generated else token_times, k_chunks=5, return_audio=False)
        metrics = evaluator(**chunks, add_tts_in_parrallel=True, return_delays=False)
        total_delay = metrics["total_delay"]
        result = {
            "idx": idx,
            "is_equal": is_equal,
            "metrics": metrics,
            "token_times": token_times,
            "eos_generated": eos_generated,
            "response_answers": response,
            "correct_answer": answer,
            "writer_response": writer_output_str,
            "thinker_response": thinker_output_str,
        }
        accuracy_numerator += int(is_equal)
        accuracy_denominator += 1
        current_accuracy = (accuracy_numerator / accuracy_denominator)
        print(end=f'[{rank=}] {idx=}, {eos_generated=}, {is_equal=}, {total_delay=:.3f}\t| {current_accuracy=:.3f}', file=sys.stderr)
        with open(save_path, "w") as f:
            json.dump(result, f, indent=2)
        if "NV_YT_OPERATION_ID" in os.environ and rank == 0 and (
                accuracy_denominator % args.dump_snapshot_freq == args.dump_snapshot_freq - 1):
            nirvana_dl.snapshot.dump_snapshot()
            logger.info("Dumped Nirvana snapshot")

    if args.start is not None and args.end is not None:
        logger.info(f'Generating tasks [{args.start}; {args.end})')
        for idx in tqdm(range(args.start, args.end), desc=f'Process {rank}'):
            _solve_task_and_save(idx)
    elif args.queue is not None:
        logger.info(f'Generating tasks from {args.queue}')
        for idx in tqdm(TaskQueue.iterate_tasks_from_queue(endpoint=args.queue), desc=f"Process {rank}"):
            _solve_task_and_save(idx)
    else:
        raise NotImplementedError("Please specify either --queue or both --start and --end")
    logger.info(f'Process {rank} has finished.')

if __name__ == "__main__":
    main()
