#!/usr/bin/env python3
"""
Positional Bias Experiment Runner

This script runs the positional bias experiment to analyze how adding spaces
to prompts affects text embeddings and model performance.
"""

import os
import sys
import argparse
import torch
import numpy as np
from typing import List, Dict
import json
from datetime import datetime

# Add the project root to the path
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))

from detectron2.config import get_cfg
from detectron2.engine import DefaultTrainer
from detectron2.utils.logger import setup_logger

# Import the analyzer
from utils.positional_bias_analyzer import PositionalBiasAnalyzer, compare_experiments, plot_experiment_comparison

def setup_experiment_config(base_config_path: str, space_prefix_length: int, 
                           random_space_length: bool = False, 
                           min_space_length: int = 1, 
                           max_space_length: int = 20) -> str:
    """
    Create a temporary config file for the experiment.
    
    Args:
        base_config_path: Path to the base configuration file
        space_prefix_length: Number of spaces to add as prefix (for fixed mode)
        random_space_length: Whether to use random space lengths
        min_space_length: Minimum number of spaces (for random mode)
        max_space_length: Maximum number of spaces (for random mode)
        
    Returns:
        Path to the temporary config file
    """
    # Read base config
    with open(base_config_path, 'r') as f:
        config_content = f.read()
    
    # Add positional bias experiment settings
    experiment_config = f"""
# Positional Bias Experiment Settings
MODEL:
  POSITIONAL_BIAS_EXPERIMENT: True
  SPACE_PREFIX_LENGTH: {space_prefix_length}
  RANDOM_SPACE_LENGTH: {str(random_space_length).lower()}
  MIN_SPACE_LENGTH: {min_space_length}
  MAX_SPACE_LENGTH: {max_space_length}

{config_content}
"""
    
    # Create temporary config file
    if random_space_length:
        temp_config_path = f"temp_config_random_{min_space_length}_{max_space_length}.yaml"
    else:
        temp_config_path = f"temp_config_space_{space_prefix_length}.yaml"
    
    with open(temp_config_path, 'w') as f:
        f.write(experiment_config)
    
    return temp_config_path

