import json
import argparse
import sys
import os
from typing import Dict, List, Any
from collections import defaultdict
import logging

# Add parent directory to path to import sal modules
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from sal.utils.score import aggregate_scores
from sal.utils.math import find_answer_with_largest_sum, find_majority_answer, extract_answer

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def load_jsonl(file_path: str) -> Dict[str, Dict]:
    """Load JSONL file and create mapping by unique_id"""
    data_map = {}
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                data = json.loads(line)
                unique_id = data['unique_id']
                data_map[unique_id] = data
    return data_map

def merge_list_fields(before_data: Dict, after_data: Dict, field: str, n: int) -> List:
    """Merge list fields from before and after data"""
    before_list = before_data.get(field, [])[:n]
    after_list = after_data.get(field, [])[:n]
    return before_list + after_list

def compute_predictions(completions: List[str], agg_scores: List[float], answer: str, k_values: List[int]) -> Dict[str, str]:
    """Compute weighted, majority, and naive predictions for given k values"""
    predictions = {}
    
    for k in k_values:
        if k > len(completions):
            k_actual = len(completions)
        else:
            k_actual = k
            
        # Get top-k completions and scores
        top_k_completions = completions[:k_actual]
        top_k_scores = agg_scores[:k_actual]
        
        # Extract answers from completions for comparison, but keep original completions for output
        top_k_answers = [extract_answer(comp, "default") for comp in top_k_completions]
        
        # Weighted prediction (based on scores) - return the original completion, not extracted answer
        weighted_answer = find_answer_with_largest_sum(top_k_answers, top_k_scores)
        # Find which completion corresponds to this answer
        weighted_completion = ""
        for i, ans in enumerate(top_k_answers):
            if ans == weighted_answer:
                weighted_completion = extract_answer(top_k_completions[i], "default")
                break
        predictions[f"pred_weighted@{k}"] = f"\\boxed{{{weighted_completion}}}" if weighted_completion else ""
        
        # Majority prediction - return the original completion, not extracted answer
        maj_answer = find_majority_answer(top_k_answers)
        # Find which completion corresponds to this answer
        maj_completion = ""
        for i, ans in enumerate(top_k_answers):
            if ans == maj_answer:
                maj_completion = extract_answer(top_k_completions[i], "default")
                break
        predictions[f"pred_maj@{k}"] = f"\\boxed{{{maj_completion}}}" if maj_completion else ""
        
        # Naive prediction - select the completion with highest score
        if top_k_scores:
            max_score_idx = top_k_scores.index(max(top_k_scores))
            naive_answer = extract_answer(top_k_completions[max_score_idx], "default")
        else:
            naive_answer = ""
        predictions[f"pred_naive@{k}"] = f"\\boxed{{{naive_answer}}}" if naive_answer else ""
    
    return predictions

def generate_k_values(max_k: int) -> List[int]:
    """Generate k values: 1, 2, 4, 8, ..., max_k"""
    k_values = []
    k = 1
    while k <= max_k:
        k_values.append(k)
        if k == max_k:
            break
        k *= 2
        if k > max_k:
            k_values.append(max_k)
            break
    return k_values

