#!/usr/bin/env python3
"""
Script to balance ST and ST-HC datasets by downsampling the larger dataset.

As mentioned in the paper, we align the datasets to the same number of samples 
by downsizing the larger dataset, randomly removing questions not present in the 
smaller dataset and ensuring both datasets used for finetuning contain an equal 
number of samples.
"""

import os
import json
import argparse
import logging
import random
import glob
from tqdm import tqdm
from typing import List, Dict, Any

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

def load_dataset(dataset_path: str) -> List[Dict[str, Any]]:
    """
    Load a dataset from a directory or file.
    
    Args:
        dataset_path: Path to the dataset directory or file
    
    Returns:
        List of dataset samples
    """
    samples = []
    
    if os.path.isdir(dataset_path):
        # Try to find samples.json
        samples_file = os.path.join(dataset_path, "samples.json")
        if os.path.exists(samples_file):
            try:
                with open(samples_file, 'r') as f:
                    samples = json.load(f)
                logging.info(f"Loaded {len(samples)} samples from {samples_file}")
                return samples
            except Exception as e:
                logging.error(f"Error loading {samples_file}: {e}")
        
        # Try to find st_hc_dataset.json for ST-HC
        st_hc_file = os.path.join(dataset_path, "st_hc_dataset.json")
        if os.path.exists(st_hc_file):
            try:
                with open(st_hc_file, 'r') as f:
                    samples = json.load(f)
                logging.info(f"Loaded {len(samples)} samples from {st_hc_file}")
                return samples
            except Exception as e:
                logging.error(f"Error loading {st_hc_file}: {e}")
        
        # If no combined file found, try to load individual sample files
        sample_files = glob.glob(os.path.join(dataset_path, "*.json"))
        sample_files = [f for f in sample_files if not os.path.basename(f).startswith("metadata") and
                       not os.path.basename(f) == "samples.json" and
                       not os.path.basename(f) == "st_hc_dataset.json"]
        
        if sample_files:
            for file_path in tqdm(sample_files, desc=f"Loading samples from {dataset_path}"):
                try:
                    with open(file_path, 'r') as f:
                        sample = json.load(f)
                        if isinstance(sample, dict):
                            samples.append(sample)
                except Exception as e:
                    logging.error(f"Error loading {file_path}: {e}")
            
            logging.info(f"Loaded {len(samples)} samples from {len(sample_files)} individual files")
            return samples
    else:
        # Load from a single JSON file
        try:
            with open(dataset_path, 'r') as f:
                samples = json.load(f)
            logging.info(f"Loaded {len(samples)} samples from {dataset_path}")
            return samples
        except Exception as e:
            logging.error(f"Error loading {dataset_path}: {e}")
    
    if not samples:
        logging.error(f"Failed to load any samples from {dataset_path}")
    
    return samples

def get_question_hash(sample: Dict[str, Any]) -> str:
    """
    Generate a hash/key from the question to identify unique questions.
    
    Args:
        sample: Dataset sample
        
    Returns:
        String hash/key for the question
    """
    # Get the question text
    question = sample.get("question", "")
    
    # Simple normalization: lowercase, remove spaces
    normalized = ''.join(question.lower().split())
    
    # Return first 100 chars as a simple hash
    # This is a simple approach - could be improved with a proper hash function
    return normalized[:100]

def save_dataset(samples: List[Dict[str, Any]], output_path: str, dataset_name: str):
    """
    Save the balanced dataset.
    
    Args:
        samples: List of dataset samples
        output_path: Directory to save the dataset
        dataset_name: Name of the dataset (ST or ST-HC)
    """
    os.makedirs(output_path, exist_ok=True)
    
    # Save combined file
    combined_path = os.path.join(output_path, f"{dataset_name}_balanced.json")
    try:
        with open(combined_path, 'w') as f:
            json.dump(samples, f, indent=2)
        logging.info(f"Saved {len(samples)} balanced samples to {combined_path}")
    except Exception as e:
        logging.error(f"Error saving to {combined_path}: {e}")
    
    # Save JSONL version
    jsonl_path = os.path.join(output_path, f"{dataset_name}_balanced.jsonl")
    try:
        with open(jsonl_path, 'w') as f:
            for sample in samples:
                f.write(json.dumps(sample) + "\n")
        logging.info(f"Saved {len(samples)} balanced samples to {jsonl_path}")
    except Exception as e:
        logging.error(f"Error saving to {jsonl_path}: {e}")
    
    # Create a metadata file
    metadata = {
        "dataset_name": dataset_name,
        "description": f"Balanced {dataset_name} dataset",
        "num_samples": len(samples),
        "date_created": os.path.basename(__file__),
        "original_dataset": dataset_name
    }
    
    metadata_path = os.path.join(output_path, f"{dataset_name}_metadata.json")
    try:
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)
        logging.info(f"Saved metadata to {metadata_path}")
    except Exception as e:
        logging.error(f"Error saving metadata to {metadata_path}: {e}")

