#!/usr/bin/env python3
"""
Script for decontaminating seed datasets by removing questions that overlap with evaluation datasets.

As mentioned in the paper, we perform decontamination to ensure that none of the
evaluation questions are present in the seed samples used for training.
"""

import os
import sys
import json
import argparse
import logging
import glob
from tqdm import tqdm
from typing import List, Dict, Any, Set

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

def load_jsonl_file(file_path: str) -> List[Dict[str, Any]]:
    """
    Load data from a JSONL file.
    
    Args:
        file_path: Path to the JSONL file
        
    Returns:
        List of dictionaries, each representing a sample
    """
    samples = []
    try:
        with open(file_path, 'r') as f:
            for line in f:
                try:
                    sample = json.loads(line.strip())
                    samples.append(sample)
                except json.JSONDecodeError:
                    logging.warning(f"Skipping invalid JSON line in {file_path}")
    except Exception as e:
        logging.error(f"Error loading {file_path}: {e}")
    
    return samples

def load_json_file(file_path: str) -> List[Dict[str, Any]]:
    """
    Load data from a JSON file.
    
    Args:
        file_path: Path to the JSON file
        
    Returns:
        List of dictionaries or a single dictionary, each representing a sample
    """
    try:
        with open(file_path, 'r') as f:
            data = json.load(f)
            
            # Handle both list and dict formats
            if isinstance(data, list):
                return data
            elif isinstance(data, dict):
                # If it's a single sample, wrap it in a list
                if "question" in data:
                    return [data]
                # It might be a HuggingFace dataset format
                elif "data" in data:
                    return data["data"]
                else:
                    logging.warning(f"Unrecognized JSON format in {file_path}")
                    return []
            else:
                logging.warning(f"Unrecognized data type in {file_path}")
                return []
    except Exception as e:
        logging.error(f"Error loading {file_path}: {e}")
        return []

def load_samples_from_directory(dir_path: str) -> List[Dict[str, Any]]:
    """
    Load all samples from a directory containing JSON/JSONL files.
    
    Args:
        dir_path: Path to the directory
        
    Returns:
        List of all samples from the directory
    """
    samples = []
    
    # Check if it's a directory
    if not os.path.isdir(dir_path):
        if os.path.isfile(dir_path):
            # Single file - determine type and load
            if dir_path.endswith('.jsonl'):
                return load_jsonl_file(dir_path)
            elif dir_path.endswith('.json'):
                return load_json_file(dir_path)
            else:
                logging.warning(f"Unsupported file type: {dir_path}")
                return []
        else:
            logging.error(f"Path does not exist: {dir_path}")
            return []
    
    # Look for JSONL files first (more efficient for large datasets)
    jsonl_files = glob.glob(os.path.join(dir_path, "*.jsonl"))
    for file_path in tqdm(jsonl_files, desc=f"Loading JSONL files from {dir_path}"):
        samples.extend(load_jsonl_file(file_path))
    
    # Then look for JSON files
    json_files = glob.glob(os.path.join(dir_path, "*.json"))
    # Skip metadata files
    json_files = [f for f in json_files if not os.path.basename(f).startswith("metadata")]
    
    for file_path in tqdm(json_files, desc=f"Loading JSON files from {dir_path}"):
        samples.extend(load_json_file(file_path))
    
    logging.info(f"Loaded {len(samples)} total samples from {dir_path}")
    return samples

def normalize_question(question: str) -> str:
    """
    Normalize a question text for comparison.
    
    Args:
        question: The original question text
        
    Returns:
        Normalized question text
    """
    # Simple normalization: lowercase, remove spaces and punctuation
    import re
    normalized = re.sub(r'[^\w\s]', '', question.lower())
    normalized = ''.join(normalized.split())
    return normalized[:200]  # Use first 200 chars as a fingerprint

def extract_questions(samples: List[Dict[str, Any]]) -> Set[str]:
    """
    Extract and normalize questions from a list of samples.
    
    Args:
        samples: List of sample dictionaries
        
    Returns:
        Set of normalized question texts
    """
    questions = set()
    for sample in samples:
        # Handle different possible formats
        if "question" in sample:
            question_text = sample["question"]
        elif "problem" in sample:
            question_text = sample["problem"]
        elif "prompt" in sample:
            question_text = sample["prompt"]
        else:
            # Skip if we can't find the question field
            continue
        
        if isinstance(question_text, str) and len(question_text) > 10:
            questions.add(normalize_question(question_text))
    
    return questions

