"""Script to generate offline GP data batches for training using Hydra configuration."""

import json
from pathlib import Path

import hydra
import torch
from omegaconf import DictConfig

from src.data.utils import generate_offline_batches


@hydra.main(
    version_base=None, config_path="../../configs", config_name="generate_offline_data"
)
def main(cfg: DictConfig) -> None:
    """Generate offline GP data batches based on Hydra configuration."""

    # Set random seed if provided
    if cfg.seed is not None:
        torch.manual_seed(cfg.seed)
        print(f"Set random seed to {cfg.seed}")

    # Create output directory
    output_dir = Path(cfg.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    print(f"\nGenerating {cfg.generation.num_batches} batches to {output_dir}")
    print(f"Configuration:")
    print(f"  Batch size: {cfg.generation.batch_size}")
    print(f"  Batches per chunk: {cfg.generation.chunk_size}")
    print(f"  Device: {cfg.device}")
    print(f"  Sampler config: {dict(cfg.sampler)}")
    print(f"  Generation config: {dict(cfg.generation)}")

    # Convert OmegaConf to regular dict for sampler kwargs
    sampler_kwargs = dict(cfg.sampler)
    sampler_kwargs["device"] = cfg.device

    # Convert dtype string to torch.dtype if present
    if "dtype" in sampler_kwargs:
        dtype_str = sampler_kwargs["dtype"]
        sampler_kwargs["dtype"] = getattr(torch, dtype_str)

    # Generate offline batches
    generate_offline_batches(
        save_dir=output_dir,
        num_batches=cfg.generation.num_batches,
        batch_size=cfg.generation.batch_size,
        sampler_data=cfg.sampler_data,
        num_context=cfg.generation.get("num_context", None),
        num_buffer=cfg.generation.num_buffer,
        num_target=cfg.generation.num_target,
        context_range=tuple(cfg.generation.context_range),
        chunk_size=cfg.generation.chunk_size,
        sampler_kwargs=sampler_kwargs,
    )

    print(
        f"\n✓ Successfully generated {cfg.generation.num_batches} batches in {output_dir}"
    )

    # Print summary of generated files
    chunk_files = list(output_dir.glob("chunk_*.pt"))
    metadata_file = output_dir / "metadata.json"

    print(f"\nGenerated files:")
    print(f"  Metadata: {metadata_file}")
    print(f"  Chunks: {len(chunk_files)} files")

    # Load and display metadata
    with open(metadata_file, "r") as f:
        metadata = json.load(f)

    print(f"\nDataset info:")
    print(f"  Total batches: {metadata['num_batches']}")
    print(f"  Batch size: {metadata['batch_size']}")
    print(f"  Chunk size: {metadata['chunk_size']}")
    print(f"  Number of chunks: {metadata['num_chunks']}")


if __name__ == "__main__":
    main()
