#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Question-to-Code vLLM Server with Sleep Mode Support

This server converts math questions into Python solver code.
It supports sleep/wake_up mode for GPU memory time-sharing with Trainer.

Usage:
    python vllm_service_init/start_vllm_server_code.py --port 6000 --model_path /path/to/models/Qwen2.5-Coder-7B-Instruct
"""

from flask import Flask, request, jsonify
import vllm
import argparse
import json
import os
import re
import time
import torch
from transformers import AutoTokenizer
from typing import List, Optional, Tuple

# Import the prompt from utils
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from utils.prompts import QUESTION_TO_CODE_PROMPT

# ------------------------- Command-Line Arguments ------------------------- #
parser = argparse.ArgumentParser()
parser.add_argument('--port', type=str, default='6000')
parser.add_argument('--model_path', type=str, default='/path/to/models/Qwen2.5-Coder-7B-Instruct')
parser.add_argument('--gpu_mem_util', type=float, default=0.9,
                    help='The maximum GPU memory utilization fraction for vLLM.')
parser.add_argument('--no_sleep', action='store_true', default=False,
                    help='If set, do NOT enter sleep mode after loading (default: enter sleep)')
args = parser.parse_args()

# ------------------------- vLLM Initialization with Sleep Mode ------------------------ #
print(f'[init] Loading model from {args.model_path} with sleep mode enabled...')

tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
model = vllm.LLM(
    model=args.model_path,
    tokenizer=args.model_path,
    gpu_memory_utilization=args.gpu_mem_util,
    trust_remote_code=True,
    enable_sleep_mode=True,  # Enable sleep mode for GPU memory time-sharing
)

# Sampling parameters for code generation (deterministic)
sample_params = vllm.SamplingParams(
    max_tokens=1024,
    temperature=0.0,
    top_p=1.0,
    stop=["</CODE>"],  # Stop at the closing tag to save resources
    stop_token_ids=[tokenizer.eos_token_id] if tokenizer.eos_token_id else None,
)

print('[init] Model loaded successfully.')

# Enter sleep mode immediately after loading (default behavior for time-sharing)
if not args.no_sleep:
    print('[init] Entering sleep mode to release GPU memory...')
    model.sleep(level=1)
    print('[init] Model is now sleeping. GPU memory released.')
else:
    print('[init] Sleep mode disabled, model remains active.')


# ------------------------ Code Extraction Utilities ------------------------ #

def extract_code_block(raw_response: str) -> Tuple[Optional[str], Optional[str]]:
    """
    Extract code content from <CODE>...</CODE> tags in LLM response.
    """
    if raw_response is None:
        return None, "raw_response_is_none"
    
    # Try to extract <CODE> tag content
    pattern = r"<CODE>(.*?)</CODE>"
    match = re.search(pattern, raw_response, re.DOTALL)
    
    if match:
        return match.group(1).strip(), None
    
    # Handle truncated response (only <CODE> start without </CODE> end)
    pattern_truncated = r"<CODE>(.*)"
    match_truncated = re.search(pattern_truncated, raw_response, re.DOTALL)
    
    if match_truncated:
        code = match_truncated.group(1).strip()
        if code:
            return code, None
    
    # Fallback to markdown code blocks
    pattern_markdown = r"```(?:python)?\s*(.*?)```"
    match_markdown = re.search(pattern_markdown, raw_response, re.DOTALL)
    
    if match_markdown:
        return match_markdown.group(1).strip(), "markdown_fallback"
    
    return None, "no_code_block_found"


def postprocess_python(extracted_code: str) -> Tuple[Optional[str], Optional[str]]:
    """
    Post-process Python code - extract the function body from solver function.
    """
    if not extracted_code:
        return None, "empty_extracted_code"
    
    code = extracted_code.strip()
    
    lines = code.split('\n')
    while lines and not lines[0].strip():
        lines.pop(0)
    while lines and not lines[-1].strip():
        lines.pop()
    
    if not lines:
        return None, "empty_code_after_cleanup"
    
    has_solver = any('def solver' in line for line in lines)
    if not has_solver:
        return None, "no_solver_function_found"
    
    return '\n'.join(lines), None


# ---------------------------- Flask Application --------------------------- #
app = Flask(__name__)


@app.route('/sleep', methods=['POST'])
def sleep_endpoint():
    """
    Put the vLLM model to sleep, releasing GPU memory.
    """
    level = request.args.get('level', 1, type=int)
    try:
        model.sleep(level=level)
        torch.cuda.empty_cache()
        print(f'[server] Model entered sleep mode (level={level})')
        return jsonify({'status': 'sleeping', 'level': level})
    except Exception as e:
        print(f'[server] Sleep failed: {e}')
        return jsonify({'error': str(e)}), 500


@app.route('/wake_up', methods=['POST'])
def wake_up_endpoint():
    """
    Wake up the vLLM model, restoring GPU memory allocation.
    """
    tags = request.args.getlist('tags')
    try:
        start_time = time.time()
        if tags:
            model.wake_up(tags=tags)
        else:
            model.wake_up()
        elapsed = time.time() - start_time
        print(f'[server] Model woke up (tags={tags if tags else "all"}) in {elapsed:.2f}s')
        return jsonify({'status': 'awake', 'tags': tags if tags else 'all', 'elapsed': elapsed})
    except Exception as e:
        print(f'[server] Wake up failed: {e}')
        return jsonify({'error': str(e)}), 500


@app.route('/is_sleeping', methods=['GET'])
def is_sleeping_endpoint():
    """Check if the model is currently in sleep mode."""
    try:
        # Note: is_sleeping is on the llm_engine, not the LLM class directly
        sleeping = model.llm_engine.is_sleeping()
        return jsonify({'is_sleeping': sleeping})
    except Exception as e:
        print(f'[server] is_sleeping check failed: {e}')
        import traceback
        traceback.print_exc()
        return jsonify({'error': str(e)}), 500


@app.route('/health', methods=['GET'])
def health():
    """Health check endpoint."""
    return jsonify({'status': 'healthy', 'model': args.model_path})


@app.route('/generate', methods=['POST'])
def generate():
    """
    Generate Python solver code for a batch of math questions.
    
    Note: The model must be awake before calling this endpoint.
    """
    try:
        data = request.get_json()
        if not data or 'questions' not in data:
            return jsonify({'error': 'Missing "questions" field in request body'}), 400
        
        questions = data['questions']
        print(f'[server] Received batch request with {len(questions)} questions')
        
        # Build prompts for all questions
        prompts = []
        for question in questions:
            prompt = QUESTION_TO_CODE_PROMPT.format(question=question)
            messages = [{"role": "user", "content": prompt}]
            
            if tokenizer.chat_template:
                formatted_prompt = tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True,
                    add_special_tokens=True
                )
            else:
                formatted_prompt = prompt
            
            prompts.append(formatted_prompt)
        
        # Batch generation
        print(f'[server] Starting batch generation...')
        start_time = time.time()
        responses = model.generate(prompts, sampling_params=sample_params, use_tqdm=True)
        gen_time = time.time() - start_time
        print(f'[server] Generation completed in {gen_time:.2f}s')
        
        # Process results
        codes = []
        success_count = 0
        failed_count = 0
        
        for i, (question, response) in enumerate(zip(questions, responses)):
            if response and response.outputs:
                raw_response = response.outputs[0].text
                extracted_code, extract_error = extract_code_block(raw_response)
                
                if extracted_code:
                    processed_code, process_error = postprocess_python(extracted_code)
                    
                    if processed_code:
                        codes.append(processed_code)
                        success_count += 1
                    else:
                        codes.append(None)
                        failed_count += 1
                        print(f'[server] Question {i}: Post-process failed - {process_error}')
                else:
                    codes.append(None)
                    failed_count += 1
                    print(f'[server] Question {i}: Extract failed - {extract_error}')
            else:
                codes.append(None)
                failed_count += 1
                print(f'[server] Question {i}: No response generated')
        
        print(f'[server] Batch complete: {success_count} success, {failed_count} failed')
        
        return jsonify({
            'codes': codes,
            'success_count': success_count,
            'failed_count': failed_count
        })
        
    except Exception as e:
        print(f'[server] Error processing request: {e}')
        return jsonify({'error': str(e)}), 500


@app.route('/generate_with_retry', methods=['POST'])
def generate_with_retry():
    """
    Generate Python solver code with retry mechanism for failed questions.
    
    Note: The model must be awake before calling this endpoint.
    """
    try:
        data = request.get_json()
        if not data or 'questions' not in data:
            return jsonify({'error': 'Missing "questions" field'}), 400
        
        questions = data['questions']
        max_retries = data.get('max_retries', 3)
        
        print(f'[server] Received retry request: {len(questions)} questions, max_retries={max_retries}')
        
        codes = [None] * len(questions)
        pending_indices = list(range(len(questions)))
        
        for attempt in range(max_retries):
            if not pending_indices:
                break
            
            print(f'[server] Attempt {attempt + 1}/{max_retries}: Processing {len(pending_indices)} questions')
            
            # Build prompts for pending questions
            prompts = []
            for idx in pending_indices:
                prompt = QUESTION_TO_CODE_PROMPT.format(question=questions[idx])
                messages = [{"role": "user", "content": prompt}]
                
                if tokenizer.chat_template:
                    formatted_prompt = tokenizer.apply_chat_template(
                        messages,
                        tokenize=False,
                        add_generation_prompt=True,
                        add_special_tokens=True
                    )
                else:
                    formatted_prompt = prompt
                
                prompts.append(formatted_prompt)
            
            # Generate
            responses = model.generate(prompts, sampling_params=sample_params, use_tqdm=True)
            
            # Process and update
            new_pending = []
            for i, (idx, response) in enumerate(zip(pending_indices, responses)):
                if response and response.outputs:
                    raw_response = response.outputs[0].text
                    extracted_code, _ = extract_code_block(raw_response)
                    
                    if extracted_code:
                        processed_code, _ = postprocess_python(extracted_code)
                        
                        if processed_code:
                            codes[idx] = processed_code
                            continue
                
                # Failed, add to pending for next retry
                new_pending.append(idx)
            
            pending_indices = new_pending
        
        success_count = sum(1 for c in codes if c is not None)
        failed_count = len(codes) - success_count
        
        print(f'[server] Final result: {success_count} success, {failed_count} failed')
        
        return jsonify({
            'codes': codes,
            'success_count': success_count,
            'failed_count': failed_count
        })
        
    except Exception as e:
        print(f'[server] Error: {e}')
        return jsonify({'error': str(e)}), 500


# ------------------------- Main Application Entrypoint --------------------------- #
if __name__ == '__main__':
    print(f'[main] Starting Question-to-Code server on port {args.port}')
    print(f'[main] Sleep mode: {"disabled" if args.no_sleep else "enabled (default)"}')
    app.run(host='0.0.0.0', port=int(args.port), threaded=True)
