#!/usr/bin/env python3

import subprocess
import time
import requests
import json
import os
import signal
import psutil
from pathlib import Path

CONFIG = {
    "LFM2-1.2B-fixed-sft": "./models_to_eval/model-souped-1.2B-fixed-sft",
}

PYTHON_PATH = "/home/name/rag-reasoning/argus/.venv/bin/python"
SERVER_URL = "http://localhost:5001"
BATCH_SIZE = 16
TEMPERATURE = 0.8

def kill_all_vllm_processes():
    print("Killing all VLLM server processes...")
    killed_count = 0
    
    for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
        try:
            cmdline = ' '.join(proc.info['cmdline']) if proc.info['cmdline'] else ''
            # Only kill vllm_server.py processes, not the health check script
            if 'argus_vllm_server.py' in cmdline and 'vllm_healthcheck.py' not in cmdline:
                print(f"Killing process {proc.info['pid']}: {cmdline[:100]}")
                proc.kill()
                killed_count += 1
        except (psutil.NoSuchProcess, psutil.AccessDenied):
            pass
    
    if killed_count > 0:
        print(f"Killed {killed_count} VLLM server processes, waiting for cleanup...")
        time.sleep(10)
    else:
        print("No VLLM server processes found")

def save_generation_to_file(model_name, batch_size, temperature, eos_token, texts, file_path):
    with open(file_path, 'a', encoding='utf-8') as f:
        f.write(f"{model_name} {batch_size} {temperature} {eos_token}\n")
        f.write(texts[0] + "\n")

def start_server(model_name):
    print(f"Starting server for {model_name}...")
    cmd = [PYTHON_PATH, "argus_vllm_server.py", "--model", model_name]
    
    process = subprocess.Popen(
        cmd, 
        stdout=subprocess.PIPE, 
        stderr=subprocess.PIPE,
        preexec_fn=os.setsid
    )
    
    max_wait = 120
    start_time = time.time()
    
    print("Waiting for server to be ready...")
    while time.time() - start_time < max_wait:
        try:
            response = requests.get(f"{SERVER_URL}/health", timeout=3)
            if response.status_code == 200:
                data = response.json()
                if data.get('status') == 'ready':
                    elapsed = time.time() - start_time
                    print(f"Server ready for {model_name} after {elapsed:.1f}s")
                    return process
        except requests.RequestException:
            pass
        
        if process.poll() is not None:
            stdout, stderr = process.communicate()
            print(f"Server process died early for {model_name}")
            print(f"stdout: {stdout.decode()[:500]}")
            print(f"stderr: {stderr.decode()[:500]}")
            return None
        
        time.sleep(3)
    
    print(f"Server startup timeout for {model_name}")
    kill_process_group(process)
    return None

def kill_process_group(process):
    if process and process.poll() is None:
        try:
            pgid = os.getpgid(process.pid)
            os.killpg(pgid, signal.SIGTERM)
            time.sleep(5)
            if process.poll() is None:
                os.killpg(pgid, signal.SIGKILL)
        except (OSError, ProcessLookupError):
            try:
                process.kill()
            except:
                pass

def stop_server(process):
    if not process:
        return
        
    print("Stopping server...")
    kill_process_group(process)
    
    try:
        process.wait(timeout=10)
    except subprocess.TimeoutExpired:
        pass
    
    kill_all_vllm_processes()
    print("Server stopped")

def call_batch_inference(prompts, eos_token, temperature, max_tokens):
    payload = {
        'prompts': prompts,
        'eos_token': [eos_token],
        'temperature': temperature,
        'max_tokens': max_tokens
    }
    
    try:
        response = requests.post(f"{SERVER_URL}/generate", json=payload, timeout=120)
        if response.status_code == 200:
            return response.json()
        else:
            raise Exception(f"Server error: {response.status_code} - {response.text}")
    except requests.RequestException as e:
        raise Exception(f"Request failed: {e}")

