#!/usr/bin/env python3


import argparse
import json
import sys
from pathlib import Path
from typing import Set, List, Dict, Any
from datasets import load_from_disk, Dataset
from datasets import load_dataset

def load_correct_ids(file_path: str) -> Set[int]:
    
    with open(file_path, 'r') as f:
        data = json.load(f)
    
    if 'correct_question_ids' in data:
        return set(data['correct_question_ids'])
    else:
        raise ValueError(f"File {file_path} does not contain 'correct_question_ids' field")

def analyze_differences(model1_ids: Set[int], model2_ids: Set[int], 
                       model1_name: str, model2_name: str) -> Dict[str, Any]:

    
    # Calculate various sets
    intersection = model1_ids & model2_ids  # Correct answers by both models
    model1_only = model1_ids - model2_ids   # Correct answers only by model 1
    model2_only = model2_ids - model1_ids   # Correct answers only by model 2
    union = model1_ids | model2_ids         # Correct answers by at least one model
    symmetric_diff = model1_ids ^ model2_ids  # Sy  mmetric difference (correct answers by only one model)
    analysis = {
        "model1_name": model1_name,
        "model2_name": model2_name,
        "model1_correct_count": len(model1_ids),
        "model2_correct_count": len(model2_ids),
        "both_correct_count": len(intersection),
        "model1_only_count": len(model1_only),
        "model2_only_count": len(model2_only),
        "union_count": len(union),
        "symmetric_diff_count": len(symmetric_diff),
        "model1_only_ids": sorted(list(model1_only)),
        "model2_only_ids": sorted(list(model2_only)),
        "both_correct_ids": sorted(list(intersection)),
        "symmetric_diff_ids": sorted(list(symmetric_diff)),
        "statistics": {
            "model1_accuracy_relative": len(model1_ids) / len(union) if union else 0,
            "model2_accuracy_relative": len(model2_ids) / len(union) if union else 0,
            "overlap_ratio": len(intersection) / len(union) if union else 0,
            "model1_unique_ratio": len(model1_only) / len(model1_ids) if model1_ids else 0,
            "model2_unique_ratio": len(model2_only) / len(model2_ids) if model2_ids else 0,
        }
    }
    
    return analysis

def print_analysis_summary(analysis: Dict[str, Any]):

    print("=" * 60)
    print("Model Comparison Analysis Results")
    print("=" * 60)
    print(f"Model 1: {analysis['model1_name']}")
    print(f"Model 2: {analysis['model2_name']}")
    print()

    print("Basic Statistics:")
    print(f"  Model 1 correct answers: {analysis['model1_correct_count']}")
    print(f"  Model 2 correct answers: {analysis['model2_correct_count']}")
    print(f"  Correct answers by both models: {analysis['both_correct_count']}")
    print(f"  Correct answers only by Model 1: {analysis['model1_only_count']}")
    print(f"  Correct answers only by Model 2: {analysis['model2_only_count']}")
    print(f"  Correct answers by at least one model: {analysis['union_count']}")
    print(f"  Symmetric difference size: {analysis['symmetric_diff_count']}")
    print()

    stats = analysis['statistics']
    print("Relative Performance Analysis:")
    print(f"  Model 1 unique correct rate: {stats['model1_unique_ratio']:.2%}")
    print(f"  Model 2 unique correct rate: {stats['model2_unique_ratio']:.2%}")
    print(f"  Answer overlap rate: {stats['overlap_ratio']:.2%}")
    print()

    print("Strength Analysis:")
    if analysis['model1_only_count'] > analysis['model2_only_count']:
        advantage = analysis['model1_only_count'] - analysis['model2_only_count']
        print(f"  Model 1 answered {advantage} more questions correctly than Model 2")
    elif analysis['model2_only_count'] > analysis['model1_only_count']:
        advantage = analysis['model2_only_count'] - analysis['model1_only_count']
        print(f"  Model 2 answered {advantage} more questions correctly than Model 1")
    else:
        print(f"  Both models have equal number of uniquely correct answers")

