# experiments/inference.py
"""
Inference stage: run agents for video QA.
"""

import os
import json
import time
from datetime import datetime
from tqdm import tqdm
from typing import Dict, Any, List

from core.agents import get_agent
from core.models import initialize_model_specific, initialize_model
from utils import load_dataset, judge_correct, calculate_accuracy, extract_model_size_tag
from utils.file_manager import OutputFileManager
from config.settings import GENERATION_CONFIG, DATASET_CONFIG

SCOPE_SAMPLERS = {"scope"}


def run_inference(dataset_name: str, dataset_path: str, output_path: str,
                 sampler_type: str, agent_type: str, num_frames: int,
                 frame_allocation_mode: str,
                 model_config: Dict[str, Any], generation_kwargs: Dict = None,
                 agent_config: Dict = None, preprocess_session: str = None,
                 match_fps: float = None,
                 planner_api_model: str = None) -> Dict[str, Any]:
    """
    Run inference.

    Args:
        dataset_name: Dataset name.
        dataset_path: Dataset root path.
        output_path: Output directory.
        sampler_type: Sampler type.
        agent_type: Agent type.
        num_frames: Frame budget.
        frame_allocation_mode: Frame allocation mode (for scope).
        model_config: Model config.
        generation_kwargs: Generation parameters.
        agent_config: Agent config.
        preprocess_session: Preprocess session directory (optional).

    Returns:
        Summary dict.
    """
    print(f"======== Inference start [{sampler_type} + {agent_type}] ========")
    if sampler_type in SCOPE_SAMPLERS:
        print(f"Frame allocation mode: {frame_allocation_mode}")
    
    # Defaults
    if generation_kwargs is None:
        generation_kwargs = GENERATION_CONFIG.copy()
    if agent_config is None:
        agent_config = {}
    
    print(f"Generation kwargs: {generation_kwargs}")
    
    # Output directory layout
    model_name = os.path.basename(model_config['model_path'].rstrip('/\\'))
    
    sampler_variant = None
    if sampler_type in SCOPE_SAMPLERS:
        import re
        parts = [frame_allocation_mode]
        if match_fps is not None:
            try:
                parts.append(f"mfps{float(match_fps):g}")
            except Exception:
                parts.append(f"mfps{match_fps}")
        if planner_api_model:
            parts.append(f"planner_{planner_api_model}")
        sampler_variant = "__".join(parts)
        sampler_variant = re.sub(r"[^0-9A-Za-z._-]+", "_", sampler_variant).strip("_")
    sampler_label = f"{sampler_type}_{sampler_variant}" if sampler_variant else sampler_type
    config_dir_name = f"{sampler_label}_{agent_type}_{num_frames}frames"
    base_exp_dir = os.path.join(output_path, dataset_name, model_name, config_dir_name)
    run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = os.path.join(base_exp_dir, run_timestamp)
    os.makedirs(output_dir, exist_ok=True)
    print(f"Output dir: {output_dir}")
    
    # Load dataset and preprocess results
    dataset = load_dataset(dataset_path, dataset_name)
    # Resolve video dir (default: 'videos')
    video_dir = DATASET_CONFIG.get(dataset_name, {}).get('video_dir', 'videos')
    video_base_path = os.path.join(dataset_path, video_dir)
    
    # Find preprocess session
    file_manager = OutputFileManager(output_path, dataset_name)
    
    if preprocess_session:
        # Use specified session dir
        session_dir = preprocess_session
        if not os.path.exists(session_dir):
            raise FileNotFoundError(f"Preprocess session directory not found: {session_dir}")
        print(f"Using preprocess session: {session_dir}")
    else:
        # Auto-pick latest session
        session_dir = file_manager.find_latest_session(num_frames, sampler_type, variant=sampler_variant)
        if not session_dir:
            raise FileNotFoundError(f"Preprocess results not found for sampler={sampler_type}. Run preprocess first.")
        print(f"Using latest preprocess session: {session_dir}")
    
    # Load session files
    file_paths = file_manager.get_session_files(session_dir, sampler_type)
    
    # Verify completeness
    completeness = file_manager.verify_session_completeness(session_dir, sampler_type)
    if not all(completeness.values()):
        missing_files = [k for k, v in completeness.items() if not v]
        raise FileNotFoundError(f"Preprocess session incomplete, missing: {missing_files}")
    
    # Load preprocess data
    all_precomputed_indices = file_manager.load_indices(session_dir)
    with open(file_paths['timing'], 'r') as f:
        all_precomputed_times = json.load(f)
    
    if all_precomputed_indices is None:
        raise FileNotFoundError(f"Failed to load indices file: {file_paths['indices']}")
    
    print(f"Loaded preprocess results: {len(all_precomputed_indices)} tasks")
    
    # Load model
    model, processor = initialize_model(model_config['model_path'], backend=model_config.get("vlm_backend"))
    # Pass preprocess session info to agent
    agent_config = dict(agent_config or {})
    agent_config.update({
        "preprocess_session": session_dir,
        "sampler_type": sampler_type,
        "num_frames": num_frames,
        "frame_allocation_mode": frame_allocation_mode,
        "match_fps": match_fps,
        "planner_api_model": planner_api_model,
    })
    agent = get_agent(agent_type, model=model, processor=processor, **agent_config)
    
    # Main loop
    results = []
    total_sampling_time = sum(all_precomputed_times)
    total_agent_time = 0.0
    correct_count = 0
    
    for i, task in enumerate(tqdm(dataset, desc="Inferring...")):
        sampling_time = all_precomputed_times[i]
        frame_indices = all_precomputed_indices[i]
        video_path = os.path.join(video_base_path, task["video_path"])
        
        try:
            # Run agent
            agent_result = agent.infer(
                task=task,
                video_path=video_path,
                frame_indices=frame_indices,
                generation_kwargs=generation_kwargs,
                task_index=i
            )
            
            # Judge correctness
            final_answer = agent_result["final_answer"]
            is_correct, predicted_letter = judge_correct(final_answer, task)
            if is_correct:
                correct_count += 1
            
            # Timing
            agent_inference_time = agent_result["times"]["agent_inference_time"]
            total_agent_time += agent_inference_time
            
            # Build result record
            row = {
                'task_id': task.get('id', i),
                'prediction': final_answer,
                'prediction_letter': predicted_letter,
                'ground_truth': task.get('correct_choice'),
                'is_correct': is_correct,
                'iterations': agent_result.get("iterations", 1),
                'times': {
                    'sampling_time': sampling_time,
                    'agent_inference_time': agent_inference_time,
                    'end_to_end_time': sampling_time + agent_inference_time
                }
            }
            
            # Agent fields
            row.update({
                'tokens': agent_result.get("tokens", {}),
                'debug': agent_result.get("debug", {})
            })
            
        except Exception as e:
            print(f"Error on task {i}: {e}")
            row = {
                'task_id': task.get('id', i),
                'error': str(e)[:1000],
                'times': {
                    'sampling_time': sampling_time,
                    'agent_inference_time': 0.0,
                    'end_to_end_time': sampling_time
                }
            }
        
        results.append(row)
        
        # Save incrementally
        result_file = os.path.join(output_dir, 'detailed_results.json')
        with open(result_file, 'w', encoding='utf-8') as f:
            json.dump(results, f, ensure_ascii=False, indent=2)
    
    # Summary
    accuracy_stats = calculate_accuracy(results)
    avg_sampling_time = total_sampling_time / len(dataset) if dataset else 0
    avg_agent_time = total_agent_time / len(dataset) if dataset else 0
    avg_e2e_time = (total_sampling_time + total_agent_time) / len(dataset) if dataset else 0
    
    # Base summary
    summary = {
        'experiment_config': {
            'sampler': sampler_type,
            'agent': agent_type,
            'num_frames': num_frames,
            'frame_allocation_mode': frame_allocation_mode,
            'match_fps': match_fps,
            'planner_api_model': planner_api_model,
            'dataset': dataset_name,
            'generation_kwargs': generation_kwargs,
            'agent_config': agent_config
        },
        'performance': accuracy_stats,
        'timing_summary_seconds': {
            'total_sampling_time': total_sampling_time,
            'total_agent_inference_time': total_agent_time,
            'total_end_to_end_time': total_sampling_time + total_agent_time,
            'average_sampling_time': avg_sampling_time,
            'average_agent_inference_time': avg_agent_time,
            'average_end_to_end_time': avg_e2e_time
        }
    }
    
    # Save summary report
    summary_file = os.path.join(output_dir, 'summary_report.json')
    with open(summary_file, 'w', encoding='utf-8') as f:
        json.dump(summary, f, ensure_ascii=False, indent=2)
    
    print("\n======== Inference completed ========")
    print(f"Accuracy: {accuracy_stats['accuracy_percent']:.2f}% ({accuracy_stats['correct_count']}/{accuracy_stats['total_tasks']})")
    print(f"Avg sampling time: {avg_sampling_time:.4f}s")
    print(f"Avg inference time: {avg_agent_time:.4f}s")
    print(f"Avg end-to-end time: {avg_e2e_time:.4f}s")
    print(f"Detailed results: {result_file}")
    print(f"Summary report: {summary_file}")
    
    return summary
