import cv2
cv2.setNumThreads(1)
from typing import Any, Dict, List, Optional, Union
from pathlib import Path
import logging
import uuid
import os

import hydra
from hydra.utils import instantiate
from omegaconf import DictConfig
import pytorch_lightning as pl

from nuplan.planning.utils.multithreading.worker_pool import WorkerPool
from nuplan.planning.utils.multithreading.worker_utils import worker_map

from navsim.planning.training.dataset import Dataset
from navsim.common.dataloader import SceneLoader
from navsim.common.dataclasses import SceneFilter, SensorConfig
from navsim.agents.abstract_agent import AbstractAgent
import json
from tqdm import tqdm
logger = logging.getLogger(__name__)

CONFIG_PATH = "config/training"
CONFIG_NAME = "default_training"


def cache_features(args: List[Dict[str, Union[List[str], DictConfig]]]) -> List[Optional[Any]]:
    """
    Helper function to cache features and targets of learnable agent.
    :param args: arguments for caching
    """
    node_id = int(os.environ.get("NODE_RANK", 0))
    thread_id = str(uuid.uuid4())

    log_names = [a["log_file"] for a in args]
    tokens = [t for a in args for t in a["tokens"]]
    cfg: DictConfig = args[0]["cfg"]

    agent: AbstractAgent = instantiate(cfg.agent)

    scene_filter: SceneFilter = instantiate(cfg.train_test_split.scene_filter)
    scene_filter.log_names = log_names
    scene_filter.tokens = tokens
    scene_loader = SceneLoader(
        sensor_blobs_path=Path(cfg.sensor_blobs_path),
        data_path=Path(cfg.navsim_log_path),
        scene_filter=scene_filter,
        sensor_config=agent.get_sensor_config(),
        enable_filter=False
    )
    logger.info(f"Extracted {len(scene_loader.tokens)} scenarios for thread_id={thread_id}, node_id={node_id}.")


    dataset = instantiate(
        cfg.dataset,
        scene_loader=scene_loader,
        agent=agent,
        cache_path=cfg.cache_path,
        force_cache_computation=cfg.force_cache_computation,
    )
    return []

import sys
def silence_worker():
    sys.stdout = open(os.devnull, 'w')
    sys.stderr = open(os.devnull, 'w')
    logging.getLogger().setLevel(logging.WARNING)

@hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME, version_base=None)
def main(cfg: DictConfig) -> None:
    """
    Main entrypoint for dataset caching script.
    :param cfg: omegaconf dictionary
    """
    logger.info("Global Seed set to 0")
    pl.seed_everything(0, workers=True)

    logger.info("Building Worker")
    worker: WorkerPool = instantiate(cfg.worker)

    logger.info("Building SceneLoader")
    scene_filter: SceneFilter = instantiate(cfg.train_test_split.scene_filter)
    data_path = Path(cfg.navsim_log_path)
    sensor_blobs_path = Path(cfg.sensor_blobs_path)
    scene_loader = SceneLoader(
        sensor_blobs_path=sensor_blobs_path,
        data_path=data_path,
        scene_filter=scene_filter,
        sensor_config=SensorConfig.build_no_sensors(),
        enable_filter=True
    )
    logger.info(f"Extracted {len(scene_loader)} scenarios for training/validation dataset")

    data_points = [
        {
            "cfg": cfg,
            "log_file": log_file,
            "tokens": tokens_list,
        }
        for log_file, tokens_list in scene_loader.get_tokens_list_per_log().items()
    ]
    print('data_points', len(data_points))
    cache_features(data_points)
    len_points = len(scene_loader)
    del scene_loader
    #_ = worker_map(worker, cache_features, data_points)

    trainval_dir = Path(cfg.cache_path)/"trainval"

    trainval_dir_annotated = Path(cfg.cache_path)/"trainval_annotation"

    for split_dir in [trainval_dir, trainval_dir_annotated]:
        if not split_dir.exists():
            print(f"[SKIP] {split_dir} 不存在")
            continue

        merged_qas = []

        # 支持 .json 和 .jsonl；若有子目录，用 rglob
        json_files = sorted(split_dir.rglob("*.json*"))
        for fp in tqdm(json_files, desc=f"Collecting {split_dir.name}"):
            with fp.open("r", encoding="utf-8") as f:
                obj = json.load(f)
                if isinstance(obj, list):
                    merged_qas.extend(obj)
                else:
                    merged_qas.append(obj)

        # 写出单文件（数组格式）和流式 jsonl 两份，按需使用
        out_json  = split_dir.parent / f"{split_dir.name}_merged.json"

        with out_json.open("w", encoding="utf-8") as f_json:
            json.dump(merged_qas, f_json, ensure_ascii=False, indent=2)


        print(f"[OK] {split_dir.name}: {len(merged_qas)} samples → "
            f"{out_json.name}")


if __name__ == "__main__":
    main()
