import os
import json
import subprocess
import sys
import argparse
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Tuple


DEFAULT_OUTPUT_BASE = Path("./batch_refinement_results")
DEFAULT_REFINER_SCRIPT = "intelligent_survey_refiner_new.py"


DEFAULT_ITERATIONS = [1, 2, 3]
RETRIEVAL_METHOD = "embed"

def setup_output_directories(output_base: Path, iterations: list):
    """Create output directory structure"""
    output_base.mkdir(parents=True, exist_ok=True)
    for iteration in iterations:
        (output_base / f"iter_{iteration}").mkdir(parents=True, exist_ok=True)
    (output_base / "analysis").mkdir(parents=True, exist_ok=True)
    (output_base / "best_results").mkdir(parents=True, exist_ok=True)
    return output_base

def parse_file_lists(svx_list: str, ref_list: str) -> List[Tuple[str, str]]:
    """Parse survey and reference file lists"""
    surveys = []
    
    svx_files = []
    for item in svx_list.split(','):
        item = item.strip()
        if item:
            svx_path = Path(item)
            if not svx_path.exists():
                print(f"[WARNING] Survey file not found: {svx_path}")
                continue
            svx_files.append(str(svx_path))
    
    # Parse reference list
    ref_dirs = []
    for item in ref_list.split(','):
        item = item.strip()
        if item:
            ref_path = Path(item)
            if not ref_path.exists():
                print(f"[WARNING] Reference directory not found: {ref_path}")
                continue
            ref_dirs.append(str(ref_path))
    
    # Match surveys with references
    if len(svx_files) != len(ref_dirs):
        print(f"[ERROR] Number of survey files ({len(svx_files)}) must match number of reference directories ({len(ref_dirs)})")
        sys.exit(1)
    
    for svx_file, ref_dir in zip(svx_files, ref_dirs):
        surveys.append((svx_file, ref_dir))
    
    return surveys

def get_topic_name_from_path(survey_path: str) -> str:
    """Extract topic name from survey file path"""
    path = Path(survey_path)
    # Try to get more descriptive name from parent directory structure
    if len(path.parts) >= 3:
        for i, part in enumerate(path.parts):
            if part == "outputs" and i + 1 < len(path.parts):
                return path.parts[i + 1].replace(" ", "_").replace("/", "_")
    return path.stem

def run_refinement(survey_path: str, ref_dir: str, topic_name: str, 
                  max_iterations: int, output_dir: Path, refiner_script: str) -> Dict:
    """Run single refinement process"""
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_file = output_dir / f"{topic_name}_iter{max_iterations}_{timestamp}.tex"
    progress_file = output_dir / f"{topic_name}_iter{max_iterations}_{timestamp}_progress.json"
    
    cmd = [
        "python3", refiner_script,
        "--survey", survey_path,
        "--refdir", ref_dir,
        "--retrieval", RETRIEVAL_METHOD,
        "--max-iterations", str(max_iterations),
        "--output", str(output_file),
        "--save-progress", str(progress_file),
        "--no-rollback-on-no-improvement"  # Keep neutral changes
    ]
    
    print(f"    Running: {' '.join(cmd)}")
    
    try:
        result = subprocess.run(cmd, capture_output=True, text=True, timeout=1800)  # 30 min timeout
        
        return {
            "success": result.returncode == 0,
            "output_file": str(output_file),
            "progress_file": str(progress_file),
            "stdout": result.stdout,
            "stderr": result.stderr,
            "command": cmd,
            "timestamp": timestamp
        }
        
    except subprocess.TimeoutExpired:
        return {
            "success": False,
            "error": "Timeout after 30 minutes",
            "command": cmd,
            "timestamp": timestamp
        }
    except Exception as e:
        return {
            "success": False,
            "error": str(e),
            "command": cmd,
            "timestamp": timestamp
        }

def extract_scores_from_progress(progress_file: Path) -> Dict:
    """Extract evaluation scores from progress file"""
    try:
        if not progress_file.exists():
            return {}
            
        with open(progress_file, 'r', encoding='utf-8') as f:
            data = json.load(f)
            
        if "iterations" not in data or not data["iterations"]:
            return {}
            
        # Get final iteration results
        final_iter = data["iterations"][-1]
        if "final_evaluation" not in final_iter:
            return {}
            
        scores = {}
        for dim, result in final_iter["final_evaluation"].items():
            if isinstance(result, dict) and "score" in result:
                scores[dim] = result["score"]
        
        if scores:
            scores["total"] = sum(scores.values())
            scores["average"] = sum(scores.values()) / len(scores)
            
        return scores
        
    except Exception as e:
        print(f"    [WARNING] Failed to extract scores from {progress_file}: {e}")
        return {}