def main():
    parser = argparse.ArgumentParser(description='Analyze differences between two models\' results')
    parser.add_argument('--model1_file', type=str, required=True, 
                    help='File path to Model 1\'s correct_ids.json')
    parser.add_argument('--model2_file', type=str, required=True,
                    help='File path to Model 2\'s correct_ids.json')
    parser.add_argument('--output_file', type=str, default=None,
                    help='File path to output analysis results as JSON (optional)')
    parser.add_argument('--model1_name', type=str, default=None,
                    help='Name of Model 1 (optional, inferred from filename)')
    parser.add_argument('--model2_name', type=str, default=None,
                    help='Name of Model 2 (optional, inferred from filename)')

    # New arguments for creating filtered datasets
    parser.add_argument('--create_datasets', action='store_true',
                    help='Whether to create filtered datasets')
    parser.add_argument('--dataset_path', type=str, default=None,
                    help='Path to original dataset (required when using --create_datasets)')
    parser.add_argument('--datasets_output_dir', type=str, default='./filtered_datasets',
                    help='Output directory for filtered datasets')
    
    args = parser.parse_args()
    
    # Check if files exist
    if not Path(args.model1_file).exists():
        print(f"Error: File does not exist {args.model1_file}")
        sys.exit(1)

    if not Path(args.model2_file).exists():
        print(f"Error: File does not exist {args.model2_file}")
        sys.exit(1)
    

    model1_name = args.model1_name or Path(args.model1_file).stem
    model2_name = args.model2_name or Path(args.model2_file).stem
    
    try:

        print(f"Loading Model 1 data: {args.model1_file}")
        model1_ids = load_correct_ids(args.model1_file)

        print(f"Loading Model 2 data: {args.model2_file}")
        model2_ids = load_correct_ids(args.model2_file)
        

        analysis = analyze_differences(model1_ids, model2_ids, model1_name, model2_name)
        

        print_analysis_summary(analysis)
        

        if args.output_file:
            output_path = Path(args.output_file)
            output_path.parent.mkdir(parents=True, exist_ok=True)
            
            with open(output_path, 'w') as f:
                json.dump(analysis, f, indent=2, ensure_ascii=False)
            print(f"\nDetailed analysis results saved to: {output_path}")
            

            if args.create_datasets:
                create_filtered_datasets_from_results(
                    args.model1_file, 
                    args.model2_file,
                    args.dataset_path,
                    args.datasets_output_dir,
                    model1_name,
                    model2_name
                )
        

        print("\n" + "=" * 60)
        print("Usage suggestions:")
        print("- model1_only_ids: Questions answered correctly only by Model 1, potentially indicating Model 1's area of strength")
        print("- model2_only_ids: Questions answered correctly only by Model 2, potentially indicating Model 2's area of strength")
        print("- symmetric_diff_ids: Questions where the two models show inconsistent performance, worthy of further analysis")
        print("- both_correct_ids: Questions that both models can handle well")

        if args.create_datasets:
            print(f"- Filtered datasets saved to: {args.datasets_output_dir}")

    except Exception as e:
        print(f"Error: {e}")
        sys.exit(1)
        
    except Exception as e:
        print(f"Error: {e}")
        sys.exit(1)