def merge_completions(before_file: str, after_file: str, n: int, output_file: str, aggregation_strategy: str = "last"):
    """
    Merge completions from before and after calibration files
    
    Args:
        before_file: Path to before calibration JSONL file
        after_file: Path to after calibration JSONL file  
        n: Number of completions to take from each file
        output_file: Output JSONL file path
        aggregation_strategy: Strategy for score aggregation
    """
    logger.info(f"Loading before calibration file: {before_file}")
    before_data = load_jsonl(before_file)
    
    logger.info(f"Loading after calibration file: {after_file}")
    after_data = load_jsonl(after_file)
    
    # Find common unique_ids
    common_ids = set(before_data.keys()) & set(after_data.keys())
    logger.info(f"Found {len(common_ids)} common problems")
    
    if not common_ids:
        raise ValueError("No common unique_ids found between the two files")
    
    merged_data = []
    max_k = 2 * n
    k_values = [max_k]  # Only compute for the maximum k value
    logger.info(f"Will compute predictions for k values: {k_values}")
    
    for unique_id in sorted(common_ids):
        before_item = before_data[unique_id]
        after_item = after_data[unique_id]
        
        # Create merged item starting with non-list fields from before_item
        merged_item = {}
        
        # Copy single-value fields (keep from before_item as reference)
        single_fields = ['problem', 'solution', 'answer', 'subject', 'level', 'unique_id']
        for field in single_fields:
            if field in before_item:
                merged_item[field] = before_item[field]
        
        # Merge list fields
        list_fields = ['completions', 'scores', 'completion_tokens']
        for field in list_fields:
            if field in before_item or field in after_item:
                merged_item[field] = merge_list_fields(before_item, after_item, field, n)
        
        # Handle agg_scores specially
        if 'scores' in merged_item:
            # Re-aggregate scores using the specified strategy
            merged_scores = merged_item['scores']
            merged_agg_scores = []
            for score_list in merged_scores:
                if aggregation_strategy == "mean":
                    agg_score = sum(score_list) / len(score_list) if score_list else 0.0
                elif aggregation_strategy == "max":
                    agg_score = max(score_list) if score_list else 0.0
                elif aggregation_strategy == "min":
                    agg_score = aggregate_scores(score_list, "min")
                elif aggregation_strategy == "last":
                    agg_score = aggregate_scores(score_list, "last")
                else:
                    # Default to last
                    agg_score = aggregate_scores(score_list, "last")
                merged_agg_scores.append(agg_score)
            merged_item['agg_scores'] = merged_agg_scores
        else:
            # Fallback to direct merging if scores not available
            merged_item['agg_scores'] = merge_list_fields(before_item, after_item, 'agg_scores', n)
        
        # Set pred to the first completion (following the original format)
        if 'completions' in merged_item and merged_item['completions']:
            merged_item['pred'] = merged_item['completions'][0]
        
        # Compute predictions for all k values
        if 'completions' in merged_item and 'agg_scores' in merged_item:
            predictions = compute_predictions(
                merged_item['completions'], 
                merged_item['agg_scores'], 
                merged_item.get('answer', ''),
                k_values
            )
            merged_item.update(predictions)
        
        merged_data.append(merged_item)
    
    # Create output directory if it doesn't exist
    output_dir = os.path.dirname(output_file)
    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
        logger.info(f"Created output directory: {output_dir}")

    # Save merged data
    logger.info(f"Saving merged data to: {output_file}")
    with open(output_file, 'w', encoding='utf-8') as f:
        for item in merged_data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')
    
    logger.info(f"Successfully merged {len(merged_data)} problems with {max_k} completions each")

def main():
    parser = argparse.ArgumentParser(description="Merge completions from before and after calibration")
    parser.add_argument("--before_file", type=str, required=True, 
                        help="Path to before calibration JSONL file")
    parser.add_argument("--after_file", type=str, required=True,
                        help="Path to after calibration JSONL file") 
    parser.add_argument("--n", type=int, required=True,
                        help="Number of completions to take from each file")
    parser.add_argument("--output_file", type=str, required=True,
                        help="Output JSONL file path")
    parser.add_argument("--aggregation_strategy", type=str, default="last",
                        choices=["last", "min", "mean", "max"],
                        help="Strategy for score aggregation")
    
    args = parser.parse_args()
    
    merge_completions(
        before_file=args.before_file,
        after_file=args.after_file, 
        n=args.n,
        output_file=args.output_file,
        aggregation_strategy=args.aggregation_strategy
    )
    
    # Force cleanup to avoid multiprocessing cleanup warnings
    import gc
    gc.collect()

if __name__ == "__main__":
    main()