# experiments/preprocess.py
"""
Preprocess stage: run samplers in parallel.
"""

import os
import json
import time
import multiprocessing as mp
from datetime import datetime
from tqdm import tqdm
from typing import Dict, Any, List, Optional, cast
from collections import Counter
import numpy as np

from core.samplers import get_sampler
from utils import load_dataset
from utils.file_manager import OutputFileManager
from config.settings import EXPERIMENT_CONFIG, DATASET_CONFIG


SCOPE_SAMPLERS = {"scope"}
SCOPE_METADATA_MODES = {"scope"}


# Worker globals
worker_sampler = None
worker_sampler_type = None
worker_gpu_id = None
worker_frame_allocation_mode = None


def get_worker_gpu_id(available_gpus: List[int]) -> int:
    """
    Choose a GPU id for the current worker.
    Uses a deterministic hash of the PID to keep assignment stable per process.
    """
    
    # Use PID to pick a GPU
    pid = os.getpid()
    
    # Deterministic assignment
    worker_index = pid % len(available_gpus)
    gpu_id = available_gpus[worker_index]
    
    return gpu_id


def init_worker(sampler_type: str, clip_path: str, available_gpus: List[int], 
               use_api: Optional[bool] = None, api_key: Optional[str] = None, meta_root: Optional[str] = None,
               frame_allocation_mode: str = "importance",
               reuse_meta_frame_allocation: bool = True,
               match_fps: Optional[float] = None,
               planner_api_model: Optional[str] = None):
    """Initialize a worker process (with dynamic GPU assignment)."""
    global worker_sampler, worker_sampler_type, worker_gpu_id, worker_frame_allocation_mode
    worker_sampler_type = sampler_type
    worker_frame_allocation_mode = frame_allocation_mode
    
    try:
        # Dynamic GPU assignment
        if available_gpus:
            worker_gpu_id = get_worker_gpu_id(available_gpus)
            device = f"cuda:{worker_gpu_id}"
        else:
            worker_gpu_id = 0
            device = "cuda:0"
        
        # Sampler init kwargs
        sampler_kwargs: Dict[str, Any] = {
            "clip_model_path": clip_path, 
            "device": device
        }
        
        # Scope sampler extras
        if sampler_type in SCOPE_SAMPLERS:
            sampler_kwargs["frame_allocation_mode"] = frame_allocation_mode
            sampler_kwargs["reuse_meta_frame_allocation"] = reuse_meta_frame_allocation
            if use_api is not None:
                sampler_kwargs["use_api"] = use_api
            if api_key is not None:
                sampler_kwargs["api_key"] = api_key
            if planner_api_model is not None:
                sampler_kwargs["planner_api_model"] = planner_api_model
            if meta_root is not None:
                sampler_kwargs["meta_root"] = meta_root
            if match_fps is not None:
                sampler_kwargs["match_fps"] = match_fps

        worker_sampler = get_sampler(sampler_type, **sampler_kwargs)
        print(f"Worker initialized with {sampler_type} sampler on {device}")

        # Print backend info for scope sampler
        if sampler_type in SCOPE_SAMPLERS:
            api_status = "api" if use_api else "local"
            print(f"Planner backend: {api_status}, GPU: {device}")
            
    except Exception as e:
        print(f"Failed to initialize worker with {sampler_type}: {e}")
        worker_sampler = None