def run_model_test(model_name, output_file):
    print(f"\n{'='*80}")
    print(f"PROCESSING: {model_name}")
    print(f"{'='*80}")
    
    result = {
        "model": model_name,
        "status": "FAILED",
        "error": "Unknown error"
    }
    
    server_process = None
    
    try:
        load_start = time.time()
        server_process = start_server(model_name)
        if not server_process:
            result["error"] = "Server startup failed"
            return result
        
        load_end = time.time()
        load_time = load_end - load_start
        
        print(f"Model loading time: {load_time:.2f}s")
        
        user_query = "what is the size of a black-hole based hard-drive"
        
        think_prompts = [f"<user_query>{user_query}</user_query>\n\n<think>" for _ in range(BATCH_SIZE)]
        
        print(f"Starting THINK batch generation (batch_size={BATCH_SIZE})...")
        think_response = call_batch_inference(think_prompts, "</think>", TEMPERATURE, 10)
        
        think_results = think_response['results']
        think_outputs = [result['text'] for result in think_results]
        think_tokens_per_item = [result['tokens'] for result in think_results]
        total_think_tokens = sum(think_tokens_per_item)
        avg_think_tokens = total_think_tokens / BATCH_SIZE
        think_time = think_response['generation_time']
        
        save_generation_to_file(model_name, BATCH_SIZE, TEMPERATURE, "</think>", think_outputs, output_file)
        
        print(f"THINK phase complete: {total_think_tokens} tokens ({avg_think_tokens:.1f} avg), {total_think_tokens/think_time:.1f} t/s")
        
        search_prompts = []
        for think_output in think_outputs:
            search_prompt = f"<user_query>{user_query}</user_query>\n\n<think>{think_output}</think>\n\n<search_query>"
            search_prompts.append(search_prompt)
        
        print(f"Starting SEARCH batch generation (batch_size={BATCH_SIZE})...")
        search_response = call_batch_inference(search_prompts, "</search_query>", TEMPERATURE, 10)
        
        search_results = search_response['results']
        search_outputs = [result['text'] for result in search_results]
        search_tokens_per_item = [result['tokens'] for result in search_results]
        total_search_tokens = sum(search_tokens_per_item)
        avg_search_tokens = total_search_tokens / BATCH_SIZE
        search_time = search_response['generation_time']
        
        save_generation_to_file(model_name, BATCH_SIZE, TEMPERATURE, "</search_query>", search_outputs, output_file)
        
        print(f"SEARCH phase complete: {total_search_tokens} tokens ({avg_search_tokens:.1f} avg), {total_search_tokens/search_time:.1f} t/s")
        
        total_tokens = total_think_tokens + total_search_tokens
        total_inference_time = think_time + search_time
        total_time = load_time + total_inference_time
        
        print(f"\nSUMMARY for {model_name}:")
        print(f"  Batch size: {BATCH_SIZE}")
        print(f"  Loading time: {load_time:.2f}s")
        print(f"  Think: {total_think_tokens} tokens in {think_time:.2f}s ({total_think_tokens/think_time:.1f} t/s)")
        print(f"  Search: {total_search_tokens} tokens in {search_time:.2f}s ({total_search_tokens/search_time:.1f} t/s)")
        print(f"  Total tokens: {total_tokens}")
        print(f"  Total inference time: {total_inference_time:.2f}s")
        print(f"  Total time: {total_time:.2f}s")
        print(f"  Average speed: {total_tokens/total_inference_time:.1f} t/s")
        print(f"  Sequences processed: {BATCH_SIZE * 2} ({BATCH_SIZE * 2 / total_inference_time:.1f} seq/s)")
        print(f"  Avg think tokens: {avg_think_tokens:.1f}")
        print(f"  Avg search tokens: {avg_search_tokens:.1f}")
        
        unload_start = time.time()
        stop_server(server_process)
        server_process = None
        unload_end = time.time()
        unload_time = unload_end - unload_start
        
        print(f"  Unloading time: {unload_time:.2f}s")
        
        result = {
            "model": model_name,
            "status": "SUCCESS",
            "batch_size": BATCH_SIZE,
            "load_time": load_time,
            "unload_time": unload_time,
            "think_tokens": total_think_tokens,
            "think_time": think_time,
            "think_tps": total_think_tokens/think_time,
            "search_tokens": total_search_tokens,
            "search_time": search_time,
            "search_tps": total_search_tokens/search_time,
            "total_tokens": total_tokens,
            "total_inference_time": total_inference_time,
            "total_time": total_time,
            "avg_tps": total_tokens/total_inference_time,
            "sequences_per_second": (BATCH_SIZE * 2) / total_inference_time,
            "avg_think_tokens": avg_think_tokens,
            "avg_search_tokens": avg_search_tokens
        }
        
    except Exception as e:
        print(f"ERROR in {model_name}: {e}")
        result["error"] = str(e)
    finally:
        if server_process:
            stop_server(server_process)
    
    return result