def balance_datasets(st_samples: List[Dict[str, Any]], 
                    st_hc_samples: List[Dict[str, Any]], 
                    output_dir: str) -> tuple:
    """
    Balance the ST and ST-HC datasets by downsampling the larger one.
    
    Args:
        st_samples: Samples from the ST dataset
        st_hc_samples: Samples from the ST-HC dataset
        output_dir: Directory to save the balanced datasets
        
    Returns:
        Tuple of (balanced_st_samples, balanced_st_hc_samples)
    """
    # Determine which dataset is smaller
    if len(st_samples) <= len(st_hc_samples):
        smaller_dataset = "ST"
        smaller_samples = st_samples
        larger_dataset = "ST-HC"
        larger_samples = st_hc_samples
    else:
        smaller_dataset = "ST-HC"
        smaller_samples = st_hc_samples
        larger_dataset = "ST"
        larger_samples = st_samples
    
    logging.info(f"Smaller dataset: {smaller_dataset} with {len(smaller_samples)} samples")
    logging.info(f"Larger dataset: {larger_dataset} with {len(larger_samples)} samples")
    
    # Create question hashes for the smaller dataset
    smaller_question_hashes = set(get_question_hash(sample) for sample in smaller_samples)
    
    # Create question hashes for the larger dataset
    larger_hashes_to_samples = {}
    for sample in larger_samples:
        question_hash = get_question_hash(sample)
        larger_hashes_to_samples[question_hash] = sample
    
    # Identify questions in the larger dataset that are not in the smaller dataset
    larger_unique_hashes = set(larger_hashes_to_samples.keys()) - smaller_question_hashes
    larger_unique_samples = [larger_hashes_to_samples[h] for h in larger_unique_hashes]
    
    # Identify questions in both datasets
    common_hashes = set(larger_hashes_to_samples.keys()) & smaller_question_hashes
    common_samples = [larger_hashes_to_samples[h] for h in common_hashes]
    
    logging.info(f"Found {len(common_samples)} questions common to both datasets")
    logging.info(f"Found {len(larger_unique_samples)} questions unique to the {larger_dataset} dataset")
    
    # Calculate how many samples we need to keep from the larger dataset's unique questions
    num_to_keep = len(smaller_samples) - len(common_samples)
    
    if num_to_keep <= 0:
        logging.warning(f"Number of common samples ({len(common_samples)}) exceeds or equals smaller dataset size ({len(smaller_samples)})")
        logging.warning("Using only common samples for both datasets")
        balanced_larger = common_samples[:len(smaller_samples)]
        return (balanced_larger, smaller_samples) if smaller_dataset == "ST" else (smaller_samples, balanced_larger)
    
    if num_to_keep > len(larger_unique_samples):
        logging.warning(f"Not enough unique samples in larger dataset to reach target size")
        logging.warning(f"Keeping all {len(larger_unique_samples)} unique samples")
        num_to_keep = len(larger_unique_samples)
    
    # Randomly select questions from the larger dataset's unique questions
    random.seed(42)  # For reproducibility
    selected_unique_samples = random.sample(larger_unique_samples, num_to_keep)
    
    # Create balanced larger dataset
    balanced_larger = common_samples + selected_unique_samples
    random.shuffle(balanced_larger)  # Shuffle for good measure
    
    logging.info(f"Balanced {larger_dataset} dataset now has {len(balanced_larger)} samples")
    logging.info(f"Balanced {smaller_dataset} dataset has {len(smaller_samples)} samples")
    
    # Save the balanced datasets
    if smaller_dataset == "ST":
        save_dataset(smaller_samples, output_dir, "ST")
        save_dataset(balanced_larger, output_dir, "ST-HC")
        return smaller_samples, balanced_larger
    else:
        save_dataset(balanced_larger, output_dir, "ST")
        save_dataset(smaller_samples, output_dir, "ST-HC")
        return balanced_larger, smaller_samples

def main():
    """Main function for balancing datasets"""
    parser = argparse.ArgumentParser(description="Balance ST and ST-HC datasets")
    parser.add_argument("--st_path", type=str, required=True,
                      help="Path to the ST dataset")
    parser.add_argument("--st_hc_path", type=str, required=True,
                      help="Path to the ST-HC dataset")
    parser.add_argument("--output_dir", type=str, default="datasets/balanced",
                      help="Directory to save the balanced datasets")
    
    args = parser.parse_args()
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Load datasets
    st_samples = load_dataset(args.st_path)
    st_hc_samples = load_dataset(args.st_hc_path)
    
    if not st_samples or not st_hc_samples:
        logging.error("Failed to load one or both datasets. Aborting.")
        return
    
    # Balance datasets
    balanced_st, balanced_st_hc = balance_datasets(st_samples, st_hc_samples, args.output_dir)
    
    logging.info("Dataset balancing complete!")
    logging.info(f"Balanced ST dataset: {len(balanced_st)} samples")
    logging.info(f"Balanced ST-HC dataset: {len(balanced_st_hc)} samples")
    logging.info(f"Datasets saved to {args.output_dir}")

if __name__ == "__main__":
    main() 