#!/usr/bin/env python3
"""
Generate ST-NT dataset (No Traces) as described in the paper.

This script creates a version of the ST dataset that contains only question-answer
pairs without any reasoning traces. This allows for evaluating the impact of
reasoning traces on model performance.
"""

import os
import json
import argparse
import logging
import time
import glob
from tqdm import tqdm

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler(f"st_nt_generation_{time.strftime('%Y%m%d_%H%M%S')}.log")
    ]
)

def load_st_samples(input_dir):
    """
    Load samples from ST dataset directory.
    
    Args:
        input_dir: Directory containing ST dataset files
        
    Returns:
        List of samples from the ST dataset
    """
    samples = []
    
    # Try to find sample files with pattern like sthc_*.json
    sample_files = glob.glob(os.path.join(input_dir, "*.json"))
    sample_files = [f for f in sample_files if os.path.basename(f).startswith("sthc_") or 
                   os.path.basename(f).startswith("st_")]
    
    if not sample_files:
        # Try to load from samples.json if individual files not found
        combined_file = os.path.join(input_dir, "samples.json")
        if os.path.exists(combined_file):
            try:
                with open(combined_file, 'r') as f:
                    samples = json.load(f)
                logging.info(f"Loaded {len(samples)} samples from combined file {combined_file}")
                return samples
            except Exception as e:
                logging.error(f"Error loading combined file {combined_file}: {str(e)}")
                return []
        else:
            logging.error(f"No sample files found in {input_dir}")
            return []
    
    # Load from individual sample files
    for file_path in tqdm(sample_files, desc="Loading ST samples"):
        try:
            with open(file_path, 'r') as f:
                sample = json.load(f)
                samples.append(sample)
        except Exception as e:
            logging.error(f"Error loading sample file {file_path}: {str(e)}")
    
    logging.info(f"Loaded {len(samples)} samples from {len(sample_files)} files")
    return samples

def create_st_nt_sample(st_sample):
    """
    Create an ST-NT sample by removing the reasoning trace from an ST sample.
    
    Args:
        st_sample: A sample from the ST dataset
        
    Returns:
        A sample for the ST-NT dataset without the reasoning trace
    """
    # Create a new sample with just question and answer
    st_nt_sample = {
        "question_id": st_sample.get("question_id", f"stnt_{time.time()}"),
        "question": st_sample.get("question", ""),
        "answer": st_sample.get("answer", ""),
        "ground_truth_answer": st_sample.get("ground_truth_answer"),
        "answer_correct": st_sample.get("answer_correct", True),
        "domain": st_sample.get("domain", "unknown"),
        "dataset": st_sample.get("dataset", "unknown"),
        "date_generated": time.strftime("%Y-%m-%d %H:%M:%S"),
        "original_dataset": os.path.basename(st_sample.get("original_dataset", "ST")),
    }
    
    return st_nt_sample

def save_st_nt_samples(st_nt_samples, output_dir):
    """
    Save ST-NT samples to output directory.
    
    Args:
        st_nt_samples: List of ST-NT samples to save
        output_dir: Directory to save samples to
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Save each sample individually
    for i, sample in enumerate(tqdm(st_nt_samples, desc="Saving ST-NT samples")):
        sample_path = os.path.join(output_dir, f"stnt_{i:04d}.json")
        try:
            with open(sample_path, 'w') as f:
                json.dump(sample, f, indent=2)
        except Exception as e:
            logging.error(f"Error saving sample to {sample_path}: {str(e)}")
    
    # Save combined samples file
    combined_path = os.path.join(output_dir, "samples.json")
    try:
        with open(combined_path, 'w') as f:
            json.dump(st_nt_samples, f, indent=2)
        logging.info(f"Saved combined file with {len(st_nt_samples)} samples to {combined_path}")
    except Exception as e:
        logging.error(f"Error saving combined file to {combined_path}: {str(e)}")
    
    # Save JSONL version for easier processing
    jsonl_path = os.path.join(output_dir, "samples.jsonl")
    try:
        with open(jsonl_path, 'w') as f:
            for sample in st_nt_samples:
                f.write(json.dumps(sample) + "\n")
        logging.info(f"Saved JSONL file with {len(st_nt_samples)} samples to {jsonl_path}")
    except Exception as e:
        logging.error(f"Error saving JSONL file to {jsonl_path}: {str(e)}")

def generate_st_nt_dataset(input_dir, output_dir):
    """
    Generate ST-NT dataset by removing reasoning traces from ST dataset.
    
    Args:
        input_dir: Directory containing ST dataset
        output_dir: Directory to save ST-NT dataset to
    """
    logging.info(f"Generating ST-NT dataset from {input_dir} to {output_dir}")
    
    # Load ST samples
    st_samples = load_st_samples(input_dir)
    if not st_samples:
        logging.error("No samples loaded, aborting")
        return
    
    # Create ST-NT samples
    st_nt_samples = []
    for st_sample in tqdm(st_samples, desc="Creating ST-NT samples"):
        st_nt_sample = create_st_nt_sample(st_sample)
        st_nt_samples.append(st_nt_sample)
    
    # Save ST-NT samples
    save_st_nt_samples(st_nt_samples, output_dir)
    
    logging.info(f"ST-NT dataset generation complete: {len(st_nt_samples)} samples")
    logging.info(f"Output directory: {output_dir}")

def main():
    """Main entry point for generating ST-NT dataset"""
    parser = argparse.ArgumentParser(description="Generate ST-NT dataset (No Traces)")
    parser.add_argument("--input_dir", type=str, required=True,
                      help="Directory containing ST dataset files")
    parser.add_argument("--output_dir", type=str, default="datasets/SmolTraces-NT",
                      help="Directory to save ST-NT dataset to")
    
    args = parser.parse_args()
    
    # Generate ST-NT dataset
    generate_st_nt_dataset(args.input_dir, args.output_dir)

if __name__ == "__main__":
    main() 