import os
import subprocess

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["TORCH_USE_CUDA_DSA"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import warnings

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)

import argparse
import json
import time
from pathlib import Path
from tqdm import tqdm
import os
import sys
import multiprocessing as mp
import yaml
import importlib
import datetime
from collections import deque
from typing import Dict, List, Any, Optional, Tuple


from video_utils.trim_video import trim_video
from time_utils.time_utils import time_to_seconds, seconds_to_time

def put_request(data, req_path):
    tmp_file = req_path.with_suffix(".json.tmp")
    with open(tmp_file, "w", encoding="utf-8") as f:
        json.dump(data, f, ensure_ascii=False)
        f.flush()
    tmp_file.rename(req_path)

class Streaming:
    def __init__(self, sample: Dict[str, Any]):
        self.stream_addr = Path(sample.get("stream_addr", ""))
        self.video_uuid = sample.get("uuid", "")
        self.video_path = sample.get("video", "")
        self.duration = sample.get("video_info", {}).get("duration", 0)
        self.events = sorted(
            sample.get("sqa", []), key=lambda x: time_to_seconds(x["timestamp"])
        )
        self.events_queue = deque(self.events)
        self.current_time = 0.0
        self.processed_events = []
        self.last_event_time = (
            time_to_seconds(self.events[-1]["timestamp"]) if self.events else 0.0
        )
        self.is_finished = False

    def get_current_events(self) -> List[Dict[str, Any]]:
        current_events = []
        while (
            self.events_queue
            and time_to_seconds(self.events_queue[0]["timestamp"]) <= self.current_time
        ):
            event = self.events_queue.popleft()
            current_events.append(event)
            self.processed_events.append(event)
        return current_events

    def get_video_chunk(self, chunk_path, start_time, end_time):
        if os.path.exists(chunk_path):
            return chunk_path
        try:
            trim_video(
                video_path=self.video_path,
                trim_path=chunk_path,
                start_time=start_time,
                end_time=end_time,
                temp_dir=str(Path(chunk_path).parent / "tmp")
            )
        except Exception as e:
            print(e)
            pass
        return chunk_path
    
    def step(self, time_increment: float = 1.0) -> Dict[str, Any]:
        valid_chunk = self.current_time < self.duration

        prev_ts = seconds_to_time(int(self.current_time))
        self.current_time += time_increment
        current_ts = seconds_to_time(int(self.current_time))
        chunk_file = f"video_{self.video_uuid}_{prev_ts}_{current_ts}.mp4".replace(
            ":", ""
        )
        new_events = self.get_current_events()
        self.is_finished = (self.current_time >= int(self.duration)) or (
            self.current_time >= self.last_event_time + 10.0
        )

        return {
            "current_timestamp": current_ts,
            "stream_chunk": self.get_video_chunk(str(self.stream_addr / chunk_file), prev_ts, current_ts) if valid_chunk else None,
            "new_events": new_events,
            "is_finished": self.is_finished,
        }

