#!/usr/bin/env python
# -*- coding: utf-8 -*-
'''
vLLM Solver Server with Sleep Mode Support

This server supports sleep/wake_up mode for GPU memory time-sharing with Trainer.
When sleeping, the model weights are offloaded to CPU RAM, freeing GPU memory.

Setup Instructions:
    pip install stopit

Usage:
    python start_vllm_server.py --port 5000 --model_path Qwen/Qwen3-4B-Base
'''

from flask import Flask, request, jsonify
import vllm
import argparse
import json
import os
import time
import torch
from transformers import AutoTokenizer
from mathruler.grader import extract_boxed_content, grade_answer
import stopit

# ------------------------- Command-Line Arguments ------------------------- #
parser = argparse.ArgumentParser()
parser.add_argument('--port', type=str, default='5000')
parser.add_argument('--model_path', type=str, default='Qwen/Qwen3-4B-Base')
parser.add_argument('--gpu_mem_util', type=float, default=0.9,
                    help='The maximum GPU memory utilization fraction for vLLM.')
parser.add_argument('--no_sleep', action='store_true', default=False,
                    help='If set, do NOT enter sleep mode after loading (default: enter sleep)')
args = parser.parse_args()

# ------------------------- vLLM Initialization with Sleep Mode ------------------------ #
print('[init] Loading model with sleep mode enabled...')

tokenizer = AutoTokenizer.from_pretrained(args.model_path)
model = vllm.LLM(
    model=args.model_path,
    tokenizer=args.model_path,
    gpu_memory_utilization=args.gpu_mem_util,
    enable_sleep_mode=True,  # Enable sleep mode for GPU memory time-sharing
)

sample_params = vllm.SamplingParams(
    max_tokens=4096,
    temperature=1.0,
    top_p=1.0,
    top_k=40,
    stop_token_ids=[tokenizer.eos_token_id],
    n=10,  # Generate 10 candidate answers for each question
)

print('[init] Model loaded successfully.')

# Enter sleep mode immediately after loading (default behavior for time-sharing)
if not args.no_sleep:
    print('[init] Entering sleep mode to release GPU memory...')
    model.sleep(level=1)
    print('[init] Model is now sleeping. GPU memory released.')
else:
    print('[init] Sleep mode disabled, model remains active.')

# ------------------------ Timeout Utility --------------------------- #
@stopit.threading_timeoutable(default='TIMED_OUT')
def grade_answer_with_timeout(res1, res2):
    """Thread-safe timeout wrapper for grade_answer."""
    return grade_answer(res1, res2)

# ---------------------------- Flask Application --------------------------- #
app = Flask(__name__)


@app.route('/sleep', methods=['POST'])
def sleep_endpoint():
    """
    Put the vLLM model to sleep, releasing GPU memory.
    
    Query params:
        level: Sleep level (1 or 2, default 1)
            - Level 1: Offload weights to CPU, discard KV cache
            - Level 2: Discard both weights and KV cache
    """
    level = request.args.get('level', 1, type=int)
    try:
        model.sleep(level=level)
        torch.cuda.empty_cache()
        print(f'[server] Model entered sleep mode (level={level})')
        return jsonify({'status': 'sleeping', 'level': level})
    except Exception as e:
        print(f'[server] Sleep failed: {e}')
        return jsonify({'error': str(e)}), 500


@app.route('/wake_up', methods=['POST'])
def wake_up_endpoint():
    """
    Wake up the vLLM model, restoring GPU memory allocation.
    
    Query params:
        tags: Optional list of components to wake up (weights, kv_cache)
              If not specified, wakes up all components.
    """
    tags = request.args.getlist('tags')
    try:
        start_time = time.time()
        if tags:
            model.wake_up(tags=tags)
        else:
            model.wake_up()
        elapsed = time.time() - start_time
        print(f'[server] Model woke up (tags={tags if tags else "all"}) in {elapsed:.2f}s')
        return jsonify({'status': 'awake', 'tags': tags if tags else 'all', 'elapsed': elapsed})
    except Exception as e:
        print(f'[server] Wake up failed: {e}')
        return jsonify({'error': str(e)}), 500


@app.route('/is_sleeping', methods=['GET'])
def is_sleeping_endpoint():
    """Check if the model is currently in sleep mode."""
    try:
        # Note: is_sleeping is on the llm_engine, not the LLM class directly
        sleeping = model.llm_engine.is_sleeping()
        return jsonify({'is_sleeping': sleeping})
    except Exception as e:
        print(f'[server] is_sleeping check failed: {e}')
        import traceback
        traceback.print_exc()
        return jsonify({'error': str(e)}), 500


