# main.py

import os
import sys
import argparse
import multiprocessing as mp
from pathlib import Path

# Project imports
from config.settings import (
    DEFAULT_PATHS,
    QWEN_CONFIG,
    EXPERIMENT_CONFIG,
    QUERY_DECOMPOSITION_CONFIG,
)


SCOPE_SAMPLERS = {"scope"}


def _infer_vision_model_family_from_path(model_path: str) -> str:
    """
    Infer output subdirectory prefix from vision model path:
    - contains 'blip' -> 'blip'
    - contains 'clip' -> 'clip'
    - otherwise default to 'clip' (backward compatible)
    """
    p = (model_path or "").lower()
    # Allow user-provided model dir names like xxx/blip-... or xxx/BLIP_...
    if "blip" in p:
        return "blip"
    if "clip" in p:
        return "clip"
    return "clip"


def _maybe_prefix_output_path_with_model_family(output_path: str, vision_model_path: str) -> str:
    """
    If output_path looks like the project's outputs root, prefix it with outputs/{clip|blip}.
    - If the user already passed outputs/clip or outputs/blip, keep it unchanged.
    - Works with both absolute and relative paths.
    """
    family = _infer_vision_model_family_from_path(vision_model_path)
    op = output_path or ""

    # Normalize for checks (do not change absolute/relative nature)
    norm = op.rstrip("/\\")
    base = os.path.basename(norm)
    parent = os.path.basename(os.path.dirname(norm)) if norm else ""

    # Already outputs/clip or outputs/blip
    if base in {"clip", "blip"} and parent == "outputs":
        return op

    # outputs or .../outputs: add one level
    if base == "outputs":
        return os.path.join(norm, family)

    return op