class Actor:
    def __init__(
        self,
        config: Dict[str, Any],
        model_name: str,
        prompts_name: str,
        model_path: Optional[str],
        sample: Dict[str, Any],
        work_dir: Path,
        bench: str,
        sparse_mode: bool,
        active_window: int,
        max_retries: int,
    ):
        self.REQ_DIR = work_dir / "requests"
        self.RES_DIR = work_dir / "responses"
        
        self._sample = sample
        self.model_name = model_name
        self.work_dir = work_dir
        self.bench = bench
        self.sparse_mode = sparse_mode
        self.active_window = active_window
        self.max_retries = max_retries
        self.answer_window = -1

        model_cfg = config["models"][model_name]
        mod_name, cls_name = model_cfg["class"].rsplit(".", 1)
        model_cls = getattr(importlib.import_module(mod_name), cls_name)
        self.model = model_cls(**model_cfg["args"])
        self.model.model_path = model_path or getattr(self.model, "model_path", None)

        prompt_cfg = config["prompts"][prompts_name]
        self.sys_mod = importlib.import_module(prompt_cfg["system_class"])
        self.usr_mod = importlib.import_module(prompt_cfg["user_class"])

        self.streaming = Streaming(sample)
        self.responses: List[Dict[str, str]] = []
        self.req_id: Optional[str] = None

        sys_msg = {
            "role": "system",
            "content": [{"type": "text", "text": self.sys_mod.get_prompt()}],
        }
        self.session_id = self.model.new_session(sys_msg)

        self.REQ_DIR.mkdir(parents=True, exist_ok=True)
        self.RES_DIR.mkdir(parents=True, exist_ok=True)

    def step(self) -> bool:
        try:
            if self.req_id:
                res_path = self.RES_DIR / f"{self.req_id}.json"
                if self._handle_response(res_path):
                    self.req_id = None
                    if self.responses[-1]["status_code"] != 200:
                        return True
                else:
                    return False

            step_res = self.streaming.step()
            chunk = step_res["stream_chunk"]
            new_events = step_res["new_events"]
            is_finished = step_res["is_finished"]

            content = []
            if chunk:
                content.append({"type": "video", "video": chunk})
            if new_events:
                act_w = -1
                contain_response = lambda ev: ev.get("response")
                contain_question = lambda ev: ev.get("question")

                for ev in new_events:
                    prompt = self.usr_mod.get_prompt(ev)
                    if prompt:
                        content.append({"type": "text", "text": prompt})

                    if contain_question(ev) and contain_response(ev):
                        act_w = max(1, act_w)
                    elif contain_question(ev) or contain_response(ev):
                        act_w = max(self.active_window, act_w)
                self.answer_window = max(act_w, self.answer_window)
            if content:
                self.model.add_chunk({"role": "user", "content": content})
                self.answer_window = max(-1, self.answer_window - 1)
                if (not self.sparse_mode) or (self.answer_window >= 0):
                    self.req_id = self.mock_generate(self.model.context, self.model.current_context_video_info)
                    return False
            return is_finished

        except Exception as e:
            self.responses.append({
                "timestamp": seconds_to_time(int(self.streaming.current_time)), 
                "response": f"[ERROR] {str(e)}"
            })
            return True

    def mock_generate(self, messages, video_info=None):
        content = {
            "messages": messages,
            "video_info": video_info
        }

        req_id = f"{self.session_id}_{int(self.streaming.current_time)}"

        req_path = self.REQ_DIR / f"{req_id}.json"

        put_request(content, req_path=req_path)

        return req_id

    def _handle_response(self, res_path: Path) -> bool:
        if not res_path.exists():
            return False

        with open(res_path, "r") as f:
            read_out = json.load(f)
            response = read_out.get("response", "[ERROR] No Response")
            raw_response = read_out.get("raw_response", response)
            status_code = int(read_out.get("status_code", 404))
        res_path.unlink(missing_ok=True)

        if status_code != 200 and self.max_retries > 0:
            self.req_id = self.mock_generate(self.model.context, self.model.current_context_video_info)
            self.max_retries -= 1
            return False

        if status_code == 200:
            is_silent = getattr(self.sys_mod, "is_silent_response", lambda _: False)(response)
            if not is_silent:
                self.model.add_chunk(
                    {"role": "assistant", "content": [{"type": "text", "text": response}]}
                )

        self.responses.append(read_out | {
            "timestamp": seconds_to_time(int(self.streaming.current_time)),
            "status_code": status_code,
            "response": response,
            "raw_response": raw_response,
        })

        return True

    def get_result(self) -> Dict[str, Any]:
        return self._sample | {"responses": self.responses, "bench": self.bench}

def actor_worker(kwargs_list: List[Dict[str, Any]]):
        actors = [Actor(**kwargs) for kwargs in kwargs_list]
        returns = []

        status = dict((actor.session_id, False) for actor in actors)
        while not all(status.values()):
            for actor in actors:
                if status[actor.session_id]:
                    continue
                status[actor.session_id] = actor.step()
                if status[actor.session_id]:
                    write_dir = actor.work_dir / actor.bench
                    write_dir.mkdir(parents=True, exist_ok=True)
                    with open(write_dir / f"{actor.session_id}.json", "w", encoding="utf-8") as f:
                        json.dump(actor.get_result(), f, ensure_ascii=False)
                    returns.append(actor.get_result())

        return returns

class Scheduler:
    def __init__(
        self,
        config_path: str,
        model_name: str,
        model_path: Optional[str],
        output_dir: str,
        prompts_name: str,
        sparse_mode: bool,
        active_window: int,
        n_workers: int = 16,
        **kwargs
    ):
        with open(config_path, encoding="utf-8") as f:
            self.config = yaml.safe_load(f)
        self.model_name = model_name
        self.model_path = model_path
        self.prompts_name = prompts_name
        self.sparse_mode = sparse_mode
        self.active_window = active_window
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.n_workers = n_workers

        put_request({
            "args": vars(kwargs.get("args", {})),
            "config": self.config,
        }, self.output_dir / "config.json")
        print("output dir:", self.output_dir)

    def _load_benchmark(self, bench: str) -> List[Dict[str, Any]]:
        bench_path = self.config["benchmarks"].get(bench, {"path": bench})["path"]
        with open(bench_path, "r", encoding="utf-8") as f:
            return [json.loads(line) for line in f]

    def evaluate_benchmarks(self, bench_list: List[str]):
        print(f"\n🎯 Evaluating: {bench_list}")

        actors_kwargs = []

        for bench in bench_list:
            samples = self._load_benchmark(bench=bench)

            actors_kwargs.extend([
                {
                    "config": self.config,
                    "model_name": self.model_name,
                    "prompts_name": self.prompts_name,
                    "model_path": self.model_path,
                    "sample": s,
                    "work_dir": self.output_dir,
                    "bench": bench,
                    "sparse_mode": self.sparse_mode,
                    "active_window": self.active_window,
                    "max_retries": 2
                }
                for s in tqdm(samples, desc="Init Actors...")
            ])
        
        n_actors_per_group = 3
        actors_kwargs_groups = [actors_kwargs[i:i + n_actors_per_group] for i in range(0, len(actors_kwargs), n_actors_per_group)]

        print(f"Begin to run {self.model_name} with {len(actors_kwargs)} samples:")
        with mp.Pool(self.n_workers) as pool:
            group_records = list(tqdm(
                pool.imap_unordered(actor_worker, actors_kwargs_groups),
                total=len(actors_kwargs_groups), 
                desc="Processing"
            ))
        records = [r for gr in group_records for r in gr]
        
        put_request({"status": "done"}, self.output_dir / "requests" / "done.json")

        records.sort(key = lambda x: x["id"])
        for record in records:
            self._write(
                path=self.output_dir / (record["bench"] + ".jsonl"),
                record=record
            )

    def _write(self, path: Path, record: Dict[str, Any]):
        with open(path, "a", encoding="utf-8") as f:
            f.write(json.dumps(record, ensure_ascii=False) + "\n")


