# env : atom_slicer

import os
import sys

notebook_dir = os.path.dirname(os.path.abspath("__file__"))

slicer_root = os.path.dirname(notebook_dir)

if slicer_root not in sys.path:
    sys.path.append(slicer_root)


from SLICER.SLICER_opti import *
from SLICER.SLICER_config import *
from SLICER.SLICER_utils import *

from LLM_models.datautils import *
from LLM_models.evaluation import *

import os
import sys
import time
import json
import psutil

import torch
import uvicorn
from fastapi import FastAPI, Request

device = torch.device("cuda:0")

datasets_list = ["boolq"]
# tokenizer_name="/data/models--microsoft--Phi-3-mini-4k-instruct/snapshots/0a67737cc96d2554230f90338b163bc6380a2a85"
# model_name = "/data/models--microsoft--Phi-3-mini-4k-instruct/snapshots/0a67737cc96d2554230f90338b163bc6380a2a85"

datasets = get_dataset_dataset_processors(datasets_list)
# server_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(device)
# server_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
# load_llama_model_and_tokenizer(model_name,tokenizer_name,device)

tokenizer_name = "meta-llama/Llama-2-7b-hf"
model_name = "/data/llm_models"
datasets = get_dataset_dataset_processors(datasets_list)
server_model, server_tokenizer = load_llama_model_and_tokenizer(
    model_name, tokenizer_name, device
)


app = FastAPI()
import io

SERVER_STATS = {
    "requests_count": 0,
    "total_tokens_generated": 0,
    "call_logs": [], 
}


def get_current_gpu_usage():
    mem_alloc = torch.cuda.memory_allocated(device)
    mem_alloc_mb = mem_alloc / (1024 * 1024)
    return mem_alloc_mb


def get_cpu_usage_percent():
    return psutil.cpu_percent(interval=None)


def move_to_device(obj):
    if isinstance(obj, dict):
        for k, v in obj.items():
            obj[k] = move_to_device(v)
        return obj
    elif isinstance(obj, list):
        return [move_to_device(x) for x in obj]
    elif isinstance(obj, torch.Tensor):
        return obj.to(device)
    else:
        return obj


@app.post("/continue_generation")
async def continue_generation(request: Request):
    raw_body = await request.body()
    start_time = time.time()
    buf = io.BytesIO(raw_body)

    all_data_dict = torch.load(buf, map_location="cpu")

    all_data_dict = move_to_device(all_data_dict)

    # encoded_pkv_tensors = all_data_dict["encoded_pkv"]
    input_ids = all_data_dict["input_ids"]
    max_new_tokens = all_data_dict["max_new_tokens"]

    # past_key_values = decode_past_key_values_optimized(encoded_pkv_tensors)
    past_key_values = None

    start_gpu_mem = get_current_gpu_usage()
    start_cpu_percent = get_cpu_usage_percent()

    generated_ids = input_ids.clone()
    initial_seq_len = generated_ids.size(1)

    for _ in range(max_new_tokens):
        cur_input = generated_ids if past_key_values is None else generated_ids[:, -1:]

        with torch.inference_mode():
            outputs = server_model(
                input_ids=cur_input,
                past_key_values=past_key_values,
                use_cache=True,
            )

        next_token_logits = outputs.logits[:, -1, :]
        past_key_values = outputs.past_key_values

        next_token_id = torch.multinomial(
            torch.nn.functional.softmax(next_token_logits / 0.8, dim=-1), num_samples=1
        )

        generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)

        if next_token_id.item() == server_tokenizer.eos_token_id:
            break

    final_seq_len = generated_ids.size(1)
    tokens_generated_here = final_seq_len - initial_seq_len

    end_time = time.time()
    end_gpu_mem = get_current_gpu_usage()
    end_cpu_percent = get_cpu_usage_percent()

    duration = end_time - start_time
    mem_diff = end_gpu_mem - start_gpu_mem

    SERVER_STATS["requests_count"] += 1
    SERVER_STATS["total_tokens_generated"] += tokens_generated_here
    SERVER_STATS["call_logs"].append(
        {
            "start_time": start_time,
            "end_time": end_time,
            "duration_sec": duration,
            "tokens_generated": tokens_generated_here,
            "start_gpu_mem_MB": start_gpu_mem,
            "end_gpu_mem_MB": end_gpu_mem,
            "gpu_mem_diff_MB": mem_diff,
            "start_cpu_percent": start_cpu_percent,
            "end_cpu_percent": end_cpu_percent,
        }
    )

    generated_text = server_tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    return {"generated_text": generated_text}


