#!/usr/bin/env python3
"""
Poster QA Experiments Entry Point
Batch processing support aligned with banner experiments
"""

import os
import sys
import argparse
import json
import shutil
from pathlib import Path
from datetime import datetime

# Add original src directory to path
_original_src_path = Path(__file__).parent.parent.parent / "src"
if str(_original_src_path) not in sys.path:
    sys.path.insert(0, str(_original_src_path))

# Add current directory to path
_current_dir = Path(__file__).parent
if str(_current_dir) not in sys.path:
    sys.path.insert(0, str(_current_dir))

from html_ad_workflow import load_config
from poster_config import PosterInteractionConfig, PosterGenerationResult, get_config_by_id, generate_all_configs
from poster_data_loader import load_poster_samples, read_prompt
from poster_qa_manager import PosterQAManager
from poster_judge_evaluator import PosterJudgeEvaluator


def process_single_sample(
    sample,
    config_path: str,
    llm_config_path: str,
    interaction_config: PosterInteractionConfig,
    output_base_dir: str
) -> dict:
    """
    Process a single poster sample
    
    Args:
        sample: PosterSample object
        config_path: Config file path
        llm_config_path: LLM config file path
        experiment_config: Experiment configuration
        output_base_dir: Output base directory
        
    Returns:
        Result dictionary
    """
    print("\n" + "=" * 60)
    print(f"🎨 Processing {sample.index}: {sample.brand_name}")
    print(f"📋 Config: {interaction_config.id}")
    print("=" * 60)
    
    result = {
        "index": sample.index,
        "brand_name": sample.brand_name,
        "prompt_file": str(sample.prompt_file),
        "logo_file": str(sample.logo_file) if sample.logo_file else None,
        "success": False,
        "generated_image_path": None,
        "ad_text": None,
        "conversation_history": None,
        "design_plan": None,
        "token_stats": None,
        "answer_image_path": str(sample.answer_image_path),
        "judge_evaluation": None,
        "error": None,
        "questions_asked": None,
    }
    
    try:
        # Read prompt
        prompt_text = read_prompt(sample.prompt_file)
        if not prompt_text:
            result["error"] = "Failed to read prompt"
            return result
        
        print(f"📝 Prompt: {prompt_text[:100]}...")
        
        # Create output directory for this sample
        brand_output_dir = os.path.join(output_base_dir, f"{sample.index}_{sample.brand_name}")
        os.makedirs(brand_output_dir, exist_ok=True)
        
        # Copy logo to output directory if exists
        logo_dest = None
        if sample.logo_file is not None and sample.logo_file.exists() and sample.logo_file.is_file():
            try:
                logo_dest = os.path.join(brand_output_dir, sample.logo_file.name)
                shutil.copy2(str(sample.logo_file), logo_dest)
                print(f"📋 Logo copied to: {logo_dest}")
            except Exception as e:
                print(f"⚠️ Error copying logo file: {e}")
                logo_dest = None
        else:
            print(f"⚠️ No logo file for this sample (logo is optional)")
        
        # Copy ground truth image to output directory
        if sample.answer_image_path and os.path.exists(sample.answer_image_path):
            gt_dest = os.path.join(brand_output_dir, "ground_truth.png")
            shutil.copy2(sample.answer_image_path, gt_dest)
            print(f"📷 Ground truth image copied to: {gt_dest}")
            result["ground_truth_image"] = gt_dest
        
        # Load config
        config_path_str = str(Path(config_path).resolve())
        llm_config_path_str = str(Path(llm_config_path).resolve())
        
        # Verify config files exist
        if not os.path.exists(config_path_str):
            result["error"] = f"Config file not found: {config_path_str}"
            return result
        if not os.path.exists(llm_config_path_str):
            result["error"] = f"LLM config file not found: {llm_config_path_str}"
            return result
        
        config = load_config(config_path_str, llm_config_path_str)
        
        # Verify KEYS section exists after loading
        if not config.has_section("KEYS"):
            result["error"] = f"KEYS section not found in config. Available sections: {config.sections()}"
            return result
        
        # Ensure required sections exist
        if not config.has_section("SETTING"):
            config.add_section("SETTING")
        config.set("SETTING", "output_folder", brand_output_dir)
        
        # Initialize QA Manager
        qa_manager = PosterQAManager(config, interaction_config)
        
        # Generate poster
        generation_result = qa_manager.generate_poster(
            item_description=prompt_text,
            answer_image_path=str(sample.answer_image_path),
            logo_path=logo_dest,
            output_dir=brand_output_dir
        )
        
        if not generation_result.success:
            result["error"] = generation_result.error
            return result
        
        # Update result
        result["success"] = True
        result["generated_image_path"] = str(generation_result.generated_image_path) if generation_result.generated_image_path else None
        result["ad_text"] = generation_result.ad_text
        result["conversation_history"] = generation_result.conversation_history
        result["design_plan"] = generation_result.design_plan
        result["token_stats"] = generation_result.token_stats
        result["questions_asked"] = generation_result.questions_asked
        
        # Run evaluation
        if result["generated_image_path"] and os.path.exists(result["generated_image_path"]):
            try:
                print(f"\n📊 Computing Judge evaluation (GPT-4o)...")
                
                # Initialize judge evaluator (use GPT-4o for evaluation, not the experiment model)
                judge_evaluator = PosterJudgeEvaluator(config, model_version="gpt41")
                
                # Evaluate poster
                judge_result = judge_evaluator.evaluate_similarity(
                    reference_image_path=result["answer_image_path"],
                    generated_image_path=result["generated_image_path"],
                    item_description=prompt_text
                )
                
                # Save evaluation result
                eval_output_path = os.path.join(brand_output_dir, "judge_evaluation.json")
                with open(eval_output_path, 'w', encoding='utf-8') as f:
                    json.dump(judge_result, f, indent=2, ensure_ascii=False)
                
                # Store in result (format aligned with banner)
                result["judge_evaluation"] = judge_result
                
                avg_score = judge_result.get("average_score", 0.0)
                print(f"✅ Judge Evaluation (Similarity):")
                print(f"   Average Score: {avg_score:.4f} ({avg_score*100:.2f}%)")
                if "scores" in judge_result:
                    scores = judge_result["scores"]
                    print(f"   Overall Color: {scores.get('Overall_Color', 0):.1f}")
                    print(f"   Layout & Composition: {scores.get('Layout_Composition', 0):.1f}")
                    print(f"   Button Style: {scores.get('Button_Style', 0):.1f}")
                    print(f"   Image Content: {scores.get('Image_Content', 0):.1f}")
                    print(f"   Text Content: {scores.get('Text_Content', 0):.1f}")
                
            except Exception as e:
                print(f"⚠️ Error computing Judge evaluation: {e}")
                import traceback
                traceback.print_exc()
        
        # Save Q&A conversation
        qa_file = os.path.join(brand_output_dir, "qa_conversation.json")
        with open(qa_file, 'w', encoding='utf-8') as f:
            json.dump({
                "index": sample.index,
                "brand_name": sample.brand_name,
                "prompt": prompt_text,
                "answer_image": result["answer_image_path"],
                "config_id": interaction_config.id,
                "agent_format": interaction_config.agent_format.value if interaction_config.agent_format else None,
                "question_format": interaction_config.question_format.value if interaction_config.question_format else None,
                "max_qa_cycles": interaction_config.max_qa_cycles,
                "max_questions_per_batch": interaction_config.max_questions_per_batch,
                "questions_asked": result["questions_asked"],
                "conversation_history": [
                    {
                        "question": qa.get("question", "") if isinstance(qa, dict) else (qa[0] if isinstance(qa, (list, tuple)) and len(qa) > 0 else ""),
                        "answer": qa.get("answer", "") if isinstance(qa, dict) else (qa[1] if isinstance(qa, (list, tuple)) and len(qa) > 1 else ""),
                        "questioner_tokens": qa.get("questioner_tokens", {}) if isinstance(qa, dict) else {},
                        "answerer_tokens": qa.get("answerer_tokens", {}) if isinstance(qa, dict) else {},
                        "format": qa.get("format", "open_text") if isinstance(qa, dict) else "open_text"
                    } for qa in (result["conversation_history"] or [])
                ],
                "design_plan": result["design_plan"],
                "token_stats": result["token_stats"]
            }, f, indent=2, ensure_ascii=False, default=str)
        
        print(f"✅ Successfully generated poster for {sample.brand_name}")
        print(f"   Generated image: {result['generated_image_path']}")
        if result.get("judge_evaluation") and "average_score" in result["judge_evaluation"]:
            avg_score = result['judge_evaluation']['average_score']
            print(f"   Judge Evaluation Average: {avg_score:.4f}")
        print(f"   Q&A saved: {qa_file}")
        
    except Exception as e:
        result["error"] = str(e)
        print(f"❌ Error processing {sample.brand_name}: {e}")
        import traceback
        traceback.print_exc()
    
    return result