def run_single_experiment(config_path: str, space_prefix_length: int, 
                         output_dir: str, random_space_length: bool = False,
                         min_space_length: int = 1, max_space_length: int = 20) -> Dict:
    """
    Run a single experiment with a specific space prefix length.
    
    Args:
        config_path: Path to the configuration file
        space_prefix_length: Number of spaces to add as prefix (for fixed mode)
        output_dir: Directory to save results
        random_space_length: Whether to use random space lengths
        min_space_length: Minimum number of spaces (for random mode)
        max_space_length: Maximum number of spaces (for random mode)
        
    Returns:
        Experiment results dictionary
    """
    if random_space_length:
        print(f"\n{'='*50}")
        print(f"Running experiment with random space lengths: [{min_space_length}, {max_space_length}]")
        print(f"{'='*50}")
        experiment_name = f"random_{min_space_length}_{max_space_length}"
    else:
        print(f"\n{'='*50}")
        print(f"Running experiment with space prefix length: {space_prefix_length}")
        print(f"{'='*50}")
        experiment_name = f"space_length_{space_prefix_length}"
    
    # Setup configuration
    cfg = get_cfg()
    cfg.merge_from_file(config_path)
    
    # Create output directory for this experiment
    experiment_dir = os.path.join(output_dir, experiment_name)
    os.makedirs(experiment_dir, exist_ok=True)
    
    # Setup logger
    logger = setup_logger(output_dir=experiment_dir)
    
    try:
        # Initialize model (this will trigger the positional bias experiment)
        from models.glee_model import GLEE_Model
        from models.matcher import HungarianMatcher
        
        # Create dummy matcher and video info for initialization
        matcher = HungarianMatcher(cost_class=1, cost_bbox=5, cost_giou=2)
        video_info = {'len': 1, 'bz': 1}
        contras_mean = torch.zeros(1)
        
        # Initialize model
        model = GLEE_Model(cfg, matcher, device='cuda', video_info=video_info, contras_mean=contras_mean)
        model.eval()
        
        # Create dummy prompts for testing
        batch_size = 2
        grounding_prompts = ["person", "car"]
        reference_prompts = [
            "positive person", "negative person", "positive car", "negative car",
            "positive person", "negative person", "positive car", "negative car"
        ]  # 6 x batch_size format
        
        all_prompts = grounding_prompts + reference_prompts
        
        # Get embeddings with positional bias experiment enabled
        with torch.no_grad():
            token_x, eot, attn_mask = model.get_text_embedding(
                all_prompts, 
                positional_bias_experiment=True, 
                space_prefix_length=space_prefix_length,
                random_space_length=random_space_length,
                min_space_length=min_space_length,
                max_space_length=max_space_length
            )
        
        # Analyze results
        analyzer = PositionalBiasAnalyzer()
        
        # Split embeddings for analysis
        original_embeddings = token_x[:batch_size]  # grounding prompts
        modified_embeddings = token_x[batch_size:]  # reference prompts with spaces
        
        # Analyze similarity
        similarity_results = analyzer.analyze_embedding_similarity(
            original_embeddings, 
            modified_embeddings,
            grounding_prompts,
            reference_prompts
        )
        
        # Analyze positional impact
        positional_results = analyzer.analyze_positional_shift_impact(
            token_x, 
            batch_size=batch_size,
            num_references_per_batch=6
        )
        
        # Generate plots
        analyzer.plot_similarity_analysis(
            similarity_results, 
            save_path=os.path.join(experiment_dir, 'similarity_analysis.png')
        )
        
        analyzer.plot_positional_analysis(
            positional_results, 
            save_path=os.path.join(experiment_dir, 'positional_analysis.png')
        )
        
        # Generate report
        if random_space_length:
            report = analyzer.generate_experiment_report(
                similarity_results, 
                positional_results, 
                f"random_{min_space_length}_{max_space_length}"
            )
        else:
            report = analyzer.generate_experiment_report(
                similarity_results, 
                positional_results, 
                space_prefix_length
            )
        
        # Save report
        with open(os.path.join(experiment_dir, 'experiment_report.txt'), 'w') as f:
            f.write(report)
        
        print(report)
        
        # Save results as JSON
        results = {
            'space_prefix_length': space_prefix_length,
            'random_space_length': random_space_length,
            'min_space_length': min_space_length,
            'max_space_length': max_space_length,
            'similarity_results': similarity_results,
            'positional_results': positional_results,
            'timestamp': datetime.now().isoformat()
        }
        
        with open(os.path.join(experiment_dir, 'results.json'), 'w') as f:
            json.dump(results, f, indent=2, default=str)
        
        return results
        
    except Exception as e:
        if random_space_length:
            print(f"Error in experiment with random space lengths [{min_space_length}, {max_space_length}]: {e}")
        else:
            print(f"Error in experiment with space length {space_prefix_length}: {e}")
        return None

