#!/usr/bin/env python3
"""
Oracle Upper Bound (OUB) analysis for GEC verification.
Computes theoretical maximum performance across fusion components.
"""

import argparse
import json
from collections import defaultdict
from typing import Dict, List, Tuple, Set


def load_qrels(qrels_file: str) -> Dict[str, Set[str]]:
    """Load qrels file and return relevant documents per query."""
    qrels = defaultdict(set)
    with open(qrels_file, 'r') as f:
        for line in f:
            parts = line.strip().split('\t')
            if len(parts) >= 4:
                qid, _, docid, rel = parts[0], parts[1], parts[2], parts[3]
                if int(rel) > 0:  # Relevant document
                    qrels[qid].add(docid)
    return dict(qrels)


def load_trec_run(run_file: str) -> Dict[str, List[Tuple[str, float]]]:
    """Load TREC run file and return ranked documents per query."""
    run_data = defaultdict(list)
    with open(run_file, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 6:
                qid, _, docid, rank, score, _ = parts
                run_data[qid].append((docid, float(score)))
    
    # Sort by score descending
    for qid in run_data:
        run_data[qid].sort(key=lambda x: x[1], reverse=True)
    
    return dict(run_data)


def compute_oracle_fusion(runs: Dict[str, Dict[str, List[Tuple[str, float]]]], 
                         qrels: Dict[str, Set[str]], k: int = 10) -> Dict[str, List[str]]:
    """
    Compute Oracle Upper Bound by selecting best documents across all runs.
    For each query, take the union of all retrieved docs and rank by relevance.
    """
    oracle_results = {}
    
    for qid in qrels:
        # Collect all retrieved documents across runs
        all_docs = set()
        doc_scores = defaultdict(float)
        
        for run_name, run_data in runs.items():
            if qid in run_data:
                for docid, score in run_data[qid][:100]:  # Top 100 per run
                    all_docs.add(docid)
                    doc_scores[docid] = max(doc_scores[docid], score)
        
        # Rank documents by relevance (relevant docs first, then by score)
        relevant_docs = qrels.get(qid, set())
        
        def sort_key(docid):
            is_relevant = docid in relevant_docs
            return (is_relevant, doc_scores[docid])
        
        ranked_docs = sorted(all_docs, key=sort_key, reverse=True)
        oracle_results[qid] = ranked_docs[:k]
    
    return oracle_results


def calculate_metrics(run_data: Dict[str, List[str]], qrels: Dict[str, Set[str]], 
                     k: int = 10) -> Dict[str, float]:
    """Calculate MRR@k, nDCG@k, and Recall@k for a run."""
    mrr_scores = []
    ndcg_scores = []
    recall_scores = []
    
    for qid in qrels:
        relevant_docs = qrels[qid]
        if not relevant_docs:
            continue
            
        retrieved_docs = run_data.get(qid, [])[:k]
        
        # MRR@k
        mrr = 0.0
        for i, docid in enumerate(retrieved_docs):
            if docid in relevant_docs:
                mrr = 1.0 / (i + 1)
                break
        mrr_scores.append(mrr)
        
        # nDCG@k (simplified binary relevance)
        dcg = 0.0
        for i, docid in enumerate(retrieved_docs):
            if docid in relevant_docs:
                dcg += 1.0 / (1.0 + i)
        
        idcg = sum(1.0 / (1.0 + i) for i in range(min(len(relevant_docs), k)))
        ndcg = dcg / idcg if idcg > 0 else 0.0
        ndcg_scores.append(ndcg)
        
        # Recall@k  
        relevant_retrieved = len([d for d in retrieved_docs if d in relevant_docs])
        recall = relevant_retrieved / len(relevant_docs)
        recall_scores.append(recall)
    
    return {
        f"MRR@{k}": sum(mrr_scores) / len(mrr_scores) if mrr_scores else 0.0,
        f"nDCG@{k}": sum(ndcg_scores) / len(ndcg_scores) if ndcg_scores else 0.0,
        f"Recall@{k}": sum(recall_scores) / len(recall_scores) if recall_scores else 0.0,
        "num_queries": len(mrr_scores)
    }


def main():
    parser = argparse.ArgumentParser(description="Compute Oracle Upper Bound analysis")
    parser.add_argument("--runs", nargs="+", required=True, 
                       help="Run files in format 'file:name' (e.g., bm25.trec:BM25)")
    parser.add_argument("--qrels", required=True, help="Qrels file")
    parser.add_argument("--k", type=int, default=10, help="Cutoff for metrics")
    parser.add_argument("--output", help="Output oracle run file")
    parser.add_argument("--analysis", help="Output analysis JSON file")
    
    args = parser.parse_args()
    
    # Load qrels
    print("Loading qrels...")
    qrels = load_qrels(args.qrels)
    print(f"Loaded {len(qrels)} queries")
    
    # Load runs
    runs = {}
    run_results = {}
    
    for run_spec in args.runs:
        if ':' in run_spec:
            run_file, run_name = run_spec.split(':', 1)
        else:
            run_file, run_name = run_spec, run_spec
            
        print(f"Loading {run_name}...")
        run_data = load_trec_run(run_file)
        runs[run_name] = run_data
        
        # Convert to list format for metrics
        run_list = {qid: [doc for doc, _ in docs] for qid, docs in run_data.items()}
        run_results[run_name] = calculate_metrics(run_list, qrels, args.k)
        print(f"  {run_name}: MRR@{args.k}={run_results[run_name][f'MRR@{args.k}']:.3f}")
    
    # Compute Oracle Upper Bound
    print("\\nComputing Oracle Upper Bound...")
    oracle_fusion = compute_oracle_fusion(runs, qrels, k=args.k)
    oracle_metrics = calculate_metrics(oracle_fusion, qrels, args.k)
    
    print(f"Oracle Upper Bound: MRR@{args.k}={oracle_metrics[f'MRR@{args.k}']:.3f}")
    
    # Write oracle run file
    if args.output:
        with open(args.output, 'w') as f:
            for qid, docs in oracle_fusion.items():
                for i, docid in enumerate(docs):
                    f.write(f"{qid}\\tQ0\\t{docid}\\t{i+1}\\t{1.0/(i+1)}\\toracle\\n")
        print(f"Oracle run written to {args.output}")
    
    # Analysis report
    analysis = {
        "oracle_upper_bound": oracle_metrics,
        "individual_runs": run_results,
        "headroom_analysis": {}
    }
    
    # Calculate headroom for each run
    oracle_mrr = oracle_metrics[f"MRR@{args.k}"]
    for run_name, metrics in run_results.items():
        run_mrr = metrics[f"MRR@{args.k}"]
        gap = oracle_mrr - run_mrr
        improvement = (gap / run_mrr * 100) if run_mrr > 0 else 0.0
        
        analysis["headroom_analysis"][run_name] = {
            "current_mrr": run_mrr,
            "oracle_mrr": oracle_mrr,
            "absolute_gap": gap,
            "relative_improvement": improvement
        }
        
        print(f"{run_name} headroom: {gap:.3f} absolute ({improvement:.1f}% improvement)")
    
    if args.analysis:
        with open(args.analysis, 'w') as f:
            json.dump(analysis, f, indent=2)
        print(f"Analysis written to {args.analysis}")


if __name__ == "__main__":
    main()