def run_experiment(
    logos_dir: str,
    prompt_dir: str,
    answer_base_dir: str,
    config_path: str,
    llm_config_path: str,
    mode: str,
    num_samples: int = 10,
    start_index: int = 1,
    max_qa_cycles: int = 2,
    max_questions_per_batch: int = 5,
    model_version: str = "gpt41",
    question_agent_model_version: str = None,
    answer_agent_model_version: str = None,
    output_dir: str = None,
    mpc_enabled: bool = False,
    use_logos: bool = True
) -> dict:
    """
    Run poster experiment (batch processing)
    
    Args:
        logos_dir: Logos directory path
        prompt_dir: Prompt files directory path
        answer_base_dir: Answer images base directory
        config_path: Config file path
        llm_config_path: LLM config file path
        mode: Config ID (e.g., "MPQC_Adaptive")
        num_samples: Number of samples
        start_index: Start index
        max_qa_cycles: Maximum QA cycles
        max_questions_per_batch: Maximum questions per batch
        model_version: Global model version for plan operations
        question_agent_model_version: Model version for question agent (fallback to model_version)
        answer_agent_model_version: Model version for answer agent (fallback to model_version)
        output_dir: Output directory
        use_logos: Use logo-based setup (prompts in logos_dir) or no-logo setup
        
    Returns:
        Experiment result dictionary
    """
    # Get config by ID
    interaction_config = get_config_by_id(mode)
    
    # Override parameters if provided
    if max_qa_cycles is not None:
        interaction_config.max_qa_cycles = max_qa_cycles
    if max_questions_per_batch is not None:
        interaction_config.max_questions_per_batch = max_questions_per_batch
    if model_version is not None:
        interaction_config.model_version = model_version
    if question_agent_model_version is not None:
        interaction_config.question_agent_model_version = question_agent_model_version
    if answer_agent_model_version is not None:
        interaction_config.answer_agent_model_version = answer_agent_model_version
    
    # Create output directory
    if output_dir is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_dir = f"results/{mode}_{timestamp}"
    
    output_base_dir = Path(output_dir)
    output_base_dir.mkdir(parents=True, exist_ok=True)
    
    print("\n" + "=" * 80)
    print(f"🧪 POSTER QA EXPERIMENT")
    print(f"   Config ID: {mode}")
    if interaction_config.agent_format:
        print(f"   Agent Format: {interaction_config.agent_format.value}")
    if interaction_config.question_format:
        print(f"   Question Format: {interaction_config.question_format.value}")
    print(f"   Samples: {num_samples} (starting from {start_index})")
    print(f"   Max QA Cycles: {interaction_config.max_qa_cycles}")
    print(f"   Max Questions per Batch: {interaction_config.max_questions_per_batch}")
    print(f"   Output: {output_base_dir}")
    print("=" * 80)
    
    # Load samples
    samples = load_poster_samples(logos_dir, prompt_dir, answer_base_dir, num_samples, start_index, use_logos=use_logos)
    
    if not samples:
        print("❌ No samples loaded!")
        return {"success": False, "error": "No samples loaded"}
    
    print(f"\n✅ Loaded {len(samples)} samples")
    
    # Resolve config paths (make them absolute)
    config_path_abs = str(Path(config_path).resolve())
    llm_config_path_abs = str(Path(llm_config_path).resolve())
    
    # Verify config files exist
    if not os.path.exists(config_path_abs):
        print(f"❌ Config file not found: {config_path_abs}")
        return {"success": False, "error": f"Config file not found: {config_path_abs}"}
    if not os.path.exists(llm_config_path_abs):
        print(f"❌ LLM config file not found: {llm_config_path_abs}")
        return {"success": False, "error": f"LLM config file not found: {llm_config_path_abs}"}
    
    print(f"📋 Using config: {config_path_abs}")
    print(f"📋 Using LLM config: {llm_config_path_abs}")
    
    # Process each sample
    results = []
    start_time = datetime.now()
    
    for i, sample in enumerate(samples, 1):
        print(f"\n[{i}/{len(samples)}] Processing {sample.index}_{sample.brand_name}...")
        
        result = process_single_sample(
            sample,
            config_path_abs,
            llm_config_path_abs,
            interaction_config,
            str(output_base_dir)
        )
        
        results.append(result)
        
        # Save intermediate results
        results_file = output_base_dir / "batch_results.json"
        with open(results_file, 'w', encoding='utf-8') as f:
            json.dump({
                "summary": {
                    "mode": mode,
                    "total": len(results),
                    "successful": len([r for r in results if r["success"]]),
                    "failed": len([r for r in results if not r["success"]]),
                },
                "results": results
            }, f, indent=2, ensure_ascii=False, default=str)
    
    end_time = datetime.now()
    total_time = (end_time - start_time).total_seconds()
    
    # Calculate statistics
    successful = [r for r in results if r["success"]]
    failed = [r for r in results if not r["success"]]
    
    # Token statistics
    total_token_stats = {
        "questioner_agent": {"input_tokens": 0, "output_tokens": 0, "reasoning_tokens": 0, "total_tokens": 0},
        "answerer_agent": {"input_tokens": 0, "output_tokens": 0, "reasoning_tokens": 0, "total_tokens": 0}
    }
    
    for r in successful:
        if r.get("token_stats"):
            for agent in ["questioner_agent", "answerer_agent"]:
                if agent in r["token_stats"]:
                    for key in ["input_tokens", "output_tokens", "reasoning_tokens", "total_tokens"]:
                        total_token_stats[agent][key] += r["token_stats"][agent].get(key, 0)
    
    # Judge evaluation statistics
    judge_evaluations = [r.get("judge_evaluation") for r in successful if r.get("judge_evaluation") is not None]
    judge_stats = {}
    if judge_evaluations:
        # Collect all scores (5 similarity dimensions)
        all_scores = {
            "Overall_Color": [],
            "Layout_Composition": [],
            "Button_Style": [],
            "Image_Content": [],
            "Text_Content": []
        }
        average_scores = []
        
        for eval_data in judge_evaluations:
            if eval_data and "scores" in eval_data:
                scores = eval_data["scores"]
                for key in all_scores.keys():
                    if key in scores:
                        all_scores[key].append(scores[key])
                if "average_score" in eval_data:
                    average_scores.append(eval_data["average_score"])
        
        # Calculate statistics
        judge_stats = {
            "average_scores": {},
            "overall_average": sum(average_scores) / len(average_scores) if average_scores else 0.0,
            "count": len(judge_evaluations)
        }
        
        for key, values in all_scores.items():
            if values:
                judge_stats["average_scores"][key] = {
                    "average": sum(values) / len(values),
                    "min": min(values),
                    "max": max(values),
                    "count": len(values)
                }
    
    # Create summary
    summary = {
        "mode": mode,
        "total": len(results),
        "successful": len(successful),
        "failed": len(failed),
        "total_time_seconds": total_time,
        "start_time": start_time.isoformat(),
        "end_time": end_time.isoformat(),
        "max_qa_cycles": max_qa_cycles,
        "max_questions_per_batch": max_questions_per_batch,
        "model_version": model_version,
        "token_stats": total_token_stats,
        "judge_evaluation_stats": judge_stats
    }
    
    # Save final results
    results_file = output_base_dir / "batch_results.json"
    with open(results_file, 'w', encoding='utf-8') as f:
        json.dump({
            "summary": summary,
            "results": results
        }, f, indent=2, ensure_ascii=False, default=str)
    
    # Print summary
    print("\n" + "=" * 80)
    print("📊 EXPERIMENT SUMMARY")
    print("=" * 80)
    print(f"✅ Successful: {len(successful)}/{len(results)}")
    print(f"❌ Failed: {len(failed)}/{len(results)}")
    print(f"⏱️  Total time: {total_time:.2f} seconds")
    print(f"📁 Output directory: {output_base_dir}")
    
    if judge_stats and judge_stats.get("count", 0) > 0:
        print(f"\n📊 Judge Evaluation (GPT-4o Similarity):")
        print(f"   Overall Average: {judge_stats['overall_average']*100:.2f}%")
        print(f"   Count: {judge_stats['count']}")
        if "average_scores" in judge_stats:
            print(f"   Dimension Averages:")
            for key, stats in judge_stats["average_scores"].items():
                print(f"     {key}: {stats['average']*100:.2f}% (min: {stats['min']*100:.2f}%, max: {stats['max']*100:.2f}%)")
    
    print(f"\n💾 Results saved to: {results_file}")
    
    return {
        "summary": summary,
        "results": results
    }


