#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# consensus_scoring.py
# -----------------------------------------------------------------------------
# Computes global consensus scores between models based on cross/self perplexity scores.
# -----------------------------------------------------------------------------

from __future__ import annotations

import json
import math
import argparse
import random
import logging
from pathlib import Path
from typing import Dict, Any, Iterable, List, Optional, Tuple, Union
from collections import defaultdict
from dataclasses import dataclass

@dataclass
class RankingConfig:
    """Configuration for consensus ranking computation."""
    top_k: int = 3
    bottom_k: int = 3  
    eps: float = 1e-12 
    seed: Optional[int] = None 
    
    def __post_init__(self):
        if self.top_k <= 0 or self.bottom_k <= 0:
            raise ValueError("top_k and bottom_k must be positive")
        if self.eps <= 0:
            raise ValueError("eps must be positive")

@dataclass 
class ConsensusScore:
    """Represents a model's consensus score and metadata."""
    model: str
    score_sum: Optional[float]  # lower is better
    n_pairs: int  
    rank: Optional[int] = None 
    
    @property
    def is_valid(self) -> bool:
        return self.score_sum is not None and self.n_pairs > 0

def load_entries(input_path: Union[str, Path]) -> Iterable[Dict[str, Any]]:
    """
    Load entries from JSON array or JSONL file.
    
    Args:
        input_path: Path to input file
        
    Yields:
        Dictionary entries from the file
        
    Raises:
        ValueError: If file format is invalid
        FileNotFoundError: If input file doesn't exist
    """
    input_path = Path(input_path)
    
    if not input_path.exists():
        raise FileNotFoundError(f"Input file not found: {input_path}")
    
    try:
        with input_path.open('r', encoding='utf-8') as f:
            pos = f.tell()
            first_char = None
            
            while True:
                char = f.read(1)
                if not char:
                    break
                if not char.isspace():
                    first_char = char
                    break
            
            f.seek(pos)
            
            if first_char == '[':
                # JSON array format
                data = json.load(f)
                if not isinstance(data, list):
                    raise ValueError("Top-level JSON must be an array")
                    
                for item in data:
                    if isinstance(item, dict):
                        yield item
                    else:
                        logging.warning(f"Skipping non-dict item: {type(item)}")
            else:
                # JSONL format  
                for line_no, line in enumerate(f, 1):
                    line = line.strip()
                    if not line:
                        continue
                        
                    try:
                        obj = json.loads(line)
                        if isinstance(obj, dict):
                            yield obj
                        else:
                            logging.warning(f"Line {line_no}: Skipping non-dict item")
                    except json.JSONDecodeError as e:
                        logging.error(f"Line {line_no}: JSON decode error - {e}")
                        
    except Exception as e:
        raise ValueError(f"Error reading {input_path}: {e}")

def save_results(results: List[Dict[str, Any]], output_path: Union[str, Path]) -> None:
    """Save results to JSON file with proper formatting."""
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    try:
        with output_path.open('w', encoding='utf-8') as f:
            json.dump(results, f, ensure_ascii=False, indent=2)
            f.write('\n')
        logging.info(f"Results saved to {output_path}")
    except Exception as e:
        raise ValueError(f"Error saving results to {output_path}: {e}")

def assign_dense_ranks(
    items: List[Dict[str, Any]], 
    key_func: callable,
    eps: float = 1e-12,
    rank_field: str = "rank"
) -> None:
    """
    Assign ranks to items based on key function.
    Items with None values get rank=None.
    
    Args:
        items: List of items (modified in-place)
        key_func: Function to extract ranking key from each item
        eps: Threshold for detecting ties
        rank_field: Field name to store rank
    """
    current_rank = 0
    prev_value: Optional[float] = None
    
    for item in items:
        value = key_func(item)
        
        if value is None:
            item[rank_field] = None
            continue
            
        if prev_value is None or abs(value - prev_value) > eps:
            current_rank += 1
            prev_value = value
            
        item[rank_field] = current_rank