def main():
    # os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2"
    parser = argparse.ArgumentParser(description="SCOPE (Semantic Cloud-Orchestrated Perception at Edge)")
    
    # Basics
    parser.add_argument('--stage', type=str, required=True, choices=['preprocess', 'inference'],
                        help='Stage: preprocess or inference')
    parser.add_argument('--dataset_name', type=str, default='longvideobench',
                        help='Dataset name')
    parser.add_argument('--dataset_path', type=str, default=DEFAULT_PATHS["dataset_path"],
                        help='Dataset root path')
    parser.add_argument('--output_path', type=str, default=DEFAULT_PATHS["output_path"],
                        help='Output directory')
    # Optional: a single shard json/jsonl file
    parser.add_argument('--dataset_file', type=str, default=None,
                        help='Use a single shard file (overrides dataset_name default loader)')

    
    # Sampling
    parser.add_argument('--sampler', type=str, required=True,
                        choices=EXPERIMENT_CONFIG["supported_samplers"],
                        help='Sampler type')
    parser.add_argument('--num_frames', type=int, default=16,
                        help='Number of frames to sample')
    parser.add_argument('--frame_allocation_mode', type=str, default='importance',
                        choices=['importance', 'uniform', 'random', 'dirichlet', 'winner_take_all'],
                        help='Frame allocation mode (for scope): importance/uniform/random/dirichlet/winner_take_all')
    parser.add_argument('--clip_path', type=str, default=DEFAULT_PATHS["clip_path"],
                        help='Vision similarity model path (CLIP/BLIP)')
    
    # Preprocess-only
    parser.add_argument('--num_workers', type=int, 
                        default=EXPERIMENT_CONFIG["default_num_workers"],
                        help='Number of preprocess workers')
    parser.add_argument('--gpu_ids', type=str, default=None,
                        help='Comma-separated GPU ids (e.g., 1,2)')
    
    # Inference-only
    parser.add_argument('--agent', type=str, default='direct',
                        choices=EXPERIMENT_CONFIG["supported_agents"],
                        help='Agent type')
    parser.add_argument('--preprocess_session', type=str, default=None,
                        help='Preprocess session directory (defaults to latest)')
    
    # Model config
    parser.add_argument('--model_path', type=str, 
                        default=QWEN_CONFIG["default_model"],
                        help='VLM model path')
    parser.add_argument('--vlm_backend', type=str, default='auto',
                        choices=['auto', 'qwen', 'internvl3'],
                        help='VLM backend: auto/qwen/internvl3')
    
    # Generation
    parser.add_argument('--max_new_tokens', type=int, default=4096,
                        help='Max new tokens')
    parser.add_argument('--temperature', type=float, default=0.0,
                        help='Sampling temperature')
    parser.add_argument('--do_sample', action='store_true',
                        help='Enable sampling')
    parser.add_argument('--num_beams', type=int, default=1,
                        help='Beam search beams')
    
    
    # Scope: global matching FPS
    parser.add_argument('--match_fps', type=float, default=None,
                        help='Candidate extraction FPS for scope (e.g., 0.1/0.5/1/2)')
    
    # Query decomposition (Scope)
    parser.add_argument('--use_api', action='store_true',
                        help='Use remote API for query decomposition')
    parser.add_argument('--use_local', action='store_true',
                        help='Force local model for query decomposition')
    parser.add_argument('--api_key', type=str, default=None,
                        help='API key (or env API_KEY)')
    parser.add_argument('--planner_api_model', type=str, default=None,
                        help='Planner API model name (e.g., deepseek-v3)')
    # Offline meta reuse
    parser.add_argument('--meta_root', type=str, default=None,
                        help='Offline meta root directory (optional)')
    parser.add_argument('--no_reuse_meta_frame_allocation', action='store_true',
                        help='Do not reuse frame_allocation from offline meta')
    
    args = parser.parse_args()

    # Auto-separate outputs by vision backbone (clip vs blip), unless user already did.
    resolved_output_path = _maybe_prefix_output_path_with_model_family(args.output_path, args.clip_path)
    if resolved_output_path != args.output_path:
        print(
            f"[SCPOE] Detected vision model path: {args.clip_path}\n"
            f"[SCPOE] Adjusting output dir from {args.output_path} to {resolved_output_path}"
        )
        args.output_path = resolved_output_path
    
    # Multiprocessing start method
    mp.set_start_method('spawn', force=True)
    
    # Optional parameter info
    if args.sampler in SCOPE_SAMPLERS:
        print("[SCPOE] Scope config:")
        if args.match_fps is not None:
            print(f"  Global match FPS: {args.match_fps} fps")
        
        # For large models, a single worker is often safer to avoid GPU OOM.
        # if args.num_workers > 1:
            # print("[SCPOE] Warning: large models often require single-process execution")
            # print("[SCPOE] Auto-setting num_workers=1")
            # args.num_workers = 1
    
    if args.stage == 'preprocess':
        from experiments import run_preprocess
        # Resolve query-decomposition backend
        use_api = None
        if args.use_api and args.use_local:
            print("Error: cannot use both --use_api and --use_local")
            sys.exit(1)
        elif args.use_api:
            use_api = True
        elif args.use_local:
            use_api = False
        # If neither is specified, keep None (use config default).
        
        # Parse GPU ids
        gpu_ids = None
        if args.gpu_ids:
            try:
                gpu_ids = [int(x.strip()) for x in args.gpu_ids.split(',')]
                print(f"Using GPUs: {gpu_ids}")
            except ValueError:
                print("Error: invalid GPU id list (expected comma-separated integers, e.g., 1,2)")
                sys.exit(1)
        
        # Run preprocess
        stats = run_preprocess(
            dataset_name=args.dataset_name,
            dataset_path=args.dataset_path,
            output_path=args.output_path,
            sampler_type=args.sampler,
            num_frames=args.num_frames,
            clip_path=args.clip_path,
            num_workers=args.num_workers,
            frame_allocation_mode=args.frame_allocation_mode,
            use_api=use_api,
            api_key=args.api_key,
            gpu_ids=gpu_ids,
            dataset_file=args.dataset_file,
            meta_root=args.meta_root,
            reuse_meta_frame_allocation=(not args.no_reuse_meta_frame_allocation),
            match_fps=args.match_fps,
            planner_api_model=args.planner_api_model,
        )
        print("\nPreprocess completed. Stats saved.")
        
    elif args.stage == 'inference':
        from experiments import run_inference
        # Model config
        model_config = {
            'model_path': args.model_path,
            'vlm_backend': args.vlm_backend,
        }
        print(f"[SCPOE] Using model: {args.model_path}")
        
        # Generation config
        generation_kwargs = {
            'max_new_tokens': args.max_new_tokens,
            'temperature': args.temperature,
            'do_sample': args.do_sample,
            'num_beams': args.num_beams
        }
        
        # Agent config
        agent_config = {}
        
        # Run inference
        summary = run_inference(
            dataset_name=args.dataset_name,
            dataset_path=args.dataset_path,
            output_path=args.output_path,
            sampler_type=args.sampler,
            agent_type=args.agent,
            num_frames=args.num_frames,
            frame_allocation_mode=args.frame_allocation_mode,
            match_fps=args.match_fps,
            planner_api_model=args.planner_api_model,
            model_config=model_config,
            generation_kwargs=generation_kwargs,
            agent_config=agent_config,
            preprocess_session=args.preprocess_session
        )
        print("\nInference completed. Results saved.")
    
    else:
        print("Error: invalid stage")


if __name__ == '__main__':
    import multiprocessing as mp
    try:
        mp.set_start_method("spawn", force=True)
    except RuntimeError:
        pass

    main()