def main():
    parser = argparse.ArgumentParser(description="Poster QA Experiments")
    parser.add_argument("--logos-dir", type=str, required=True,
                       help="Path to logos directory")
    parser.add_argument("--prompt-dir", type=str, required=True,
                       help="Path to prompt files directory")
    parser.add_argument("--answer-base-dir", type=str, required=True,
                       help="Base directory for answer images")
    parser.add_argument("--config", type=str,
                       default="../config/config.ini",
                       help="Path to config file")
    parser.add_argument("--llm-config", type=str,
                       default="../config/config_llm.ini",
                       help="Path to LLM config file")
    parser.add_argument("--mode", type=str, required=True,
                       help="Config ID (e.g., 'MPQC_Adaptive'). Use 'list-configs' to see all available config IDs.")
    parser.add_argument("--num-samples", type=int, default=10,
                       help="Number of samples to process")
    parser.add_argument("--start-index", type=int, default=1,
                       help="Starting index (1-based)")
    parser.add_argument("--max-qa-cycles", type=int, default=2,
                       help="Maximum number of QA cycles")
    parser.add_argument("--max-questions-per-batch", type=int, default=5,
                       help="Maximum number of questions per batch")
    parser.add_argument("--model-version", type=str, default="gpt41",
                       choices=["gpt41", "gpt52", "gemini25"],
                       help="Global model version for plan operations: gpt41, gpt52, or gemini25")
    parser.add_argument("--question-agent-model-version", type=str, default=None,
                       choices=["gpt41", "gpt52", "gemini25"],
                       help="Model version for question agent (fallback to --model-version)")
    parser.add_argument("--answer-agent-model-version", type=str, default=None,
                       choices=["gpt41", "gpt52", "gemini25"],
                       help="Model version for answer agent (fallback to --model-version)")
    parser.add_argument("--output-dir", type=str, default=None,
                       help="Output directory (default: results/{mode}_{timestamp})")
    parser.add_argument("--use-logos", action="store_true",
                       help="Use logo-based setup (prompts in logos_dir, banner-style)")
    
    args = parser.parse_args()
    
    run_experiment(
        logos_dir=args.logos_dir,
        prompt_dir=args.prompt_dir,
        answer_base_dir=args.answer_base_dir,
        config_path=args.config,
        llm_config_path=args.llm_config,
        mode=args.mode,
        num_samples=args.num_samples,
        start_index=args.start_index,
        max_qa_cycles=args.max_qa_cycles,
        max_questions_per_batch=args.max_questions_per_batch,
        model_version=args.model_version,
        question_agent_model_version=args.question_agent_model_version,
        answer_agent_model_version=args.answer_agent_model_version,
        output_dir=args.output_dir,
        use_logos=args.use_logos
    )


if __name__ == "__main__":
    main()
