#!/usr/bin/env python
"""
Script to split a dataset with 16 buffer points into versions with 8 and 4 buffer points.
This creates three dataset versions from the original 16-buffer dataset.
"""

import torch
from pathlib import Path
import json
import shutil
from tqdm import tqdm


def split_buffer_dataset(source_dir: str, num_buffer_points: int, output_dir: str):
    """
    Create a new dataset with reduced buffer points from the original.
    
    Args:
        source_dir: Path to the original dataset with 16 buffer points
        num_buffer_points: Number of buffer points to keep (4 or 8)
        output_dir: Path to save the new dataset
    """
    source_path = Path(source_dir)
    output_path = Path(output_dir)
    
    # Create output directory
    output_path.mkdir(parents=True, exist_ok=True)
    
    # Load metadata
    with open(source_path / "metadata.json", "r") as f:
        metadata = json.load(f)
    
    # Update metadata for new buffer size
    metadata["generation_kwargs"]["num_buffer"] = num_buffer_points
    metadata["original_source"] = str(source_path)
    metadata["buffer_split_from"] = 16
    
    # Save updated metadata
    with open(output_path / "metadata.json", "w") as f:
        json.dump(metadata, f, indent=2)
    
    # Process each chunk file
    chunk_files = sorted(source_path.glob("chunk_*.pt"))
    
    for chunk_file in tqdm(chunk_files, desc=f"Processing chunks for {num_buffer_points}-buffer"):
        # Load chunk
        chunk_data = torch.load(chunk_file)
        
        # Process each batch in the chunk
        new_chunk_data = []
        for batch in chunk_data:
            # Keep only the first num_buffer_points of the buffer
            new_batch = {
                "xc": batch["xc"],
                "yc": batch["yc"],
                "xb": batch["xb"][:, :num_buffer_points],  # Take first N buffer points
                "yb": batch["yb"][:, :num_buffer_points],  # Take first N buffer points
                "xt": batch["xt"],
                "yt": batch["yt"],
            }
            new_chunk_data.append(new_batch)
        
        # Save modified chunk
        output_chunk_path = output_path / chunk_file.name
        torch.save(new_chunk_data, output_chunk_path)
    
    print(f"✓ Created {num_buffer_points}-buffer dataset in {output_path}")


def main():
    """Main function to create all three dataset versions."""
    
    # Define paths
    base_dir = Path("data")
    source_dataset = base_dir / "gp_128batch_16buf_256tar"
    
    # Check if source dataset exists
    if not source_dataset.exists():
        print(f"Error: Source dataset not found at {source_dataset}")
        print("Please generate the 16-buffer dataset first using:")
        print("  uv run python -m src.data.generate_offline_data \\")
        print("      --config-name offline_data_gp_highprecision \\")
        print("      output_dir=data/gp_128batch_16buf_256tar \\")
        print("      generation.num_batches=10000 \\")
        print("      generation.batch_size=128 \\")
        print("      generation.num_buffer=16 \\")
        print("      generation.num_target=256")
        return
    
    print("Creating dataset versions with different buffer sizes...")
    print(f"Source: {source_dataset}")
    print()
    
    # Create 8-buffer version
    output_8buf = base_dir / "gp_128batch_8buf_256tar"
    print(f"Creating 8-buffer version: {output_8buf}")
    split_buffer_dataset(source_dataset, 8, output_8buf)
    
    # Create 4-buffer version
    output_4buf = base_dir / "gp_128batch_4buf_256tar"
    print(f"Creating 4-buffer version: {output_4buf}")
    split_buffer_dataset(source_dataset, 4, output_4buf)
    
    # The 16-buffer version is the original, just note it
    print(f"\n✓ Original 16-buffer version: {source_dataset}")
    print("\n✓ Successfully created all three dataset versions:")
    print(f"  - 16-buffer: {source_dataset}")
    print(f"  - 8-buffer:  {output_8buf}")
    print(f"  - 4-buffer:  {output_4buf}")


if __name__ == "__main__":
    main()