def decontaminate_dataset(seed_samples: List[Dict[str, Any]], 
                         eval_questions: Set[str]) -> List[Dict[str, Any]]:
    """
    Remove samples from seed dataset that overlap with evaluation questions.
    
    Args:
        seed_samples: List of samples from the seed dataset
        eval_questions: Set of normalized evaluation questions
        
    Returns:
        Decontaminated list of seed samples
    """
    decontaminated_samples = []
    contaminated_count = 0
    
    for sample in tqdm(seed_samples, desc="Decontaminating dataset"):
        if "question" in sample:
            question_text = sample["question"]
        elif "problem" in sample:
            question_text = sample["problem"]
        elif "prompt" in sample:
            question_text = sample["prompt"]
        else:
            # Keep samples where we can't find the question field
            decontaminated_samples.append(sample)
            continue
        
        normalized = normalize_question(question_text)
        
        # Check if this question is in the evaluation set
        if normalized in eval_questions:
            contaminated_count += 1
        else:
            decontaminated_samples.append(sample)
    
    return decontaminated_samples, contaminated_count

def save_dataset(samples: List[Dict[str, Any]], output_path: str, dataset_name: str):
    """
    Save the decontaminated dataset.
    
    Args:
        samples: List of dataset samples
        output_path: Directory to save the dataset
        dataset_name: Name of the dataset for file naming
    """
    os.makedirs(output_path, exist_ok=True)
    
    # Save as JSON
    json_path = os.path.join(output_path, f"{dataset_name}_decontaminated.json")
    try:
        with open(json_path, 'w') as f:
            json.dump(samples, f, indent=2)
        logging.info(f"Saved {len(samples)} decontaminated samples to {json_path}")
    except Exception as e:
        logging.error(f"Error saving to {json_path}: {e}")
    
    # Save as JSONL (more efficient for large datasets)
    jsonl_path = os.path.join(output_path, f"{dataset_name}_decontaminated.jsonl")
    try:
        with open(jsonl_path, 'w') as f:
            for sample in samples:
                f.write(json.dumps(sample) + "\n")
        logging.info(f"Saved {len(samples)} decontaminated samples to {jsonl_path}")
    except Exception as e:
        logging.error(f"Error saving to {jsonl_path}: {e}")
    
    # Create a metadata file
    metadata = {
        "dataset_name": f"{dataset_name}_decontaminated",
        "description": f"Decontaminated {dataset_name} dataset",
        "num_samples": len(samples),
        "date_created": os.path.basename(__file__),
        "original_dataset": dataset_name
    }
    
    metadata_path = os.path.join(output_path, f"{dataset_name}_metadata.json")
    try:
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)
        logging.info(f"Saved metadata to {metadata_path}")
    except Exception as e:
        logging.error(f"Error saving metadata to {metadata_path}: {e}")

def main():
    """Main function for decontaminating datasets"""
    parser = argparse.ArgumentParser(description="Decontaminate seed datasets by removing evaluation questions")
    parser.add_argument("--seed_paths", type=str, nargs='+', required=True,
                      help="Paths to seed dataset directories or files")
    parser.add_argument("--eval_paths", type=str, nargs='+', required=True,
                      help="Paths to evaluation dataset directories or files")
    parser.add_argument("--output_dir", type=str, default="datasets/decontaminated",
                      help="Directory to save decontaminated datasets")
    
    args = parser.parse_args()
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Load evaluation datasets and extract questions
    eval_samples = []
    for eval_path in args.eval_paths:
        eval_samples.extend(load_samples_from_directory(eval_path))
    
    if not eval_samples:
        logging.error("No evaluation samples loaded. Aborting.")
        return
    
    eval_questions = extract_questions(eval_samples)
    logging.info(f"Extracted {len(eval_questions)} unique evaluation questions")
    
    # Process each seed dataset
    for seed_path in args.seed_paths:
        dataset_name = os.path.basename(seed_path).replace(".json", "").replace(".jsonl", "")
        if dataset_name == "" or dataset_name == os.path.basename(os.path.dirname(seed_path)):
            dataset_name = os.path.basename(os.path.dirname(seed_path))
        
        seed_samples = load_samples_from_directory(seed_path)
        if not seed_samples:
            logging.warning(f"No samples loaded from {seed_path}. Skipping.")
            continue
        
        logging.info(f"Processing {dataset_name} with {len(seed_samples)} samples")
        
        # Decontaminate dataset
        decontaminated_samples, contaminated_count = decontaminate_dataset(
            seed_samples, eval_questions)
        
        logging.info(f"Removed {contaminated_count} contaminated samples from {dataset_name}")
        logging.info(f"Remaining samples: {len(decontaminated_samples)}")
        
        # Save decontaminated dataset
        save_dataset(decontaminated_samples, args.output_dir, dataset_name)
    
    logging.info("Decontamination complete!")

if __name__ == "__main__":
    main() 