def process_task(task_info):
    """Process a single task."""
    global worker_sampler, worker_sampler_type
    task_index, task, video_base_path, num_frames, sampler_type = task_info
    video_path = os.path.join(video_base_path, task["video_path"])

    # Build query text with candidates (minimal change)
    query_text = task.get("question", "")
    _cands = task.get("candidates")
    if isinstance(_cands, (list, tuple)) and _cands:
        # Join with a simple delimiter
        joined = " | ".join(str(c) for c in _cands)
        query_text = f"{query_text}\nCandidates: {joined}"

    
    # Ensure sampler is initialized
    if worker_sampler is None:
        error_msg = f"Worker sampler not initialized for {sampler_type}"
        print(f"Error processing task {task.get('id', 'N/A')}: {error_msg}")
        error_meta = {
            "mode": sampler_type,
            "frame_allocation_mode": worker_frame_allocation_mode,
            "errors": [{"type": "sampler_initialization_error", "detail": error_msg}],
            "task_id": task.get('id', 'N/A'),
            "video_path": video_path,
            "fallback": True
        }
        return ([], 0.0, error_meta)
    
    try:
        start_time = time.time()
        meta = {"mode": sampler_type}
        if sampler_type in SCOPE_SAMPLERS:
            meta["frame_allocation_mode"] = worker_frame_allocation_mode
        
        if sampler_type == 'uniform':
            from utils import get_video_frame_count_decord
            total_frames = get_video_frame_count_decord(video_path)
            indices = worker_sampler.select_keyframes(total_frames, num_frames)
            
        elif sampler_type == 'aks':
            from utils import extract_frames_decord
            frames_1fps, original_indices = extract_frames_decord(video_path, fps=1)
            indices = worker_sampler.select_keyframes(frames_1fps, original_indices, task["question"], num_frames)
            
        elif sampler_type == 'topk':
            from utils import extract_frames_decord
            frames_1fps, original_indices = extract_frames_decord(video_path, fps=1)
            indices = worker_sampler.select_keyframes(frames_1fps, original_indices, task["question"], num_frames)
            
        elif sampler_type == 'division':
            indices = worker_sampler.select_keyframes(video_path, num_frames, query_text)
            meta = worker_sampler.get_sampling_metadata(video_path, query_text, num_frames)
            
        elif sampler_type in SCOPE_SAMPLERS:
            indices = worker_sampler.select_keyframes(
                video_path,
                num_frames,
                query_text,
                task_index=task_index,
                task_id=task.get("id", task_index),
            )
            meta = worker_sampler.get_sampling_metadata(video_path, query_text, num_frames)

        else:
            raise ValueError(f"Unknown sampler: {sampler_type}")
        
        end_time = time.time()
        return (indices, end_time - start_time, meta)
        
    except Exception as e:
        error_msg = str(e)
        print(f"Error processing task {task.get('id', 'N/A')}: {error_msg}")
        
        # Standardized error metadata
        error_meta = {
            "mode": sampler_type,
            "frame_allocation_mode": worker_frame_allocation_mode,
            "errors": [
                {
                    "type": "task_processing_error",
                    "detail": error_msg
                }
            ],
            "task_id": task.get('id', 'N/A'),
            "video_path": video_path,
            "fallback": True
        }
        
        return ([], 0.0, error_meta)
    
    finally:
        # Cleanup
        try:
            if worker_sampler is not None and hasattr(worker_sampler, 'clear_cache'):
                worker_sampler.clear_cache()
            
            import torch, gc
            # if torch.cuda.is_available():
            #     torch.cuda.empty_cache()
            gc.collect()
        except Exception:
            pass


def convert_numpy_types(obj):
    """
    Recursively convert numpy scalars/arrays to Python types for JSON serialization.
    """
    if isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {key: convert_numpy_types(value) for key, value in obj.items()}
    elif isinstance(obj, list):
        return [convert_numpy_types(item) for item in obj]
    elif isinstance(obj, tuple):
        return tuple(convert_numpy_types(item) for item in obj)
    else:
        return obj


