#!/usr/bin/env python3
"""
Directly create datasets containing only samples correctly answered by each respective model, 
based on the correct_ids.json files of two models.
"""

import argparse
import json
import sys
from pathlib import Path


from analyze_model_differences import load_correct_ids, filter_dataset_by_correct_ids

def create_exclusive_datasets_simple(model1_file: str, model2_file: str, 
                                   output_dir: str, model1_name: str = None, 
                                   model2_name: str = None):


    if not model1_name:
        model1_name = Path(model1_file).stem.replace('_correct_ids', '')
    if not model2_name:
        model2_name = Path(model2_file).stem.replace('_correct_ids', '')
    
    print(f"Creating unique datasets...")
    print(f"Model 1: {model1_name}")
    print(f"Model 2: {model2_name}")
    print(f"Output directory: {output_dir}")
    print("-" * 50)
    

    model1_ids = load_correct_ids(model1_file)
    model2_ids = load_correct_ids(model2_file)
    

    model1_only = model1_ids - model2_ids
    model2_only = model2_ids - model1_ids
    both_correct = model1_ids & model2_ids
    
    print(f"Total correct by Model 1: {len(model1_ids)} questions")
    print(f"Total correct by Model 2: {len(model2_ids)} questions")
    print(f"Correct by both models: {len(both_correct)} questions")
    print(f"Correct only by Model 1: {len(model1_only)} questions")
    print(f"Correct only by Model 2: {len(model2_only)} questions")
    print()
    
    
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    

    exclusive_data = {
        "model1_name": model1_name,
        "model2_name": model2_name,
        "model1_total_correct": len(model1_ids),
        "model2_total_correct": len(model2_ids),
        "both_correct_count": len(both_correct),
        "model1_only_count": len(model1_only),
        "model2_only_count": len(model2_only),
        "model1_only_ids": sorted(list(model1_only)),
        "model2_only_ids": sorted(list(model2_only)),
        "both_correct_ids": sorted(list(both_correct))
    }
    
    with open(output_path / "exclusive_analysis.json", 'w') as f:
        json.dump(exclusive_data, f, indent=2, ensure_ascii=False)
    print(f"✅ Exclusive question analysis saved to: {output_path / 'exclusive_analysis.json'}")
    
    datasets_created = []
    
    if model1_only:
        print(f"\nCreating dataset unique to Model 1 ({len(model1_only)} questions)...")
        model1_path = output_path / f"{model1_name}_only_dataset"
        count1 = filter_dataset_by_correct_ids(
            "cais/mmlu", model1_only, str(model1_path), f"{model1_name}_only"
        )
        datasets_created.append((model1_path, count1, f"{model1_name}_only"))
        print(f"✅ finished: {model1_path}")
    
    if model2_only:
        print(f"\nCreating dataset unique to Model 2 ({len(model2_only)} questions)...")
        model2_path = output_path / f"{model2_name}_only_dataset"
        count2 = filter_dataset_by_correct_ids(
            "cais/mmlu", model2_only, str(model2_path), f"{model2_name}_only"
        )
        datasets_created.append((model2_path, count2, f"{model2_name}_only"))
        print(f"✅ finished: {model2_path}")
    

    print("\n" + "="*60)
    print("Creation completed!")
    print("="*60)
    print(f"Output directory: {output_path}")
    print(f"Analysis file: exclusive_analysis.json")
    print("\nCreated datasets:")
    for path, count, name in datasets_created:
        print(f"  - {name}: {path} ({count} questions)")

    if not datasets_created:
        print("  (No exclusive questions, no datasets created)")

    print("\nDataset uses:")
    print("- Analyze model strengths and weaknesses")
    print("- Conduct targeted training")
    print("- Research model complementarity")
    print("="*60)
    
    return datasets_created

def main():
    parser = argparse.ArgumentParser(description='Quickly create datasets with questions answered correctly only by each respective model')
    parser.add_argument('--model1_file', type=str, required=True,
                    help='File path to model 1\'s correct_ids.json')
    parser.add_argument('--model2_file', type=str, required=True,
                    help='File path to model 2\'s correct_ids.json')
    parser.add_argument('--output_dir', type=str, required=True,
                    help='Output directory path')
    parser.add_argument('--model1_name', type=str, default=None,
                    help='Name of model 1 (optional)')
    parser.add_argument('--model2_name', type=str, default=None,
                    help='Name of model 2 (optional)')
    
    args = parser.parse_args()
    

    for file_path in [args.model1_file, args.model2_file]:
        if not Path(file_path).exists():
            print(f"Error: File does not exist {file_path}")
            sys.exit(1)
    
    try:
        create_exclusive_datasets_simple(
            args.model1_file,
            args.model2_file,
            args.output_dir,
            args.model1_name,
            args.model2_name
        )
    except Exception as e:
        print(f"Error: {e}")
        sys.exit(1)

if __name__ == "__main__":
    main()
