import json
import time
import argparse
import importlib
import yaml
from pathlib import Path
import os
import traceback
import warnings
import sys

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["TORCH_USE_CUDA_DSA"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)

def try_lock(req_file):
    locked = req_file.with_suffix(".json.locked")
    try:
        req_file.rename(locked)
        return locked
    except (FileNotFoundError, OSError):
        return None

def unlock(locked_file):
    req_file = locked_file.with_suffix(".json")
    try:
        locked_file.rename(req_file)
        return req_file
    except (FileNotFoundError, OSError):
        return None 

def put_request(data, req_path):
    tmp_file = req_path.with_suffix(".json.tmp")
    with open(tmp_file, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False)
        f.flush()
    tmp_file.rename(req_path)

def load_model(config_path, model_name):
    with open(config_path, encoding="utf-8") as f:
        config = yaml.safe_load(f)
    model_cfg = config["models"][model_name]
    mod_name, cls_name = model_cfg["class"].rsplit(".", 1)
    model_cls = getattr(importlib.import_module(mod_name), cls_name)
    model = model_cls(**model_cfg["args"])
    return model

def get_error_msg(e):
    return (
        f"[ERROR] Type: {type(e).__name__}\n"
        f"Message: {str(e)}\n"
        f"Traceback:\n{traceback.format_exc()}"
    )

def process_batch(model, locked_files, responses_dir, histories_dir):
    status_code = 200

    messages_list = []
    for f in locked_files:
        with open(f, encoding="utf-8") as req_f:
            messages_list.append(json.load(req_f))

    print(f"🤖 Read {len(messages_list)} messages.", file=sys.stderr)
    t0 = time.time()
    try:
        responses = [
            res if isinstance(res, dict) else {"response": res} 
            for res in model.batch_generate(messages_list)
        ]
    except Exception as e:
        error_msg = get_error_msg(e)
        responses = [{"response": error_msg}] * len(messages_list)
        status_code = 502

    t1 = time.time()
    runtime = round((t1 - t0) * 1000, 2)
    for f, msg, resp in zip(locked_files, messages_list, responses):
        out_file = responses_dir / f.stem
        put_request(resp | {"status_code": status_code, "runtime_ms": runtime}, out_file)

        his_file = histories_dir / f.stem
        put_request({
            "timestamp": time.time(),
            "CUDA_VISIBLE_DEVICES": os.environ.get("CUDA_VISIBLE_DEVICES"),
            "PID": os.getpid(),
            "output": resp,
            "seconds": runtime / 1000,
            "batch_size": len(messages_list),
            "input": msg,
        }, his_file)

    print(f"✅ Processed {len(locked_files)} requests.", file=sys.stderr)

    for f in locked_files:
        f.unlink(missing_ok=True)

    if status_code != 200:
        print("❌ Critical error in batch_generate. Restarting gracefully.", file=sys.stderr)
        import subprocess
        subprocess.Popen([sys.executable] + sys.argv)
        sys.exit(1)

def model_backend(config, model_name, model_path, output_dir, batch_size=4, poll_interval=0.1):
    output_dir = Path(output_dir).resolve()
    requests_dir = output_dir / "requests"
    responses_dir = output_dir / "responses"
    histories_dir = output_dir / "histories"
    requests_dir.mkdir(parents=True, exist_ok=True)
    responses_dir.mkdir(parents=True, exist_ok=True)
    histories_dir.mkdir(parents=True, exist_ok=True)

    model = load_model(config, model_name)
    model.model_path = model_path.replace("'", "").replace('"', '') or model.model_path
    model.initialize_model()
    model.lightweight_gpu_reset()
    print(f"🚀 Model {model_name} loaded with {model.model_path}: Watching {requests_dir}", file=sys.stderr)

    while True:
        locked_files = []
        done = False

        for req_file in requests_dir.glob("*.json"):
            if len(locked_files) >= batch_size:
                break
            if req_file.stem == "done":
                done = True
                break

            locked = req_file.with_suffix(".json.locked")
            try:
                req_file.rename(locked)
                locked_files.append(locked)
            except (FileNotFoundError, OSError):
                continue

        if locked_files:
            process_batch(model, locked_files, responses_dir, histories_dir)

        if done:
            break

        time.sleep(poll_interval)

    print("Exit.", file=sys.stderr)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, default="config/stream_config.yaml")
    parser.add_argument("--model-name", required=True)
    parser.add_argument("--model-path", type=str, default="")
    parser.add_argument("--output-dir", required=True)
    parser.add_argument("--batch-size", type=int, default=4)
    parser.add_argument("--poll-interval", type=float, default=0.1)
    args = parser.parse_args()

    model_backend(
        config=args.config,
        model_name=args.model_name,
        model_path=args.model_path,
        output_dir=args.output_dir,
        batch_size=args.batch_size,
        poll_interval=args.poll_interval,
    )

if __name__ == "__main__":
    main()