def analyze_results(all_results: Dict) -> Dict:
    """Analyze all refinement results and determine best versions"""
    
    analysis = {
        "timestamp": datetime.now().isoformat(),
        "summary": {},
        "best_results": {},
        "detailed_analysis": {}
    }
    
    print(f"\n{'='*60}")
    print("ANALYZING REFINEMENT RESULTS")
    print(f"{'='*60}")
    
    for topic_name, topic_results in all_results.items():
        print(f"\nAnalyzing {topic_name}...")
        
        best_score = -1
        best_config = None
        best_result = None
        
        topic_analysis = {
            "iterations_tested": len(topic_results),
            "results_by_iteration": {},
            "best_iteration": None,
            "improvement_summary": {}
        }
        
        for iteration, result in topic_results.items():
            if not result["success"]:
                print(f"  Iteration {iteration}: FAILED")
                topic_analysis["results_by_iteration"][iteration] = {
                    "status": "failed",
                    "error": result.get("error", "Unknown error")
                }
                continue
                
            # Extract scores
            progress_file = Path(result["progress_file"])
            scores = extract_scores_from_progress(progress_file)
            
            if not scores:
                print(f"  Iteration {iteration}: No scores extracted")
                topic_analysis["results_by_iteration"][iteration] = {
                    "status": "no_scores"
                }
                continue
                
            total_score = scores.get("total", 0)
            avg_score = scores.get("average", 0)
            
            print(f"  Iteration {iteration}: Total={total_score}, Average={avg_score:.2f}")
            print(f"    Scores: {dict((k,v) for k,v in scores.items() if k not in ['total', 'average'])}")
            
            topic_analysis["results_by_iteration"][iteration] = {
                "status": "success",
                "scores": scores,
                "total_score": total_score,
                "average_score": avg_score,
                "output_file": result["output_file"]
            }
            
            # Track best result
            if total_score > best_score:
                best_score = total_score
                best_config = iteration
                best_result = result
                
        # Record best result
        if best_result:
            print(f"  BEST: Iteration {best_config} (Total Score: {best_score})")
            topic_analysis["best_iteration"] = best_config
            topic_analysis["best_score"] = best_score
            
            analysis["best_results"][topic_name] = {
                "best_iteration": best_config,
                "best_score": best_score,
                "output_file": best_result["output_file"],
                "progress_file": best_result["progress_file"]
            }
        else:
            print(f"  BEST: None (all iterations failed)")
            topic_analysis["best_iteration"] = None
            
        analysis["detailed_analysis"][topic_name] = topic_analysis
    
    return analysis

def copy_best_results(analysis: Dict, output_dir: Path):
    """Copy best results to a dedicated directory"""
    best_dir = output_dir / "best_results"
    best_dir.mkdir(parents=True, exist_ok=True)
    
    print(f"\n{'='*60}")
    print("COPYING BEST RESULTS")
    print(f"{'='*60}")
    
    for topic_name, best_info in analysis["best_results"].items():
        if "output_file" not in best_info:
            continue
            
        src_file = Path(best_info["output_file"])
        if not src_file.exists():
            print(f"[WARNING] Best result file not found: {src_file}")
            continue
            
        # Copy refined file
        dest_file = best_dir / f"{topic_name}_BEST.tex"
        import shutil
        shutil.copy2(src_file, dest_file)
        
        # Copy progress file if exists
        if "progress_file" in best_info:
            progress_src = Path(best_info["progress_file"])
            if progress_src.exists():
                progress_dest = best_dir / f"{topic_name}_BEST_progress.json"
                shutil.copy2(progress_src, progress_dest)
        
        print(f"[COPY] {topic_name} -> {dest_file.name}")

def save_analysis(analysis: Dict, output_dir: Path):
    """Save analysis results"""
    
    # Save detailed JSON
    analysis_file = output_dir / "analysis" / f"batch_analysis_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
    with open(analysis_file, 'w', encoding='utf-8') as f:
        json.dump(analysis, f, indent=2, ensure_ascii=False)
    
    # Save summary CSV
    csv_file = output_dir / "analysis" / f"batch_summary_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
    with open(csv_file, 'w', encoding='utf-8') as f:
        f.write("topic,best_iteration,best_score,status\n")
        for topic_name, topic_analysis in analysis["detailed_analysis"].items():
            best_iter = topic_analysis.get("best_iteration", "None")
            best_score = topic_analysis.get("best_score", 0)
            status = "success" if best_iter != "None" else "failed"
            f.write(f"{topic_name},{best_iter},{best_score},{status}\n")
    
    # Save readable report
    report_file = output_dir / "analysis" / f"batch_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.md"
    with open(report_file, 'w', encoding='utf-8') as f:
        f.write(f"# Batch Refinement Analysis Report\n")
        f.write(f"**Generated:** {analysis['timestamp']}\n\n")
        
        f.write(f"## Summary\n")
        total_topics = len(analysis["detailed_analysis"])
        successful_topics = len(analysis["best_results"])
        f.write(f"- Total topics processed: {total_topics}\n")
        f.write(f"- Successful refinements: {successful_topics}\n")
        f.write(f"- Success rate: {successful_topics/total_topics*100:.1f}%\n\n")
        
        f.write(f"## Best Results by Topic\n")
        for topic_name, best_info in analysis["best_results"].items():
            f.write(f"### {topic_name}\n")
            f.write(f"- Best iteration: {best_info['best_iteration']}\n")
            f.write(f"- Best score: {best_info['best_score']}\n")
            f.write(f"- Output file: `{Path(best_info['output_file']).name}`\n\n")
    
    return analysis_file, csv_file, report_file

