import os
import ray
import noise_embed.torch_utils as torch_utils
import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, StreamingResponse
from typing import List, Dict, Any, Optional
from vllm import LLM, SamplingParams, AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.utils import random_uuid
from transformers import AutoTokenizer

from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, get_masked_input_and_mask
from vllm.distributed import tensor_model_parallel_all_reduce

import time

def patch_vocab_embedding(noise_std):
    if not hasattr(VocabParallelEmbedding, 'original_forward'):
        VocabParallelEmbedding.original_forward = VocabParallelEmbedding.forward

    def noisy_forward(self, input_):
        if self.tp_size > 1:
            masked_input, input_mask = get_masked_input_and_mask(
                input_, self.shard_indices.org_vocab_start_index,
                self.shard_indices.org_vocab_end_index,
                self.shard_indices.num_org_vocab_padding,
                self.shard_indices.added_vocab_start_index,
                self.shard_indices.added_vocab_end_index)
        else:
            masked_input = input_
        
        output_parallel = self.quant_method.embedding(self, masked_input.long())
        
        noisy_ids = {128000, 128001, 128008, 128009, 151645, 151643}
        noise_mask = torch_utils.isin(input_, torch_utils.tensor(list(noisy_ids), device=input_.device))
        
        if noise_mask.any():
            noise = torch_utils.randn_like(output_parallel) * noise_std
            output_parallel = torch_utils.where(noise_mask.unsqueeze(-1), output_parallel + noise, output_parallel)
            
        if self.tp_size > 1:
            output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
            
        output = tensor_model_parallel_all_reduce(output_parallel)
        return output
    
    VocabParallelEmbedding.forward = noisy_forward
    print("VocabParallelEmbedding.forward has been successfully patched with noisy_forward.")

@ray.remote(num_gpus=1)
class VLLMActor:
    def __init__(self, model_path: str, use_noise: bool = False, noise_std: float = 0.01):
        if use_noise:
            patch_vocab_embedding(noise_std)
        
        self.engine = AsyncLLMEngine.from_engine_args(
            AsyncEngineArgs(
                model=model_path,
                enforce_eager=True,
                tensor_parallel_size=1,
                gpu_memory_utilization=0.90,
                max_model_len=20000,
                dtype="bfloat16"
            )
        )
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)

    async def generate(self, request_body: Dict[str, Any]):
        prompt = self.tokenizer.apply_chat_template(
            conversation=request_body["messages"],
            tokenize=False,
            add_generation_prompt=True
        )
        
        sampling_params = SamplingParams(
            n=request_body.get("n", 1),
            temperature=request_body.get("temperature", 0.7),
            top_p=request_body.get("top_p", 1.0),
            max_tokens=request_body.get("max_tokens", 1024),
            stop=request_body.get("stop", None)
        )
        
        request_id = f"chatcmpl-{random_uuid()}"
        results_generator = self.engine.generate(prompt, sampling_params, request_id)
        
        final_output = None
        async for request_output in results_generator:
            final_output = request_output
        
        return final_output, request_id

app = FastAPI()

actor_handles: Dict[str, "ray.actor.ActorHandle"] = {}

MODEL_CONFIGS = [
    {"served_name": "llama3_base_ins-1", "path": "llama3_base_ins", "use_noise": False, "gpu_id": 0},
    {"served_name": "llama3_base_ins-2", "path": "llama3_base_ins", "use_noise": False, "gpu_id": 1},
    {"served_name": "llama3_base_ins-3", "path": "llama3_base_ins", "use_noise": False, "gpu_id": 2},
    {"served_name": "llama3_base_ins-4", "path": "llama3_base_ins", "use_noise": False, "gpu_id": 3},
    {"served_name": "llama3_base_ins-5", "path": "llama3_base_ins", "use_noise": False, "gpu_id": 4},
    {"served_name": "llama3_base_ins-6", "path": "llama3_base_ins", "use_noise": False, "gpu_id": 5},
    {"served_name": "llama3_base_ins-7", "path": "llama3_base_ins", "use_noise": False, "gpu_id": 6},
    {"served_name": "llama3_base_ins-8", "path": "llama3_base_ins", "use_noise": False, "gpu_id": 7},
]

@app.on_event("startup")
def startup_event():
    custom_resources = {f"GPU_{i}": 1 for i in range(torch_utils.cuda.device_count())}
    ray.init(num_gpus=torch_utils.cuda.device_count(), resources=custom_resources, ignore_reinit_error=True)

    for config in MODEL_CONFIGS:
        actor_options = {"num_gpus": 1, "resources": {f"GPU_{config['gpu_id']}": 1}}
        
        handle = VLLMActor.options(**actor_options).remote(
            model_path=config["path"],
            use_noise=config["use_noise"],
            noise_std=config.get("noise_std", 0.01)
        )
        actor_handles[config["served_name"]] = handle

@app.post("/v1/chat/completions")
async def create_chat_completion(request: Request):
    request_body = await request.json()
    model_name = request_body.get("model")
    
    if model_name not in actor_handles:
        return JSONResponse({"error": f"Model '{model_name}' not found."}, status_code=404)
        
    actor = actor_handles[model_name]
    final_output, request_id = await actor.generate.remote(request_body)

    choices = []
    for output in final_output.outputs:
        choices.append({
            "index": output.index,
            "message": {
                "role": "assistant",
                "content": output.text,
            },
            "finish_reason": output.finish_reason,
        })
    
    num_prompt_tokens = len(final_output.prompt_token_ids)
    num_generated_tokens = sum(len(output.token_ids) for output in final_output.outputs)
    
    response = {
        "id": request_id,
        "object": "chat.completion",
        "created": int(time.time()),
        "model": model_name,
        "choices": choices,
        "usage": {
            "prompt_tokens": num_prompt_tokens,
            "completion_tokens": num_generated_tokens,
            "total_tokens": num_prompt_tokens + num_generated_tokens,
        }
    }
    return JSONResponse(response)

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")
