#!/usr/bin/env python3
"""
PoE Reachability Audit (PRA) - analyzes how much oracle potential is reachable
under gPoE guard constraints.
"""

import argparse
import json
from collections import defaultdict
from typing import Dict, List, Set, Tuple


def load_qrels(qrels_file: str) -> Dict[str, Set[str]]:
    """Load qrels and return relevant documents per query."""
    qrels = defaultdict(set)
    with open(qrels_file, 'r') as f:
        for line in f:
            parts = line.strip().split('\t')
            if len(parts) >= 4 and int(parts[3]) > 0:
                qrels[parts[0]].add(parts[2])
    return dict(qrels)


def load_trec_run(run_file: str, cutoff: int = 100) -> Dict[str, List[Tuple[str, float]]]:
    """Load TREC run and return top documents per query."""
    run_data = defaultdict(list)
    with open(run_file, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 6:
                qid, docid, score = parts[0], parts[2], float(parts[4])
                run_data[qid].append((docid, score))
    
    # Sort and truncate
    for qid in run_data:
        run_data[qid].sort(key=lambda x: x[1], reverse=True)
        run_data[qid] = run_data[qid][:cutoff]
    
    return dict(run_data)


def parse_guard_config(guard_str: str) -> Dict[str, float]:
    """Parse guard configuration string like 'H=4,L=1.10,J=60,C=10,TAU=0.25'"""
    guards = {}
    for part in guard_str.split(','):
        key, value = part.split('=')
        guards[key.strip()] = float(value.strip())
    return guards


def check_reachability_under_guards(base_docs: List[Tuple[str, float]], 
                                   ges_docs: List[Tuple[str, float]],
                                   oracle_docs: List[str],
                                   guards: Dict[str, float],
                                   k: int = 10) -> Dict[str, any]:
    """
    Check if oracle documents are reachable under gPoE guard constraints.
    
    Guards:
    - H (freeze_head_k): Freeze top-H positions from base run
    - L (lambda_cap): Maximum boost factor  
    - J (max_jump): Maximum rank jump allowed
    - C (cutoff_target): Target cutoff for coverage
    - TAU (min_ges): Minimum GES score required for boost
    """
    base_docids = [doc for doc, _ in base_docs[:k]]
    ges_scores = {doc: score for doc, score in ges_docs}
    
    reachable_oracle = []
    analysis = {
        "total_oracle_docs": len(oracle_docs[:k]),
        "reachable_docs": 0,
        "blocked_by_head_freeze": 0,
        "blocked_by_max_jump": 0,
        "blocked_by_min_ges": 0,
        "blocked_by_coverage": 0
    }
    
    freeze_head_k = int(guards.get('H', 4))
    lambda_cap = guards.get('L', 1.10)
    max_jump = int(guards.get('J', 60))
    cutoff_target = int(guards.get('C', 10))
    min_ges = guards.get('TAU', 0.25)
    
    for oracle_doc in oracle_docs[:k]:
        is_reachable = True
        block_reason = None
        
        # Check if document is in base run
        if oracle_doc not in [doc for doc, _ in base_docs]:
            is_reachable = False
            block_reason = "not_in_base_pool"
        
        # Check if document is in GES results  
        elif oracle_doc not in ges_scores:
            is_reachable = False
            block_reason = "not_in_ges_pool"
        
        # Check minimum GES threshold
        elif ges_scores.get(oracle_doc, 0.0) < min_ges:
            is_reachable = False
            block_reason = "blocked_by_min_ges"
            analysis["blocked_by_min_ges"] += 1
            
        # Check head freeze constraint
        else:
            try:
                base_rank = next(i for i, (doc, _) in enumerate(base_docs) if doc == oracle_doc)
                
                if base_rank < freeze_head_k:
                    # Already in frozen head - reachable
                    pass
                elif base_rank >= max_jump:
                    is_reachable = False
                    block_reason = "blocked_by_max_jump"
                    analysis["blocked_by_max_jump"] += 1
                    
            except StopIteration:
                is_reachable = False
                block_reason = "not_in_base_ranking"
        
        if is_reachable:
            reachable_oracle.append(oracle_doc)
            analysis["reachable_docs"] += 1
        
    # Calculate reachability percentage
    reachability_pct = (analysis["reachable_docs"] / analysis["total_oracle_docs"] 
                       if analysis["total_oracle_docs"] > 0 else 0.0)
    
    analysis["reachability_percentage"] = reachability_pct
    analysis["guard_config"] = guards
    
    return analysis


def main():
    parser = argparse.ArgumentParser(description="PoE Reachability Audit")
    parser.add_argument("--qrels", required=True, help="Qrels file")
    parser.add_argument("--runs", nargs="+", required=True,
                       help="Run files as 'file:name' (need base, ges, and optionally oracle)")
    parser.add_argument("--guards", required=True, 
                       help="Guard config like 'H=4,L=1.10,J=60,C=10,TAU=0.25'")
    parser.add_argument("--k", type=int, default=10, help="Evaluation cutoff")
    parser.add_argument("--cutoff", type=int, default=100, help="Document retrieval cutoff") 
    parser.add_argument("--output", help="Output analysis JSON")
    
    args = parser.parse_args()
    
    # Parse guard configuration
    guards = parse_guard_config(args.guards)
    print(f"Guard configuration: {guards}")
    
    # Load qrels
    qrels = load_qrels(args.qrels)
    print(f"Loaded {len(qrels)} queries")
    
    # Load runs
    runs = {}
    for run_spec in args.runs:
        if ':' in run_spec:
            run_file, run_name = run_spec.split(':', 1)
        else:
            run_file, run_name = run_spec, run_spec
        runs[run_name] = load_trec_run(run_file, args.cutoff)
        print(f"Loaded {run_name}: {len(runs[run_name])} queries")
    
    # Identify base run and GES run
    base_run = runs.get('BGE') or runs.get('BM25') or list(runs.values())[0]
    ges_run = runs.get('Multi-GES') or runs.get('GES') or list(runs.values())[-1]
    
    print(f"Using base run with {len(base_run)} queries")
    print(f"Using GES run with {len(ges_run)} queries")
    
    # Compute oracle for comparison
    print("Computing oracle documents...")
    oracle_docs = {}
    for qid in qrels:
        # Oracle = union of all relevant docs found by any run
        relevant_docs = qrels[qid]
        found_docs = set()
        
        for run_name, run_data in runs.items():
            if qid in run_data:
                for docid, _ in run_data[qid]:
                    if docid in relevant_docs:
                        found_docs.add(docid)
        
        oracle_docs[qid] = list(found_docs)
    
    # Perform reachability analysis
    print("\\nPerforming reachability analysis...")
    overall_stats = {
        "total_oracle_docs": 0,
        "reachable_docs": 0,
        "blocked_by_head_freeze": 0,
        "blocked_by_max_jump": 0, 
        "blocked_by_min_ges": 0,
        "queries_analyzed": 0
    }
    
    query_analyses = {}
    
    for qid in qrels:
        if qid not in base_run or qid not in ges_run or qid not in oracle_docs:
            continue
            
        analysis = check_reachability_under_guards(
            base_run[qid], ges_run[qid], oracle_docs[qid], guards, args.k
        )
        
        query_analyses[qid] = analysis
        
        # Accumulate stats
        overall_stats["total_oracle_docs"] += analysis["total_oracle_docs"]
        overall_stats["reachable_docs"] += analysis["reachable_docs"] 
        overall_stats["blocked_by_head_freeze"] += analysis["blocked_by_head_freeze"]
        overall_stats["blocked_by_max_jump"] += analysis["blocked_by_max_jump"]
        overall_stats["blocked_by_min_ges"] += analysis["blocked_by_min_ges"]
        overall_stats["queries_analyzed"] += 1
    
    # Calculate overall reachability
    total_oracle = overall_stats["total_oracle_docs"]
    total_reachable = overall_stats["reachable_docs"]
    overall_reachability = (total_reachable / total_oracle) if total_oracle > 0 else 0.0
    
    print(f"\\n=== PoE Reachability Audit Results ===")
    print(f"Queries analyzed: {overall_stats['queries_analyzed']}")
    print(f"Total oracle documents: {total_oracle}")
    print(f"Reachable documents: {total_reachable}")
    print(f"Overall reachability: {overall_reachability:.3f} ({overall_reachability*100:.1f}%)")
    print(f"Blocked by min GES threshold: {overall_stats['blocked_by_min_ges']}")
    print(f"Blocked by max jump constraint: {overall_stats['blocked_by_max_jump']}")
    
    # Prepare output
    output_data = {
        "guard_configuration": guards,
        "overall_reachability": overall_reachability,
        "overall_statistics": overall_stats,
        "per_query_analysis": query_analyses,
        "summary": {
            "total_oracle_docs": total_oracle,
            "reachable_docs": total_reachable,
            "reachability_percentage": overall_reachability,
            "primary_blocking_factors": {
                "min_ges_threshold": overall_stats["blocked_by_min_ges"] / total_oracle if total_oracle > 0 else 0,
                "max_jump_constraint": overall_stats["blocked_by_max_jump"] / total_oracle if total_oracle > 0 else 0
            }
        }
    }
    
    if args.output:
        with open(args.output, 'w') as f:
            json.dump(output_data, f, indent=2)
        print(f"\\nAnalysis saved to {args.output}")
    

if __name__ == "__main__":
    main()