def select_topk_with_ties(
    sorted_items: List[Dict[str, Any]],
    k: int,
    key_func: callable,
    rng: random.Random,
    eps: float = 1e-12
) -> List[Dict[str, Any]]:
    """
    Select top-k items, handling ties with random sampling.
    
    Args:
        sorted_items: Items sorted by ranking key (ascending)  
        k: Number of items to select
        key_func: Function to extract ranking key
        rng: Random number generator for tie-breaking
        eps: Threshold for detecting ties
        
    Returns:
        List of selected items (up to k items)
    """
    if not sorted_items or k <= 0:
        return []
        
    selected = []
    i = 0
    n = len(sorted_items)
    
    while i < n and len(selected) < k:
        current_value = key_func(sorted_items[i])
        
        if current_value is None:
            break 
            
        # Find all items with the same value (tie group)
        j = i + 1
        while j < n:
            next_value = key_func(sorted_items[j])
            if next_value is None or abs(next_value - current_value) > eps:
                break
            j += 1
            
        # Select from tie group
        tie_group = sorted_items[i:j]
        remaining_slots = k - len(selected)
        
        if len(tie_group) <= remaining_slots:
            selected.extend(tie_group)
        else:
            # Random sampling within tie group
            selected.extend(rng.sample(tie_group, remaining_slots))
            
        i = j
        
    return selected

def safe_log(value: Optional[Union[int, float]]) -> Optional[float]:
    """Compute log of value if positive, otherwise return None."""
    if value is None:
        return None
    if isinstance(value, (int, float)) and value > 0:
        return math.log(value)
    return None

def extract_ppl_data(entry: Dict[str, Any]) -> Tuple[Dict[str, float], Dict[str, Dict[str, float]]]:
    """
    Extract self- and cross-perplexity dictionaries.
    """
    self_ppl  = entry.get("self_ppl")
    cross_ppl = entry.get("cross_ppl")

    if not isinstance(self_ppl, dict):
        raise ValueError("Missing or invalid field 'self_ppl'")
    if not isinstance(cross_ppl, dict):
        raise ValueError("Missing or invalid field 'cross_ppl'")
    if not self_ppl:
        raise ValueError("'self_ppl' is empty")

    return self_ppl, cross_ppl

def compute_pairwise_consensus(
    entry: Dict[str, Any], 
    eps: float = 1e-12
) -> Dict[str, List[Dict[str, Any]]]:
    """
    Compute pairwise consensus scores for all model pairs.
    
    For each model A, compute d(A->B) = log(cross_ppl[A][B]) - log(self_ppl[B])
    for all other models B, then rank by |d(A->B)|.
    
    Args:
        entry: Dictionary containing self_ppl and cross_ppl data
        eps: Epsilon for tie detection in ranking
        
    Returns:
        Dictionary mapping each model to its ranked consensus scores
        Format: {model_A: [{model: B, d: float|None, abs_d: float|None, rank: int|None}, ...]}
    """
    self_ppl, cross_ppl = extract_ppl_data(entry)
    models = list(self_ppl.keys())
    
    # Precompute log of self perplexities
    log_self_ppl = {model: safe_log(self_ppl[model]) for model in models}
    
    consensus_ranks = {}
    
    for model_a in models:
        pairwise_scores = []
        
        for model_b in models:
            if model_a == model_b:
                continue 
                
            # Compute d(A->B) = log(cross_ppl[A][B]) - log(self_ppl[B])
            log_cross = safe_log(cross_ppl.get(model_a, {}).get(model_b))
            log_self = log_self_ppl[model_b]
            
            if log_cross is not None and log_self is not None:
                d_value = log_cross - log_self
                abs_d_value = abs(d_value)
            else:
                d_value = None
                abs_d_value = None
                
            pairwise_scores.append({
                "model": model_b,
                "d": d_value,
                "abs_d": abs_d_value
            })
        
        # Sort by |d| ascending (None values go to end)
        pairwise_scores.sort(key=lambda x: float('inf') if x["abs_d"] is None else x["abs_d"])
        
        assign_dense_ranks(pairwise_scores, key_func=lambda x: x["abs_d"], eps=eps)
        
        consensus_ranks[model_a] = pairwise_scores
    
    return consensus_ranks