def generate_summary_report(stats: Dict[str, Any],
                            stage_agg: Dict[str, Dict[str, Any]],
                            sampler_type: str) -> Dict[str, Any]:
    """
    Generate a compact summary report (does not depend on full per-task meta).
    - stage_agg: {"sum": {stage: float}, "cnt": {stage: int}}
    """
    dataset_size = stats.get("dataset_size", 0)
    wall_clock = stats.get("timing", {}).get("wall_clock_time", 0.0)
    avg_per_task = stats.get("timing", {}).get("average_time_per_task", 0.0)

    summary = {
        "experiment_info": {
            "sampler_type": sampler_type,
            "num_frames": stats.get("num_frames", 0),
            "dataset_size": dataset_size,
            "timestamp": stats.get("session_info", {}).get("timestamp", ""),
            "session_dir": stats.get("session_info", {}).get("session_dir", "")
        },
        "performance_metrics": {
            "wall_clock_time": wall_clock,
            "total_cpu_time": stats.get("timing", {}).get("total_cpu_time", 0.0),
            "average_time_per_task": avg_per_task,
            "speedup": stats.get("timing", {}).get("speedup", 1.0),
            "tasks_per_second": (dataset_size / wall_clock) if wall_clock > 0 else 0.0
        },
        "error_analysis": stats.get("errors", {}),
        "file_locations": stats.get("files", {}),
        "sampling_analysis": {
            "successful_tasks": stats.get("success", {}).get("successful_tasks", 0),
            "failed_tasks": stats.get("success", {}).get("failed_tasks", 0),
            "success_rate": stats.get("success", {}).get("success_rate", 0.0)
        }
    }

    # Per-stage average timing (via sum/cnt)
    s_sum = stage_agg.get("sum", {})
    s_cnt = stage_agg.get("cnt", {})
    stage_timings = {}
    for k in s_sum.keys():
        cnt = max(1, int(s_cnt.get(k, 0)))
        stage_timings[k] = {
            "average": float(s_sum[k]) / cnt if cnt > 0 else 0.0,
            "count": int(s_cnt.get(k, 0)),
            "total": float(s_sum[k])
        }
    if stage_timings:
        summary["stage_timings"] = stage_timings

    # Friendly overall breakdown (prefer timing_breakdown_compact when available)
    if stage_timings:
        preferred_keys = (
            "query_decomposition",
            "frame_allocation",
            "frame_feature_extraction",
            "matching_score",
            "matching_select",
            "postprocess_dedup_budget",
            "fill_remaining_frames",
            "other",
            "total_time_actual",
        )
        compact = {k: stage_timings[k] for k in preferred_keys if k in stage_timings}
        if compact:
            summary["timing_breakdown_compact"] = {
                **compact,
                "_notes": (
                    "Fields aggregated from scope timing_breakdown_compact per task. "
                    "frame_feature_extraction=encoding frame features for candidates (controlled by match_fps); "
                    "matching_score/matching_select=scoring and top-k selection; "
                    "fill_remaining_frames=time spent filling remaining budget (reuses scores when possible)."
                ),
            }

    # Ensure the returned object is JSON-serializable and dict-shaped
    return cast(Dict[str, Any], convert_numpy_types(summary))