def run_comparison_experiments(base_config_path: str, space_lengths: List[int], 
                             output_dir: str, random_experiments: List[Dict] = None):
    """
    Run multiple experiments with different space prefix lengths and compare results.
    
    Args:
        base_config_path: Path to the base configuration file
        space_lengths: List of space prefix lengths to test (fixed mode)
        output_dir: Directory to save results
        random_experiments: List of dicts with random experiment configs
                           [{'min': 1, 'max': 10}, {'min': 5, 'max': 20}, ...]
    """
    print(f"Running positional bias experiments with space lengths: {space_lengths}")
    if random_experiments:
        print(f"Random experiments: {random_experiments}")
    
    # Create main output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Run experiments
    experiment_results = []
    successful_results = []
    
    # Run fixed length experiments
    for space_length in space_lengths:
        # Create temporary config
        temp_config_path = setup_experiment_config(base_config_path, space_length, random_space_length=False)
        
        try:
            # Run experiment
            result = run_single_experiment(temp_config_path, space_length, output_dir, random_space_length=False)
            if result is not None:
                successful_results.append(result)
                experiment_results.append(result['similarity_results'])
            
        finally:
            # Clean up temporary config
            if os.path.exists(temp_config_path):
                os.remove(temp_config_path)
    
    # Run random length experiments
    if random_experiments:
        for exp_config in random_experiments:
            min_len = exp_config['min']
            max_len = exp_config['max']
            
            # Create temporary config
            temp_config_path = setup_experiment_config(
                base_config_path, 5, random_space_length=True, 
                min_space_length=min_len, max_space_length=max_len
            )
            
            try:
                # Run experiment
                result = run_single_experiment(
                    temp_config_path, 5, output_dir, 
                    random_space_length=True, min_space_length=min_len, max_space_length=max_len
                )
                if result is not None:
                    successful_results.append(result)
                    experiment_results.append(result['similarity_results'])
                
            finally:
                # Clean up temporary config
                if os.path.exists(temp_config_path):
                    os.remove(temp_config_path)
    
    # Compare experiments
    if len(successful_results) > 1:
        print(f"\n{'='*50}")
        print("COMPARING EXPERIMENTS")
        print(f"{'='*50}")
        
        # Create experiment labels for comparison
        experiment_labels = []
        for res in successful_results:
            if res['random_space_length']:
                label = f"random_{res['min_space_length']}_{res['max_space_length']}"
            else:
                label = f"fixed_{res['space_prefix_length']}"
            experiment_labels.append(label)
        
        comparison_results = compare_experiments(experiment_results, experiment_labels)
        
        # Plot comparison
        plot_experiment_comparison(
            comparison_results, 
            save_path=os.path.join(output_dir, 'experiment_comparison.png')
        )
        
        # Save comparison results
        with open(os.path.join(output_dir, 'comparison_results.json'), 'w') as f:
            json.dump(comparison_results, f, indent=2, default=str)
        
        # Print summary
        print("\nEXPERIMENT SUMMARY:")
        print(f"Experiments Tested: {comparison_results['space_lengths']}")
        print(f"Mean Cosine Similarities: {[f'{s:.4f}' for s in comparison_results['mean_cosine_similarities']]}")
        print(f"Mean L2 Distances: {[f'{d:.4f}' for d in comparison_results['mean_l2_distances']]}")
        
        # Find best performing configuration
        best_idx = np.argmax(comparison_results['mean_cosine_similarities'])
        best_config = comparison_results['space_lengths'][best_idx]
        best_similarity = comparison_results['mean_cosine_similarities'][best_idx]
        
        print(f"\nBest Configuration:")
        print(f"Config: {best_config}")
        print(f"Cosine Similarity: {best_similarity:.4f}")
        
        # Save summary
        summary = {
            'experiment_configs': comparison_results['space_lengths'],
            'mean_cosine_similarities': comparison_results['mean_cosine_similarities'],
            'mean_l2_distances': comparison_results['mean_l2_distances'],
            'best_config': best_config,
            'best_cosine_similarity': best_similarity,
            'timestamp': datetime.now().isoformat()
        }
        
        with open(os.path.join(output_dir, 'experiment_summary.json'), 'w') as f:
            json.dump(summary, f, indent=2, default=str)

def main():
    parser = argparse.ArgumentParser(description="Run positional bias experiments")
    parser.add_argument("--config", type=str, required=True,
                       help="Path to base configuration file")
    parser.add_argument("--output-dir", type=str, default="./positional_bias_results",
                       help="Output directory for results")
    parser.add_argument("--space-lengths", type=int, nargs="+", 
                       default=[1, 5, 10, 15, 20],
                       help="Space prefix lengths to test (fixed mode)")
    parser.add_argument("--single-experiment", type=int, default=None,
                       help="Run single experiment with specified space length")
    parser.add_argument("--random-experiments", type=str, nargs="+",
                       help="Random experiments in format 'min:max' (e.g., '1:10' '5:20')")
    parser.add_argument("--random-only", action="store_true",
                       help="Run only random experiments, skip fixed length experiments")
    
    args = parser.parse_args()
    
    if args.single_experiment is not None:
        # Run single experiment
        temp_config_path = setup_experiment_config(args.config, args.single_experiment)
        try:
            run_single_experiment(temp_config_path, args.single_experiment, args.output_dir)
        finally:
            if os.path.exists(temp_config_path):
                os.remove(temp_config_path)
    else:
        # Parse random experiments
        random_experiments = None
        if args.random_experiments:
            random_experiments = []
            for exp_str in args.random_experiments:
                try:
                    min_val, max_val = map(int, exp_str.split(':'))
                    random_experiments.append({'min': min_val, 'max': max_val})
                except ValueError:
                    print(f"Invalid random experiment format: {exp_str}. Use 'min:max' format.")
                    continue
        
        # Run comparison experiments
        if args.random_only:
            space_lengths = []
        else:
            space_lengths = args.space_lengths
            
        run_comparison_experiments(args.config, space_lengths, args.output_dir, random_experiments)
    
    print(f"\nExperiments completed. Results saved to: {args.output_dir}")

if __name__ == "__main__":
    main() 