"""Generate offline tabular regression data batches."""

import json
from pathlib import Path
import torch
import numpy as np
import argparse

from src.data.utils import generate_offline_batches


def main():
    parser = argparse.ArgumentParser(description='Generate offline tabular regression data')
    
    # Output settings
    parser.add_argument('--output_dir', type=str, default='data/tabular/offline')
    parser.add_argument('--clean', action='store_true', default=False,
                       help='Delete existing output_dir before generation')
    parser.add_argument('--num_batches', type=int, default=None, 
                       help='Number of batches (use this OR --target_size_mb)')
    parser.add_argument('--target_size_mb', type=int, default=None,
                       help='Target dataset size in MB (e.g., 5000 for 5GB)')
    parser.add_argument('--batch_size', type=int, default=32)
    parser.add_argument('--chunk_size', type=int, default=100)
    
    # Data generation settings
    parser.add_argument('--num_features', type=str, default='10', 
                       help='Number of features (comma-separated list to sample from)')
    parser.add_argument('--num_context', type=str, default=None,
                       help='Number of context points (comma-separated list to sample from, or use context_min/max)')
    parser.add_argument('--context_min', type=int, default=32)
    parser.add_argument('--context_max', type=int, default=256)
    parser.add_argument('--num_buffer', type=int, default=0,
                       help='Number of buffer points (fixed)')
    parser.add_argument('--num_target', type=int, default=128,
                       help='Number of target points (fixed)')
    
    # Feature/target normalization (TabICL-style for X, z-score for y)
    parser.add_argument('--normalize_x', action='store_true', default=False,
                       help='Apply TabICL-like normalization to features (fit on context; applies to xc/xb/xt)')
    parser.add_argument('--x_norm_method', type=str, default='power',
                       choices=['power', 'quantile', 'quantile_rtdl', 'none'],
                       help='Feature normalization method (requires TabICL deps for non-none)')
    parser.add_argument('--x_outlier_threshold', type=float, default=4.0,
                       help='Z-score threshold for outlier clipping in TabICL preprocessor')
    
    # MLP SCM settings
    parser.add_argument('--num_causes', type=int, default=None)
    parser.add_argument('--num_layers', type=int, default=4)
    parser.add_argument('--hidden_dim', type=int, default=64)
    parser.add_argument('--noise_std', type=float, default=0.01)
    parser.add_argument('--sampling', type=str, default='mixed', choices=['normal', 'uniform', 'mixed'])
    
    # Other settings
    parser.add_argument('--normalize_y', action='store_true', default=True)
    parser.add_argument('--dtype', type=str, default='float32', choices=['float16', 'bfloat16', 'float32', 'float64'],
                       help='Data type for generation')
    parser.add_argument('--device', type=str, default='cpu')
    parser.add_argument('--seed', type=int, default=None)
    
    args = parser.parse_args()
    
    # Validate arguments
    if args.num_batches is None and args.target_size_mb is None:
        parser.error("Must specify either --num_batches or --target_size_mb")
    if args.num_batches is not None and args.target_size_mb is not None:
        parser.error("Cannot specify both --num_batches and --target_size_mb")
    
    # Set seed if provided
    if args.seed is not None:
        torch.manual_seed(args.seed)
        print(f"Set random seed to {args.seed}")
    
    # Parse list arguments
    def parse_list_arg(arg_str):
        if arg_str is None:
            return None
        if ',' in arg_str:
            return [int(x.strip()) for x in arg_str.split(',')]
        else:
            return int(arg_str)
    
    num_features = parse_list_arg(args.num_features)
    num_context = parse_list_arg(args.num_context)
    
    # Get dtype
    dtype_map = {
        'float16': torch.float16,
        'bfloat16': torch.bfloat16,
        'float32': torch.float32,
        'float64': torch.float64,
    }
    dtype = dtype_map[args.dtype]
    
    # Calculate bytes per element
    bytes_per_element = {
        torch.float16: 2,
        torch.bfloat16: 2,
        torch.float32: 4,
        torch.float64: 8,
    }[dtype]
    
    # Estimate number of batches if using target size
    if args.target_size_mb is not None:
        # Calculate average sizes
        avg_features = np.mean(num_features) if isinstance(num_features, list) else num_features
        if num_context is not None:
            avg_context = np.mean(num_context) if isinstance(num_context, list) else num_context
        else:
            avg_context = (args.context_min + args.context_max) / 2
        
        # Calculate elements per batch
        elements_per_batch = args.batch_size * (
            (avg_context + args.num_buffer + args.num_target) * (avg_features + 1)  # +1 for y
        )
        
        # Calculate bytes per batch  
        bytes_per_batch = elements_per_batch * bytes_per_element
        
        # Add overhead for metadata (roughly 10%)
        bytes_per_batch *= 1.1
        
        # Calculate number of batches needed
        target_bytes = args.target_size_mb * 1024 * 1024
        num_batches = int(target_bytes / bytes_per_batch)
        
        print(f"\nTarget size: {args.target_size_mb} MB")
        print(f"Estimated bytes per batch: {bytes_per_batch / 1024:.1f} KB")
        print(f"Calculated number of batches: {num_batches}")
    else:
        num_batches = args.num_batches
    
    # Prepare output directory (optionally clean first)
    output_dir = Path(args.output_dir)
    if args.clean and output_dir.exists():
        import shutil
        print(f"Cleaning existing directory: {output_dir}")
        shutil.rmtree(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    print(f"\nGenerating {num_batches} tabular regression batches to {output_dir}")
    print(f"Configuration:")
    print(f"  Batch size: {args.batch_size}")
    print(f"  Batches per chunk: {args.chunk_size}")
    print(f"  Number of features: {num_features}")
    print(f"  Context: {num_context if num_context is not None else f'[{args.context_min}, {args.context_max}]'}")
    print(f"  Buffer: {args.num_buffer} (fixed)")
    print(f"  Target: {args.num_target} (fixed)")
    print(f"  Data type: {args.dtype}")
    print(f"  Normalize y: {args.normalize_y}")
    print(f"  Normalize x: {args.normalize_x} (method={args.x_norm_method}, outlier_z={args.x_outlier_threshold})")
    print(f"  Device: {args.device}")
    
    # Prepare sampler kwargs
    sampler_kwargs = {
        "dim_x": num_features,
        "dim_y": 1,
        "is_causal": True,
        "num_causes": args.num_causes,
        "num_layers": args.num_layers,
        "hidden_dim": args.hidden_dim,
        "noise_std": args.noise_std,
        "sampling": args.sampling,
        "normalize_y": args.normalize_y,
        "normalize_x": args.normalize_x,
        "x_norm_method": args.x_norm_method,
        "x_outlier_threshold": args.x_outlier_threshold,
        "device": args.device,
        "dtype": dtype,
    }
    
    # Generate offline batches
    generate_offline_batches(
        save_dir=output_dir,
        num_batches=num_batches,  # Use calculated num_batches
        batch_size=args.batch_size,
        sampler_data="tabular",  # Use tabular sampler
        num_context=num_context,
        num_buffer=args.num_buffer,
        num_target=args.num_target,
        context_range=(args.context_min, args.context_max),
        chunk_size=args.chunk_size,
        sampler_kwargs=sampler_kwargs,
    )
    
    print(f"\n✓ Successfully generated {num_batches} batches in {output_dir}")
    
    # Print summary
    chunk_files = list(output_dir.glob("chunk_*.pt"))
    metadata_file = output_dir / "metadata.json"
    
    print(f"\nGenerated files:")
    print(f"  Metadata: {metadata_file}")
    print(f"  Chunks: {len(chunk_files)} files")
    
    # Load and display metadata
    with open(metadata_file, "r") as f:
        metadata = json.load(f)
    
    print(f"\nDataset info:")
    print(f"  Total batches: {metadata['num_batches']}")
    print(f"  Batch size: {metadata['batch_size']}")
    print(f"  Chunk size: {metadata['chunk_size']}")
    print(f"  Number of chunks: {metadata['num_chunks']}")
    
    # Calculate actual size
    total_size_bytes = sum(f.stat().st_size for f in chunk_files)
    total_size_mb = total_size_bytes / (1024 * 1024)
    print(f"  Total size: {total_size_mb:.1f} MB")
    
    if args.target_size_mb is not None:
        print(f"  Target was: {args.target_size_mb} MB")
        print(f"  Achieved: {(total_size_mb / args.target_size_mb * 100):.1f}% of target")


if __name__ == "__main__":
    main()