def main():
    print("Starting sequential model processing with VLLM server...")
    print(f"Test configuration: batch_size={BATCH_SIZE}, models={len(CONFIG)}")
    print("="*80)
    
    kill_all_vllm_processes()
    
    output_file = "generations.txt"
    
    if os.path.exists(output_file):
        os.remove(output_file)
        print(f"Cleared existing {output_file}")
    
    overall_start = time.time()
    results = []
    
    for model_name in CONFIG.keys():
        result = run_model_test(model_name, output_file)
        results.append(result)
        print(f"\nCompleted {result['model']}: {result['status']}")
        
        if result['status'] == 'FAILED':
            print(f"Error: {result['error']}")
    
    overall_end = time.time()
    overall_time = overall_end - overall_start
    
    print(f"\n\n{'='*120}")
    print("FINAL BATCH INFERENCE RESULTS:")
    print(f"{'='*120}")
    
    successful_results = [r for r in results if r["status"] == "SUCCESS"]
    failed_results = [r for r in results if r["status"] == "FAILED"]
    
    if successful_results:
        print(f"\nSUCCESSFUL MODELS ({len(successful_results)}):")
        header = f"{'Model':<15} {'Load(s)':<8} {'Think t/s':<10} {'Search t/s':<11} {'Total t/s':<10} {'Seq/s':<8} {'Avg Th':<8} {'Avg Sr':<8}"
        print(header)
        print("-" * len(header))
        
        for result in successful_results:
            row = (f"{result['model']:<15} {result['load_time']:<8.1f} "
                   f"{result['think_tps']:<10.1f} {result['search_tps']:<11.1f} {result['avg_tps']:<10.1f} "
                   f"{result['sequences_per_second']:<8.1f} {result['avg_think_tokens']:<8.1f} {result['avg_search_tokens']:<8.1f}")
            print(row)
    
    if failed_results:
        print(f"\nFAILED MODELS ({len(failed_results)}):")
        for result in failed_results:
            print(f"  {result['model']}: {result.get('error', 'Unknown error')}")
    
    print(f"\nOVERALL TIMING:")
    print(f"  Total execution time: {overall_time:.2f} seconds")
    print(f"  Models processed: {len(results)}")
    print(f"  Success rate: {len(successful_results)}/{len(results)} ({100*len(successful_results)/len(results):.1f}%)")
    
    if successful_results:
        total_sequences = sum(r.get('batch_size', 0) * 2 for r in successful_results)
        avg_tps = sum(r.get('avg_tps', 0) for r in successful_results) / len(successful_results)
        print(f"  Total sequences processed: {total_sequences}")
        print(f"  Average throughput across models: {avg_tps:.1f} t/s")
        print(f"  Average sequences per second: {total_sequences/overall_time:.1f}")
    
    print(f"\nSample generations saved to: {output_file}")
    print("Each model contributes 2 entries: one </think> and one </search_query> example")

if __name__ == "__main__":
    main()