def compute_aggregate_consensus(
    pairwise_ranks: Dict[str, List[Dict[str, Any]]],
    models: List[str],
    config: RankingConfig
) -> Tuple[List[ConsensusScore], List[ConsensusScore], List[ConsensusScore]]:
    """
    Compute aggregate consensus scores and select top-k/bottom-k models.
    
    For each model B, aggregate_score(B) = sum_A |d(A->B)| across all A≠B.
    Lower scores indicate better consensus (more agreement from other models).
    
    Args:
        pairwise_ranks: Output from compute_pairwise_consensus
        models: List of all model names
        config: Ranking configuration
        
    Returns:
        Tuple of (all_scores, top_k_scores, bottom_k_scores)
    """
    # Aggregate scores: sum |d(A->B)| for each model B
    score_sums = defaultdict(float)
    pair_counts = defaultdict(int)
    
    for model_a, rankings in pairwise_ranks.items():
        if model_a not in models:
            continue
            
        for rank_info in rankings:
            model_b = rank_info["model"]
            abs_d = rank_info["abs_d"]
            
            if abs_d is not None:
                score_sums[model_b] += abs_d
                pair_counts[model_b] += 1
    
    # Create ConsensusScore objects
    all_scores = []
    for model in models:
        score = score_sums[model] if pair_counts[model] > 0 else None
        all_scores.append(ConsensusScore(
            model=model,
            score_sum=score,
            n_pairs=pair_counts[model]
        ))
    
    # Sort by score (ascending - lower is better)
    all_scores.sort(key=lambda x: float('inf') if x.score_sum is None else x.score_sum)
  
    assign_dense_ranks(
        [score.__dict__ for score in all_scores],
        key_func=lambda x: x["score_sum"],
        eps=config.eps,
        rank_field="rank"
    )
    
    for i, score_dict in enumerate([s.__dict__ for s in all_scores]):
        all_scores[i].rank = score_dict["rank"]
    
    # Select top-k and bottom-k with tie handling
    rng = random.Random(config.seed)
    
    valid_models = [s for s in all_scores if s.is_valid]
    actual_top_k = min(config.top_k, len(valid_models))
    actual_bottom_k = min(config.bottom_k, len(valid_models))
    
    # Top-k (best consensus - lowest scores)
    top_k_dicts = select_topk_with_ties(
        [s.__dict__ for s in all_scores],
        actual_top_k,
        key_func=lambda x: x["score_sum"],
        rng=rng,
        eps=config.eps
    )
    top_k_scores = [ConsensusScore(**d) for d in top_k_dicts]
    
    # Bottom-k (worst consensus - highest scores)  
    all_scores_desc = sorted(all_scores, key=lambda x: float('-inf') if x.score_sum is None else -x.score_sum)
    bottom_k_dicts = select_topk_with_ties(
        [s.__dict__ for s in all_scores_desc],
        actual_bottom_k,
        key_func=lambda x: -x["score_sum"] if x["score_sum"] is not None else None,
        rng=rng,
        eps=config.eps
    )
    bottom_k_scores = [ConsensusScore(**d) for d in bottom_k_dicts]
    
    return all_scores, top_k_scores, bottom_k_scores