def run_model_backend(
    cuda_visible_devices: str,
    gpu_per_model_backend: int,
    n_model_backend: int,
    param: dict
):
    cuda_visible_devices = [g.strip() for g in cuda_visible_devices.split(",") if g.strip()]
    total_needed = gpu_per_model_backend * n_model_backend
    if len(cuda_visible_devices) < total_needed:
        raise ValueError(f"Need {total_needed} GPUs but only {len(cuda_visible_devices)} available")

    processes = []
    for i in range(n_model_backend):
        start_idx = i * gpu_per_model_backend
        gpus = cuda_visible_devices[start_idx:start_idx + gpu_per_model_backend]
        env = os.environ.copy()
        env["CUDA_VISIBLE_DEVICES"] = ",".join(gpus)

        cmd = [sys.executable, "model_backend.py"]
        for k, v in param.items():
            cmd.extend([str(k), str(v)])

        print(f'Start process(GPU={",".join(gpus)}): {cmd}')
        processes.append(subprocess.Popen(cmd, env=env))
    return processes

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", default="config/stream_config.yaml")
    parser.add_argument("--prompts", default="streaming")
    parser.add_argument("--output-dir", default="outputs")
    parser.add_argument("--model-name", required=True)
    parser.add_argument("--model-path", type=str, default="")
    parser.add_argument("--benchmarks", nargs="+", required=True)
    parser.add_argument("--sparse-mode", type=int, default=1)
    parser.add_argument("--active-window", type=int, default=3)
    
    parser.add_argument("--cuda-visible-devices", type=str, default="0,1")
    parser.add_argument("--gpu-per-model-backend", type=int, default=1)
    parser.add_argument("--n-model-backend", type=int, default=1)
    parser.add_argument("--n-actors-alive", type=int, default=16)
    parser.add_argument("--batch-size", type=int, default=1)
    parser.add_argument("--load-modules", nargs="+", default=["model_backend", "scheduler"])
    parser.add_argument("--run-id", default=str(datetime.datetime.now().strftime('%Y%m%d_%H%M%S')))
    
    args = parser.parse_args()
    args.sparse_mode = bool(args.sparse_mode)
    
    args.output_dir = str(
        Path(args.output_dir) / f"{args.model_name}_{args.run_id}"
    )

    if "model_backend" in args.load_modules:
        run_model_backend(
            cuda_visible_devices=args.cuda_visible_devices,
            gpu_per_model_backend=args.gpu_per_model_backend,
            n_model_backend=args.n_model_backend,
            param = {
                "--config": args.config,
                "--model-name": args.model_name,
                "--model-path": args.model_path,
                "--output-dir": args.output_dir,
                "--batch-size": args.batch_size
            }
        )

    if "scheduler" in args.load_modules:
        scheduler = Scheduler(
            config_path=args.config,
            model_name=args.model_name,
            model_path=args.model_path,
            output_dir=args.output_dir,
            prompts_name=args.prompts,
            sparse_mode=args.sparse_mode,
            active_window=args.active_window,
            n_workers=args.n_actors_alive,
            args=args
        )

        scheduler.evaluate_benchmarks(args.benchmarks)

        print(f"\n🎉 Done: {args.model_name}")

    if "model_backend" in args.load_modules:
        done_path = Path(args.output_dir) / "requests" / "done.json"
        print(f"Processes are holding for done.json in {done_path}")
        while not done_path.exists():
            time.sleep(1)
        print(f"Processes Exit(0). Found {done_path}.")

if __name__ == "__main__":
    main()
