#!/usr/bin/env python3

import os
import argparse
import torch
from flask import Flask, request, jsonify
from vllm import LLM, SamplingParams
import time

CONFIG = {
    # "LFM2-350M": "./models/models-sft-copied/model_souped/model_souped/LFM2-350M/epoch_2",
    # "LFM2-700M": "./models/models-sft-copied/model_souped/model_souped/LFM2-700M/epoch_2",
    "LFM2-1.2B-fixed-sft": "./models_to_eval/model-souped-1.2B-fixed-sft",
}

app = Flask(__name__)

class VLLMServer:
    def __init__(self, model_path, model_max_tokens):
        self.model_path = model_path
        self.model_max_tokens = model_max_tokens
        self.model = None
        self.tokenizer = None
        self.parallelism_config = None
        
    def load_model(self):
        os.environ["CUDA_VISIBLE_DEVICES"] = "3,4,5,6,7"
        
        try:
            self.model = LLM(
                model=self.model_path,
                gpu_memory_utilization=0.85,
                max_model_len=self.model_max_tokens,
                dtype="bfloat16",
                tensor_parallel_size=1,
                pipeline_parallel_size=4,
                max_num_seqs=256
            )
            self.parallelism_config = "Pipeline (TP=1, PP=4)"
            self._load_tokenizer()
            return True
            
        except Exception as e:
            print(f"Pipeline parallelism failed: {e}")
            
        for tp_size in [1]:
            try:
                self.model = LLM(
                    model=self.model_path,
                    gpu_memory_utilization=0.85,
                    max_model_len=self.model_max_tokens,
                    dtype="bfloat16",
                    tensor_parallel_size=tp_size,
                    max_num_seqs=256
                )
                self.parallelism_config = f"Tensor (TP={tp_size})"
                self._load_tokenizer()
                return True
                
            except Exception as e:
                print(f"TP={tp_size} failed: {e}")
                
        print(f"All parallelism strategies failed")
        return False
        
    def _load_tokenizer(self):
        try:
            from transformers import AutoTokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)
        except Exception as e:
            print(f"Warning: Failed to load tokenizer: {e}")
        
    def _convert_eos_to_token_ids(self, eos_token):
        if self.tokenizer is None:
            return None
        try:
            token_ids = self.tokenizer.encode(eos_token, add_special_tokens=False)
            if len(token_ids) == 1:
                return token_ids[0]
            else:
                return token_ids
        except Exception as e:
            print(f"Warning: EOS token conversion failed: {e}")
            return None
        
    def batch_inference(self, prompts, eos_token, temperature=0.8, max_tokens=512):
        if self.model is None:
            raise RuntimeError("Model not loaded")
            
        if not isinstance(prompts, list):
            prompts = [prompts]
        
        batch_size = len(prompts)

        if self.tokenizer is not None:
            truncated_prompts = []
            max_input_tokens = self.model_max_tokens - max_tokens - 50  # Leave room for generation + safety margin
            
            for prompt in prompts:
                tokens = self.tokenizer.encode(prompt)
                if len(tokens) > max_input_tokens:
                    # Truncate from left to keep most recent context
                    truncated_tokens = tokens[-max_input_tokens:]
                    truncated_prompt = self.tokenizer.decode(truncated_tokens, skip_special_tokens=True)
                    truncated_prompts.append(truncated_prompt)
                else:
                    truncated_prompts.append(prompt)
            prompts = truncated_prompts
        
        eos_token_id = self._convert_eos_to_token_ids(eos_token)
        
        if eos_token_id is not None:
            sampling_params = SamplingParams(
                temperature=temperature,
                max_tokens=max_tokens,
                stop_token_ids=[eos_token_id] if isinstance(eos_token_id, int) else eos_token_id,
                skip_special_tokens=False,
                include_stop_str_in_output=True
            )
        else:
            sampling_params = SamplingParams(
                temperature=temperature,
                max_tokens=max_tokens,
                stop=[eos_token],
                skip_special_tokens=False,
                include_stop_str_in_output=True
            )
        
        start_time = time.time()
        outputs = self.model.generate(prompts, sampling_params)
        end_time = time.time()
        
        results = []
        total_tokens = 0
        for i, output in enumerate(outputs):
            generated_text = output.outputs[0].text
            output_tokens = len(generated_text.split())
            total_tokens += output_tokens
            results.append({
                'text': generated_text,
                'tokens': output_tokens,
                'time': end_time - start_time
            })
            
        avg_tokens = total_tokens / batch_size
        generation_time = end_time - start_time
        
        return {
            'results': results,
            'total_tokens': total_tokens,
            'avg_tokens': avg_tokens,
            'generation_time': generation_time,
            'throughput_tps': total_tokens/generation_time,
            'throughput_sps': batch_size/generation_time
        }

server_instance = None

@app.route('/batch_inference', methods=['POST'])
def batch_inference():
    global server_instance
    
    if server_instance is None or server_instance.model is None:
        return jsonify({'error': 'Model not loaded'}), 500
    
    data = request.json
    prompts = data.get('prompts', [])
    eos_token = data.get('eos_token', '</s>')
    temperature = data.get('temperature', 0.8)
    max_tokens = data.get('max_tokens', 512)
    
    if not prompts:
        return jsonify({'error': 'No prompts provided'}), 400
    
    try:
        result = server_instance.batch_inference(prompts, eos_token, temperature, max_tokens)
        return jsonify(result)
    except Exception as e:
        return jsonify({'error': str(e)}), 500

@app.route('/health', methods=['GET'])
def health():
    global server_instance
    status = 'ready' if server_instance and server_instance.model else 'not_ready'
    return jsonify({'status': status})

def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, required=True, choices=list(CONFIG.keys()))
    parser.add_argument('--model_max_tokens', type=int, default=512)
    parser.add_argument('--port', type=int, default=5001)
    parser.add_argument('--host', type=str, default='0.0.0.0')
    return parser.parse_args()

def main():
    global server_instance
    
    args = parse_arguments()
    
    model_path = CONFIG[args.model]
    
    print(f"Loading model: {args.model}")
    print(f"Model path: {model_path}")
    print(f"Max tokens: {args.model_max_tokens}")
    
    server_instance = VLLMServer(model_path, args.model_max_tokens)
    
    if not server_instance.load_model():
        print("Failed to load model")
        return
    
    print(f"Model loaded successfully with {server_instance.parallelism_config}")
    print(f"Starting server on {args.host}:{args.port}")
    
    app.run(host=args.host, port=args.port, debug=False)

if __name__ == '__main__':
    main()