@app.route('/hello', methods=['GET'])
def hello():
    '''
    The main processing endpoint: reads a task file, invokes vLLM, 
    consolidates answers, and writes results.
    
    Note: The model must be awake before calling this endpoint.
    Use /wake_up first if the model is sleeping.
    '''
    name = request.args.get('name', 'None')
    print(f'[server] Received request for task file: {name}')

    # ---------- Load Data ----------
    with open(name, 'r') as f:
        data = json.load(f)
    os.remove(name)

    questions = [item.get('question', '') for item in data]
    answers   = [item.get('answer',   '') for item in data]

    # Data preparation
    valid_indices, valid_questions, valid_answers, valid_chats = [], [], [], []
    for i, (q, a) in enumerate(zip(questions, answers)):
        if q and a:
            valid_indices.append(i)
            valid_questions.append(q)
            valid_answers.append(a)
            valid_chats.append([
                {'role': 'system', 'content': 'Please reason step by step, and put your final answer within \\boxed{}.'},
                {'role': 'user',   'content': q}
            ])
    print('[server] Valid chat prompts have been prepared.')

    # ---------- vLLM Generation ----------
    if valid_chats:
        if tokenizer.chat_template:
            prompts = [
                tokenizer.apply_chat_template(chat, tokenize=False,
                                              add_generation_prompt=True, add_special_tokens=True)
                for chat in valid_chats
            ]
        else:
            prompts = [
                'system: ' + chat[0]['content'] + '\n' + 'user: ' + chat[1]['content']
                for chat in valid_chats
            ]
        responses = model.generate(prompts, sampling_params=sample_params, use_tqdm=True)
    else:
        responses = []
    print('[server] Generation completed.')

    # ---------- Results Post-Processing ----------
    def process_single(question, golden_answer, response):
        '''Consolidates and grades vLLM outputs for a single question.'''
        results = [extract_boxed_content(out.text) for out in response.outputs]

        answer_counts = {}
        for res in results:
            if not res: continue
            matched = False
            
            for exist_ans in list(answer_counts.keys()):
                # Cheap comparisons first
                if res == exist_ans or ('no ' in res.lower() and 'no ' in exist_ans.lower()):
                    answer_counts[exist_ans] += 1
                    matched = True
                    break
                
                # Expensive grade_answer calls with timeout
                try:
                    is_match = False
                    match_result_1 = grade_answer_with_timeout(res, exist_ans, timeout=10)
                    if match_result_1 == 'TIMED_OUT':
                        print(f"      [grader] TIMEOUT comparing '{res[:30]}...' with '{exist_ans[:30]}...'.")
                    elif match_result_1:
                        is_match = True

                    if not is_match:
                        match_result_2 = grade_answer_with_timeout(exist_ans, res, timeout=10)
                        if match_result_2 == 'TIMED_OUT':
                            print(f"      [grader] TIMEOUT comparing '{exist_ans[:30]}...' with '{res[:30]}...'. Skipping pair.")
                        elif match_result_2:
                            is_match = True
                    
                    if is_match:
                        answer_counts[exist_ans] += 1
                        matched = True
                        break

                except Exception as e:
                    print(f"      [grader] ERROR comparing '{res[:30]}...' with '{exist_ans[:30]}...': {e}. Skipping.")
                    continue
            
            if not matched:
                answer_counts[res] = 1

        if not answer_counts:
            majority_ans, max_count = '', 0
        else:
            majority_ans = max(answer_counts, key=answer_counts.get)
            max_count = answer_counts[majority_ans]

        score = max_count / len(results) if results else 0.0

        return {
            'question': question,
            'answer':   majority_ans,
            'score':    score,
            'results':  results
        }

    results_all = []
    response_idx = 0
    for q, a in zip(questions, answers):
        try:
            if q and a:
                response = responses[response_idx]
                response_idx += 1
                item = process_single(q, a, response)
                results_all.append(item)
            else:
                results_all.append({'question': q, 'answer': a, 'score': -1, 'results': []})
        except Exception as e:
            print(f'[server] CRITICAL: An unhandled error occurred while processing question: {q}')
            print(f'[server] Error details: {e}')
            results_all.append({
                'question': q,
                'answer':   a,
                'score':    -1,
                'results':  [],
                'error':    f'unhandled exception in process_single: {str(e)}'
            })
    print('[server] All results have been processed.')

    out_path = name.replace('.json', '_results.json')
    with open(out_path, 'w') as f:
        json.dump(results_all, f, indent=4)

    print(f'[server] Processed {name}, results saved to {out_path}.')
    return jsonify({'message': f'Processed {name}, results saved to {out_path}.'})


# ------------------------- Main Application Entrypoint --------------------------- #
if __name__ == '__main__':
    print(f'[main] Starting Solver vLLM server on port {args.port}')
    print(f'[main] Sleep mode: {"disabled" if args.no_sleep else "enabled (default)"}')
    app.run(host='0.0.0.0', port=int(args.port), threaded=True)