def main():
    """Main batch processing function"""
    
    parser = argparse.ArgumentParser(description="Batch survey refinement with custom file lists")
    parser.add_argument("--svx-list", type=str, required=True,
                       help="Comma-separated list of survey .tex file paths")
    parser.add_argument("--ref-list", type=str, required=True,
                       help="Comma-separated list of reference directory paths")
    parser.add_argument("--output", "-o", type=str, default=str(DEFAULT_OUTPUT_BASE),
                       help="Output directory for results")
    parser.add_argument("--refiner-script", type=str, default=DEFAULT_REFINER_SCRIPT,
                       help="Path to the refinement script")
    parser.add_argument("--iterations", type=str, default="1,2,3",
                       help="Comma-separated list of iteration numbers to test (e.g., '1,3,5')")
    args = parser.parse_args()
    
    output_base = Path(args.output)
    refiner_script = args.refiner_script
    
    # Parse iterations
    try:
        iterations = [int(x.strip()) for x in args.iterations.split(',') if x.strip()]
        if not iterations:
            iterations = DEFAULT_ITERATIONS
    except ValueError:
        print(f"Error: Invalid iterations format '{args.iterations}'. Using default {DEFAULT_ITERATIONS}")
        iterations = DEFAULT_ITERATIONS
    
    print(f"{'='*80}")
    print("BATCH SURVEY REFINEMENT (Custom Lists)")
    print(f"{'='*80}")
    print(f"Output directory: {output_base}")
    print(f"Iterations to test: {iterations}")
    print(f"Retrieval method: {RETRIEVAL_METHOD}")
    print(f"Refiner script: {refiner_script}")
    
    # Setup directories
    output_dir = setup_output_directories(output_base, iterations)
    
    # Parse file lists
    surveys = parse_file_lists(args.svx_list, args.ref_list)
    if not surveys:
        print("No valid survey-reference pairs found!")
        return
        
    print(f"\nFound {len(surveys)} survey-reference pairs to process:")
    for i, (survey_path, ref_dir) in enumerate(surveys, 1):
        topic_name = get_topic_name_from_path(survey_path)
        print(f"  {i}. {topic_name}")
        print(f"     Survey: {survey_path}")
        print(f"     References: {ref_dir}")
    
    # Process each survey
    all_results = {}
    total_jobs = len(surveys) * len(iterations)
    current_job = 0
    
    for survey_path, ref_dir in surveys:
        topic_name = get_topic_name_from_path(survey_path)
        print(f"\n{'='*60}")
        print(f"PROCESSING: {topic_name}")
        print(f"{'='*60}")
        
        topic_results = {}
        
        for max_iter in iterations:
            current_job += 1
            print(f"\n[JOB {current_job}/{total_jobs}] Testing {max_iter} iterations...")
            
            result = run_refinement(
                survey_path=survey_path,
                ref_dir=ref_dir,
                topic_name=topic_name,
                max_iterations=max_iter,
                output_dir=output_dir / f"iter_{max_iter}",
                refiner_script=refiner_script
            )
            
            if result["success"]:
                print(f"SUCCESS: {Path(result['output_file']).name}")
            else:
                print(f"FAILED: {result.get('error', 'Unknown error')}")
                
            topic_results[max_iter] = result
        
        all_results[topic_name] = topic_results
    
    # Analyze results
    analysis = analyze_results(all_results)
    
    # Copy best results
    copy_best_results(analysis, output_dir)
    
    # Save analysis
    analysis_file, csv_file, report_file = save_analysis(analysis, output_dir)
    
    print(f"\n{'='*80}")
    print("BATCH PROCESSING COMPLETED")
    print(f"{'='*80}")
    print(f"Analysis saved to: {analysis_file}")
    print(f"Summary CSV: {csv_file}")
    print(f"Report: {report_file}")
    print(f"Best results: {output_dir / 'best_results'}")
    
    return analysis

if __name__ == "__main__":
    try:
        analysis = main()
        print("Batch refinement completed successfully!")
    except KeyboardInterrupt:
        print("Process interrupted by user")
        sys.exit(1)
    except Exception as e:
        print(f"Batch refinement failed: {e}")
        import traceback
        traceback.print_exc()
        sys.exit(1)