def run_preprocess(dataset_name: str, dataset_path: str, output_path: str, 
                  sampler_type: str, num_frames: int, clip_path: str, 
                  frame_allocation_mode: str = "importance",
                  num_workers: Optional[int] = None, use_api: Optional[bool] = None, 
                  api_key: Optional[str] = None, gpu_ids: Optional[List[int]] = None,
                  dataset_file: Optional[str] = None, meta_root: Optional[str] = None,
                  reuse_meta_frame_allocation: bool = True,
                  match_fps: Optional[float] = None,
                  planner_api_model: Optional[str] = None) -> Dict[str, Any]:
    """
    Run preprocess stage (streaming-to-disk).
    - indices.json: sampled indices per task (list)
    - timing.json: per-task timing (list)
    - errors.json / error_counts.json: error logs
    - summary.json: aggregated statistics (no large payload)
    """
    print(f"======== Preprocess start [{sampler_type}] (multiprocessing) ========")
    
    # Scope: print planner backend info (does not change passed args)
    if sampler_type in SCOPE_SAMPLERS:
        from config.settings import QUERY_DECOMPOSITION_CONFIG, API_CONFIG
        effective_use_api = use_api if use_api is not None else QUERY_DECOMPOSITION_CONFIG.get("use_api", True)
        effective_api_key = api_key if api_key is not None else (API_CONFIG.get("api_key", None))

        use_api = effective_use_api
        api_key = effective_api_key

        api_status = "api" if effective_use_api else "local"
        print(f"Planner backend: {api_status}")
        if effective_use_api and effective_api_key:
            masked = ('*' * (len(effective_api_key) - 8) + effective_api_key[-8:]) if len(effective_api_key) > 8 else '***'
            print(f"API key: {masked}")
        if planner_api_model:
            print(f"Planner API model: {planner_api_model}")
        if match_fps is not None:
            print(f"Match FPS: {match_fps} fps")
        print(f"Frame allocation mode: {frame_allocation_mode}")
    
    # Create file manager and a new session
    file_manager = OutputFileManager(output_path, dataset_name)
    variant = None
    if sampler_type in SCOPE_SAMPLERS:
        import re
        parts = [frame_allocation_mode]
        if match_fps is not None:
            try:
                parts.append(f"mfps{float(match_fps):g}")
            except Exception:
                parts.append(f"mfps{match_fps}")
        if planner_api_model:
            parts.append(f"planner_{planner_api_model}")
        variant = "__".join(parts)
        variant = re.sub(r"[^0-9A-Za-z._-]+", "_", variant).strip("_")
    file_paths = file_manager.create_preprocess_session(num_frames, sampler_type, variant=variant)

    # Directory for per-task meta
    meta_dir = os.path.join(file_paths['session_dir'], 'meta')
    os.makedirs(meta_dir, exist_ok=True)

    print(f"Session dir: {file_paths['session_dir']}")
    print(f"Timestamp: {file_paths['timestamp']}")

    # Load dataset (supports JSON array or JSONL shard files)
    def _read_any_json_or_jsonl(p):
        arr = []
        with open(p, 'r', encoding='utf-8') as f:
            head = f.read(2)
            f.seek(0)
            if head and head[0] == '[':
                # JSON array
                data = json.load(f)
                if not isinstance(data, list):
                    raise ValueError(f"{p} is JSON but not a list")
                return data
            else:
                # JSONL
                for line in f:
                    line = line.strip()
                    if line:
                        arr.append(json.loads(line))
                return arr
    
    # Load dataset
    # dataset = load_dataset(dataset_path, dataset_name)
    if dataset_file:
        print(f"[Data] Using shard file: {dataset_file}")
        dataset = _read_any_json_or_jsonl(dataset_file)
    else:
        dataset = load_dataset(dataset_path, dataset_name)
    # Resolve video directory (default: 'videos')
    video_dir = DATASET_CONFIG.get(dataset_name, {}).get('video_dir', 'videos')
    video_base_path = os.path.join(dataset_path, video_dir)
    
    # Task queue
    tasks_to_process = [(i, task, video_base_path, num_frames, sampler_type) for i, task in enumerate(dataset)]
    
    # Worker count and GPU config
    if num_workers is None:
        num_workers = min(EXPERIMENT_CONFIG["default_num_workers"], mp.cpu_count())
    available_gpus = gpu_ids if gpu_ids is not None else [0]
    max_workers_per_gpu = 2
    recommended_workers = len(available_gpus) * max_workers_per_gpu
    if num_workers > recommended_workers:
        print(f"[Hint] num_workers={num_workers}; recommended <= {recommended_workers} (num_gpus×{max_workers_per_gpu})")
    
    print(f"Using {num_workers} worker processes...")
    print(f"Available GPUs: {available_gpus}")
    
    # Streaming accumulators (avoid large in-memory arrays)
    all_indices: List[List[int]] = []
    all_times: List[float] = []
    error_rows: List[Dict[str, Any]] = []
    error_counter: Counter = Counter()

    # Success/failure stats
    success_cnt = 0
    fail_cnt = 0

    # Per-stage timing aggregation (key -> sum / cnt)
    from collections import defaultdict
    stage_sum = defaultdict(float)
    stage_cnt = defaultdict(int)

    def _safe_dump_json(path: str, obj: Any):
        obj = convert_numpy_types(obj)
        with open(path, 'w', encoding='utf-8') as f:
            json.dump(obj, f, ensure_ascii=False, indent=2)

    # Multiprocessing execution, streaming results to disk
    total_start_time = time.time()
    with mp.Pool(processes=num_workers, initializer=init_worker, 
                 initargs=(
                     sampler_type,
                     clip_path,
                     available_gpus,
                     use_api,
                     api_key,
                     meta_root,
                     frame_allocation_mode,
                     reuse_meta_frame_allocation,
                     match_fps,
                     planner_api_model,
                 )) as pool:
        for i, (indices, t_spent, meta) in enumerate(tqdm(
            pool.imap(process_task, tasks_to_process),
            total=len(tasks_to_process),
            desc=f"Sampling... ({sampler_type})"
        )):
            task = dataset[i]
            task_id = task.get("id", i)
            video_path = os.path.join(video_base_path, task["video_path"])

            # Order-preserving outputs: indices/timing are saved as full lists
            all_indices.append(indices)
            all_times.append(t_spent)

            # Per-task meta: merge lightweight context + sampler meta.
            # Note: do not embed indices to avoid duplication / file bloat.
            per_meta = {
                "task_id": task_id,
                "video_path": video_path,
                "sampler_type": sampler_type,
                "frame_allocation_mode": frame_allocation_mode if sampler_type in SCOPE_SAMPLERS else None,
                "num_frames": num_frames,
                "match_fps": match_fps if sampler_type in SCOPE_SAMPLERS else None,
                "planner_api_model": planner_api_model if sampler_type in SCOPE_SAMPLERS else None,
                "planner_use_api": use_api if sampler_type in SCOPE_SAMPLERS else None,
                "question": task.get("question", ""),
                "candidates": task.get("candidates", None),
            }
            if isinstance(meta, dict):
                per_meta.update(meta)

            # Persist per-task meta immediately: meta/meta_{i}_{task_id}.json
            meta_path = os.path.join(meta_dir, f"meta_{i}_{task_id}.json")
            _safe_dump_json(meta_path, per_meta)

            # Success/failure + error logs (keep stable schema)
            errors = (meta or {}).get("errors", [])
            if errors:
                fail_cnt += 1
                for e in errors:
                    e_type = e.get("type", "unknown")
                    error_counter[e_type] += 1
                    error_rows.append({
                        "task_index": i,
                        "task_id": task_id,
                        "video_path": video_path,
                        "error_type": e_type,
                        "detail": e.get("detail", "")
                    })
            else:
                success_cnt += 1

            # Aggregate stage timings (prefer timing_breakdown_compact)
            def _accum_timings(d: Dict[str, Any]):
                for k, v in d.items():
                    if isinstance(v, (int, float)) and not isinstance(v, bool):
                        stage_sum[str(k)] += float(v)
                        stage_cnt[str(k)] += 1

            if isinstance(meta, dict):
                tb = meta.get("timing_breakdown")
                if isinstance(tb, dict):
                    compact = tb.get("timing_breakdown_compact")
                    if isinstance(compact, dict):
                        _accum_timings(compact)
                    else:
                        # Backward compatibility: if no compact breakdown exists, only keep key stages
                        keys = (
                            "query_decomposition",
                            "frame_allocation",
                            "shared_frame_features_time",
                            "layer_processing_actual",
                            "fill_remaining_frames_time",
                            "postprocess_dedup_budget_time",
                            "total_time_actual",
                        )
                        _accum_timings({k: tb.get(k) for k in keys if k in tb})
                else:
                    # Non-enhanced samplers: fall back to *_time fields if present
                    for k, v in meta.items():
                        if (isinstance(v, (int, float)) and not isinstance(v, bool)
                            and (k.endswith("_time") or ("timing" in k.lower()))):
                            stage_sum[k] += float(v)
                            stage_cnt[k] += 1

            # scope: keep model diagnostic errors (schema-compatible)
            mode = (meta or {}).get('mode', 'unknown')
            if mode in SCOPE_METADATA_MODES:
                query_decomposition = (meta or {}).get("query_decomposition", {})
                model_diag = query_decomposition.get("model_diag")
                if model_diag and (model_diag.get("retry_errors") or model_diag.get("fallback_used")):
                    error_rows.append({
                        "task_index": i,
                        "task_id": task_id,
                        "video_path": video_path,
                        "error_type": "query_decomposition_diag",
                        "detail": {
                            "attempts": model_diag.get("attempts", 1),
                            "success": model_diag.get("final_success", True),
                            "fallback_used": model_diag.get("fallback_used", False)
                        }
                    })
            else:
                api_diag = (meta or {}).get("api_diag")
                if api_diag and api_diag.get("errors"):
                    error_rows.append({
                        "task_index": i,
                        "task_id": task_id,
                        "video_path": video_path,
                        "error_type": f"{mode}_model_diag",
                        "detail": api_diag
                    })

    wall_clock_time = time.time() - total_start_time
    total_cpu_time = float(sum(all_times))

    # Save order-preserving files
    _safe_dump_json(file_paths['indices'], all_indices)
    _safe_dump_json(file_paths['timing'], all_times)

    # Save error logs
    with open(file_paths['errors'], "w", encoding="utf-8") as f:
        json.dump(error_rows, f, ensure_ascii=False, indent=2)
    with open(file_paths['error_counts'], "w", encoding="utf-8") as f:
        json.dump(error_counter.most_common(), f, ensure_ascii=False, indent=2)

    # Basic stats
    stats = {
        "sampler_type": sampler_type,
        "frame_allocation_mode": frame_allocation_mode if sampler_type in SCOPE_SAMPLERS else None,
        "num_frames": num_frames,
        "match_fps": match_fps if sampler_type in SCOPE_SAMPLERS else None,
        "planner_api_model": planner_api_model if sampler_type in SCOPE_SAMPLERS else None,
        "dataset_size": len(dataset),
        "session_info": {
            "session_dir": file_paths['session_dir'],
            "timestamp": file_paths['timestamp']
        },
        "files": {
            "indices": file_paths['indices'],
            "timing": file_paths['timing'],
            "errors": file_paths['errors'],
            "error_counts": file_paths['error_counts'],
            "meta_dir": meta_dir  # per-task meta directory
        },
        "timing": {
            "wall_clock_time": wall_clock_time,
            "total_cpu_time": total_cpu_time,
            "average_time_per_task": total_cpu_time / len(dataset) if dataset else 0.0,
            "speedup": (total_cpu_time / wall_clock_time) if wall_clock_time > 0 else 1.0
        },
        "errors": {
            "total_errors": len(error_rows),
            "error_types": dict(error_counter.most_common())
        },
        "success": {
            "successful_tasks": success_cnt,
            "failed_tasks": fail_cnt,
            "success_rate": (success_cnt / (success_cnt + fail_cnt) * 100.0) if (success_cnt + fail_cnt) else 0.0
        }
    }

    print("\n======== Preprocess completed ========")
    print(f"Session dir: {file_paths['session_dir']}")
    print(f"Indices: {file_paths['indices']}")
    print(f"Timing: {file_paths['timing']}")
    print(f"Wall clock: {wall_clock_time:.2f}s (CPU total: {total_cpu_time:.2f}s)")
    print(f"Speedup: {stats['timing']['speedup']:.2f}x")
    if error_rows:
        print(f"Errors: {len(error_rows)} (see {file_paths['errors']})")

    # Generate compact summary report (no need to read full meta back)
    print("\n======== Generating summary report ========")
    try:
        summary_report = generate_summary_report(
            stats=stats,
            stage_agg={"sum": dict(stage_sum), "cnt": dict(stage_cnt)},
            sampler_type=sampler_type
        )
        summary_path = os.path.join(file_paths['session_dir'], 'summary.json')
        _safe_dump_json(summary_path, summary_report)
        print(f"Summary report saved to: {summary_path}")
        stats["files"]["summary"] = summary_path

        perf = summary_report.get("performance_metrics", {})
        print(f"Average time per task: {perf.get('average_time_per_task', 0):.3f}s")
        print(f"Throughput: {perf.get('tasks_per_second', 0):.2f} tasks/s")

        sa = summary_report.get("sampling_analysis", {})
        if sa:
            print(f"Success rate: {sa.get('success_rate', 0):.1f}% ({sa.get('successful_tasks', 0)}/{stats['dataset_size']})")

        stage_timings = summary_report.get("stage_timings", {})
        if stage_timings:
            print("Average time per stage:")
            for stage, timing in stage_timings.items():
                if isinstance(timing, dict) and 'average' in timing:
                    print(f"  {stage}: {timing['average']:.3f}s")
    except Exception as e:
        print(f"Failed to generate summary report: {e}")
        print("Continuing...")

    return stats