def filter_dataset_by_correct_ids(dataset_path: str, correct_ids: Set[int], 
                                output_path: str, model_name: str):
    """
    Filter dataset and save based on question IDs answered correctly by the model
    
    Args:
        dataset_path: Path to the original dataset
        correct_ids: Set of question IDs answered correctly by the model
        output_path: Path to save the filtered dataset
        model_name: Name of the model, used for log output
    """
    print(f"\nFiltering dataset for model '{model_name}'...")
    print(f"Original dataset path: {dataset_path}")
    print(f"Number of correctly answered questions: {len(correct_ids)}")
    
    try:
        
        dataset = load_dataset("cais/mmlu", "all")
        
        
        if hasattr(dataset, 'keys'):
            if 'test' in dataset:
                original_dataset = dataset['test']
                print(f"Using test split, original dataset size: {len(original_dataset)}")
            else:
                split_name = list(dataset.keys())[0]
                original_dataset = dataset[split_name]
                print(f"Using '{split_name}' split, original dataset size: {len(original_dataset)}")
        else:
            original_dataset = dataset
            print(f"Original dataset size: {len(original_dataset)}")
        
        
        filtered_indices = [i for i in correct_ids if i < len(original_dataset)]
        if len(filtered_indices) != len(correct_ids):
            missing_count = len(correct_ids) - len(filtered_indices)
            print(f"Warning: {missing_count} indices are out of dataset range and will be ignored")


        filtered_data = [original_dataset[i] for i in sorted(filtered_indices)]
        filtered_dataset = Dataset.from_list(filtered_data)

        print(f"Filtered dataset size: {len(filtered_dataset)}")


        output_path = Path(output_path)
        output_path.mkdir(parents=True, exist_ok=True)

        filtered_dataset.save_to_disk(str(output_path))
        print(f"Filtered dataset saved to: {output_path}")
        
        
        index_info = {
            "model_name": model_name,
            "original_dataset_path": dataset_path,
            "original_size": len(original_dataset),
            "filtered_size": len(filtered_dataset),
            "correct_ids": sorted(list(filtered_indices)),
            "creation_time": str(Path().cwd()),  
        }
        
        index_file = output_path / "filter_info.json"
        with open(index_file, 'w') as f:
            json.dump(index_info, f, indent=2, ensure_ascii=False)
        print(f"Filter information saved to: {index_file}")

        return len(filtered_dataset)

    except Exception as e:
        print(f"Error filtering dataset: {e}")
        raise

def create_filtered_datasets_from_results(model1_file: str, model2_file: str,
                                        dataset_path: str, output_dir: str,
                                        model1_name: str = None, model2_name: str = None):
    """
    Create filtered datasets based on results from two models
    Creates two datasets: questions answered correctly only by model 1 and 
    questions answered correctly only by model 2
    
    Args:
        model1_file: File path to model 1's correct_ids.json
        model2_file: File path to model 2's correct_ids.json
        dataset_path: Path to the original dataset
        output_dir: Output directory
        model1_name: Name of model 1
        model2_name: Name of model 2
    """

    if not model1_name:
        model1_name = Path(model1_file).stem.replace('_correct_ids', '')
    if not model2_name:
        model2_name = Path(model2_file).stem.replace('_correct_ids', '')
    
    print("=" * 60)
    print("Creating filtered datasets (only retaining questions correctly answered exclusively by each model)")
    print("=" * 60)


    model1_ids = load_correct_ids(model1_file)
    model2_ids = load_correct_ids(model2_file)


    model1_only = model1_ids - model2_ids  
    model2_only = model2_ids - model1_ids   

    print(f"Total questions correctly answered by Model 1: {len(model1_ids)}")
    print(f"Total questions correctly answered by Model 2: {len(model2_ids)}")
    print(f"Questions correctly answered only by Model 1: {len(model1_only)}")
    print(f"Questions correctly answered only by Model 2: {len(model2_only)}")
    
    output_dir = Path(output_dir)
    

    if model1_only:
        model1_output = output_dir / f"{model1_name}_only_correct_dataset"
        filter_dataset_by_correct_ids(dataset_path, model1_only, str(model1_output), f"{model1_name}_only")
        print(f"✅ Dataset with questions correctly answered exclusively by Model 1 created: {model1_output}")
    else:
        print(f"⚠️  Model 1 has no exclusively correct questions, skipping dataset creation")


    if model2_only:
        model2_output = output_dir / f"{model2_name}_only_correct_dataset"
        filter_dataset_by_correct_ids(dataset_path, model2_only, str(model2_output), f"{model2_name}_only")
        print(f"✅ Dataset with questions correctly answered exclusively by Model 2 created: {model2_output}")
    else:
        print(f"⚠️  Model 2 has no exclusively correct questions, skipping dataset creation")

    print(f"\nCompleted! Filtered datasets saved to {output_dir}")
    print("=" * 60)
    print("Dataset descriptions:")
    print(f"- {model1_name}_only_correct_dataset: Questions answered correctly by {model1_name} but incorrectly by {model2_name}")
    print(f"- {model2_name}_only_correct_dataset: Questions answered correctly by {model2_name} but incorrectly by {model1_name}")
    print("=" * 60)
    
    return (str(output_dir / f"{model1_name}_only_correct_dataset") if model1_only else None,
            str(output_dir / f"{model2_name}_only_correct_dataset") if model2_only else None)

if __name__ == "__main__":
    main()