def process_entry(
    entry: Dict[str, Any], 
    entry_idx: int,
    config: RankingConfig
) -> Dict[str, Any]:
    """
    Process a single entry to compute consensus rankings.
    
    Args:
        entry: Input entry with perplexity data
        entry_idx: Index of entry for tracking
        config: Ranking configuration
        
    Returns:
        Dictionary with consensus analysis results
    """
    try:
        # Extract model list
        self_ppl, _ = extract_ppl_data(entry)
        models = list(self_ppl.keys())
        
        if len(models) < 2:
            raise ValueError(f"Need at least 2 models, got {len(models)}")
        
        # Compute pairwise consensus rankings
        pairwise_ranks = compute_pairwise_consensus(entry, config.eps)
        
        # Compute aggregate consensus scores and select top/bottom-k
        all_scores, top_k, bottom_k = compute_aggregate_consensus(
            pairwise_ranks, models, config
        )

        return {
            "idx": entry_idx,
            "question": entry.get("question"),
            "image_path": entry.get("image_path"),
            "dataset": entry.get("dataset"),
            "n_models": len(models),
            "consensus_ranks": pairwise_ranks,
            "final_ranking": [score.__dict__ for score in all_scores],
            "final_topk": [score.__dict__ for score in top_k],
            "final_bottomk": [score.__dict__ for score in bottom_k],
            "config": {
                "top_k": config.top_k,
                "bottom_k": config.bottom_k,
                "eps": config.eps,
                "seed": config.seed
            }
        }
        
    except Exception as e:
        logging.error(f"Entry {entry_idx}: {type(e).__name__}: {e}")
        return {
            "idx": entry_idx, 
            "question": entry.get("question"),
            "image_path": entry.get("image_path"),
            "error": f"{type(e).__name__}: {e}"
        }

def setup_logging(verbose: bool = False) -> None:
    """Setup logging configuration."""
    level = logging.DEBUG if verbose else logging.INFO
    logging.basicConfig(
        level=level,
        format='%(asctime)s | %(levelname)s | %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )

def parse_arguments() -> argparse.Namespace:
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(
        description="Compute model consensus rankings from cross/self perplexity scores",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  %(prog)s --input results.json --output rankings.json --top_k 3 --bottom_k 3
  %(prog)s --input data.jsonl --output out.json --seed 42 --verbose
        """
    )
    
    parser.add_argument(
        "--input", required=True,
        help="Input file (JSON array or JSONL format)"
    )
    parser.add_argument(
        "--output", required=True,
        help="Output JSON file for results"
    )
    parser.add_argument(
        "--top_k", type=int, default=3,
        help="Number of top models to select (default: 3)"
    )
    parser.add_argument(
        "--bottom_k", type=int, default=3,
        help="Number of bottom models to select (default: 3)"
    )
    parser.add_argument(
        "--seed", type=int, default=42,
        help="Random seed for reproducible tie-breaking"
    )
    parser.add_argument(
        "--eps", type=float, default=1e-12,
        help="Threshold for tie detection (default: 1e-12)"
    )
    parser.add_argument(
        "--verbose", action="store_true",
        help="Enable verbose logging"
    )
    
    return parser.parse_args()

def main() -> None:
    """Main entry point."""
    args = parse_arguments()
    setup_logging(args.verbose)
    
    try:
        config = RankingConfig(
            top_k=args.top_k,
            bottom_k=args.bottom_k,
            eps=args.eps,
            seed=args.seed
        )
        
        logging.info(f"Starting consensus ranking with config: {config}")
        logging.info(f"Input: {args.input}")
        logging.info(f"Output: {args.output}")
        
        # Process all entries
        results = []
        entry_count = 0
        error_count = 0
        
        for entry_idx, entry in enumerate(load_entries(args.input)):
            result = process_entry(entry, entry_idx, config)
            results.append(result)
            
            entry_count += 1
            if "error" in result:
                error_count += 1
                
            if entry_count % 100 == 0:
                logging.info(f"Processed {entry_count} entries ({error_count} errors)")
        
        save_results(results, args.output)
        
        logging.info(f"Completed processing {entry_count} entries")
        if error_count > 0:
            logging.warning(f"Encountered {error_count} errors during processing")
        else:
            logging.info("All entries processed successfully")
            
    except Exception as e:
        logging.error(f"Fatal error: {type(e).__name__}: {e}")
        raise

if __name__ == "__main__":
    main() 