@app.post("/raw_continue_generation")
async def raw_continue_generation(request: Request):
    data = await request.json()

    partial_text = data["partial_text"]
    max_new_tokens = data.get("max_new_tokens", 50)

    start_gpu_mem = get_current_gpu_usage()
    start_cpu_percent = get_cpu_usage_percent()

    start_time = time.time()
    input_ids = server_tokenizer.encode(partial_text, return_tensors="pt").to(device)

    initial_seq_len = input_ids.size(1)
    generated_ids = input_ids.clone()
    past_key_values = None

    for _ in range(max_new_tokens):
        with torch.inference_mode():
            cur_input = (
                generated_ids if past_key_values is None else generated_ids[:, -1:]
            )
            outputs = server_model(
                input_ids=cur_input,
                past_key_values=past_key_values,
                use_cache=True,
            )
        logits = outputs.logits
        past_key_values = outputs.past_key_values
        next_token_logits = logits[:, -1, :]
        next_token_id = torch.multinomial(
            torch.nn.functional.softmax(next_token_logits / 0.8, dim=-1), num_samples=1
        )

        generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)

        if next_token_id.item() == server_tokenizer.eos_token_id:
            break

    final_seq_len = generated_ids.size(1)
    tokens_generated_here = final_seq_len - initial_seq_len

    end_time = time.time()
    end_gpu_mem = get_current_gpu_usage()
    end_cpu_percent = get_cpu_usage_percent()

    duration = end_time - start_time
    mem_diff = end_gpu_mem - start_gpu_mem

    SERVER_STATS["requests_count"] += 1
    SERVER_STATS["total_tokens_generated"] += tokens_generated_here
    SERVER_STATS["call_logs"].append(
        {
            "start_time": start_time,
            "end_time": end_time,
            "duration_sec": duration,
            "tokens_generated": tokens_generated_here,
            "start_gpu_mem_MB": start_gpu_mem,
            "end_gpu_mem_MB": end_gpu_mem,
            "gpu_mem_diff_MB": mem_diff,
            "start_cpu_percent": start_cpu_percent,
            "end_cpu_percent": end_cpu_percent,
        }
    )

    generated_text = server_tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    return {"generated_text": generated_text}


@app.post("/just_call")
async def just_call(request: Request):
    data = await request.json()

    start_time = time.time()
    start_gpu_mem = get_current_gpu_usage()
    start_cpu_percent = get_cpu_usage_percent()

    end_time = time.time()
    end_gpu_mem = get_current_gpu_usage()
    end_cpu_percent = get_cpu_usage_percent()

    duration = 0.0001
    mem_diff = 0.0001

    SERVER_STATS["requests_count"] += 1
    SERVER_STATS["total_tokens_generated"] += 0
    SERVER_STATS["call_logs"].append(
        {
            "start_time": start_time,
            "end_time": end_time,
            "duration_sec": duration,
            "tokens_generated": 0,
            "start_gpu_mem_MB": start_gpu_mem,
            "end_gpu_mem_MB": end_gpu_mem,
            "gpu_mem_diff_MB": mem_diff,
            "start_cpu_percent": start_cpu_percent,
            "end_cpu_percent": end_cpu_percent,
        }
    )

    return {"generated_text": data["partial_text"]}


@app.get("/server_stats")
def get_server_stats():
    total_reqs = SERVER_STATS["requests_count"]
    total_tokens = SERVER_STATS["total_tokens_generated"]
    total_duration = sum(log["duration_sec"] for log in SERVER_STATS["call_logs"])

    if total_reqs > 0:
        avg_req_time = total_duration / total_reqs
    else:
        avg_req_time = 0.0

    if total_tokens > 0:
        time_per_token = total_duration / total_tokens
    else:
        time_per_token = 0.0

    if total_duration > 0:
        tokens_per_second = total_tokens / total_duration
    else:
        tokens_per_second = 0.0

    avg_gpu_mem_start = (
        (sum(log["start_gpu_mem_MB"] for log in SERVER_STATS["call_logs"]) / total_reqs)
        if total_reqs > 0
        else 0.0
    )
    avg_gpu_mem_end = (
        (sum(log["end_gpu_mem_MB"] for log in SERVER_STATS["call_logs"]) / total_reqs)
        if total_reqs > 0
        else 0.0
    )

    summary = {
        "requests_count": total_reqs,
        "total_tokens_generated": total_tokens,
        "total_duration_sec": total_duration,
        "avg_request_time_sec": avg_req_time,
        "tokens_per_second": tokens_per_second,
        "time_per_token_sec": time_per_token,
        "avg_gpu_mem_start_MB": avg_gpu_mem_start,
        "avg_gpu_mem_end_MB": avg_gpu_mem_end,
    }

    with open("server_ours_overhead.json", "w") as f:
        json.dump(
            {"summary": summary, "call_logs": SERVER_STATS["call_logs"]}, f, indent=4
        )
    SERVER_STATS["requests_count"] = 0
    SERVER_STATS["total_tokens_generated"] = 0
    SERVER_STATS["call_logs"] = []

    return {"summary": summary, "detail_logs": SERVER_STATS["call_logs"]}


def run_server():
    uvicorn.run(app, host="0.0.0.0", port=8000)


if __name__ == "__main__":
    run_server()
