"""
Experiment utilities for setting up training runs.
Handles configuration, data loading, model creation, and directory setup.
"""

import os
import torch
import numpy as np
from datetime import datetime
from pathlib import Path
import json
from torch.utils.data import DataLoader

from models.pi_convnp import PIConvNP
from data.dataset import NonlinearPoissonDataset, collate_variable_length, get_cache_info, clear_cache
from configs.base_config import BaseConfig
from configs.nonlinear_poisson_config import NonlinearPoissonConfig  # 添加这行


def get_device():
    """Get the best available device."""
    if torch.cuda.is_available():
        return 'cuda'
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        return 'mps'
    else:
        return 'cpu'


def set_seed(seed):
    """Set random seed for reproducibility."""
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def create_config(args):
    """Create configuration by merging base config with command line arguments."""
    # 使用 NonlinearPoissonConfig 而不是 BaseConfig
    config = NonlinearPoissonConfig()
    
    # Update with non-None command line arguments
    if args.latent_dim is not None:
        config.latent_dim = args.latent_dim
    if args.conv_channels is not None:
        config.conv_channels = args.conv_channels
    if args.num_conv_blocks is not None:
        config.conv_num_blocks = args.num_conv_blocks
    if args.grid_resolution is not None:
        config.grid_resolution = (args.grid_resolution,)
    if args.lr is not None:
        config.learning_rate = args.lr
    if args.batch_size is not None:
        config.batch_size = args.batch_size
    if args.num_iterations is not None:
        config.num_iterations = args.num_iterations
    if args.gradient_clip is not None:
        config.grad_clip_norm = args.gradient_clip
    if args.lambda_physics is not None:
        config.physics_weight = args.lambda_physics
    if args.optimizer is not None:
        config.optimizer = args.optimizer
    if args.scheduler is not None:
        config.scheduler_type = args.scheduler
    if args.log_interval is not None:
        config.log_interval = args.log_interval
    if args.confidence_level is not None:
        config.confidence_level = args.confidence_level
    if args.num_vis_samples is not None:
        config.num_vis_samples = args.num_vis_samples
    if args.seed is not None:
        config.seed = args.seed
    
    # Update domain bounds
    config.domain_bounds = ((args.x_range_min, args.x_range_max),)
    config.observation_noise_std = args.noise_std
    
    config.validate()
    return config


def setup_directories(args):
    """Setup output and checkpoint directories."""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = Path(args.output_dir) / f'nonlinear_poisson_{timestamp}'
    output_dir.mkdir(parents=True, exist_ok=True)
    
    if args.checkpoint_dir is None:
        checkpoint_dir = output_dir / 'checkpoints'
    else:
        checkpoint_dir = Path(args.checkpoint_dir) / f'nonlinear_poisson_{timestamp}'
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    
    return output_dir, checkpoint_dir


def handle_cache_operations(args):
    """Handle cache-related operations (show info, clear cache)."""
    cache_dir = Path(args.output_dir).parent / 'data' / 'cache'
    
    # Show cache info
    if args.show_cache_info:
        cache_info = get_cache_info(str(cache_dir))
        if cache_info:
            print("\n" + "="*80)
            print("📦 Existing Cache Files:")
            print("="*80)
            for info in cache_info:
                print(f"\n📄 {info['filename']}")
                print(f"   Size: {info['size_mb']:.2f} MB")
                if 'parameters' in info:
                    print(f"   Samples: {info['parameters'].get('num_samples', 'N/A')}")
                    print(f"   Grid points: {info['parameters'].get('n_grid_points', 'N/A')}")
                    print(f"   Chebyshev: {info['parameters'].get('n_chebyshev', 'N/A')}")
                    print(f"   Seed: {info['parameters'].get('seed', 'N/A')}")
            print("="*80 + "\n")
        else:
            print("\n📦 No cache files found\n")
    
    # Clear cache
    if args.clear_cache:
        clear_cache(str(cache_dir), confirm=not args.force_clear_cache)
        if not args.force_clear_cache:
            import sys
            sys.exit(0)


def create_dataloaders(args, config, device):
    """Create train, validation, and test dataloaders."""
    parameter_dim = args.n_chebyshev + 1
    cache_dir = Path(args.output_dir).parent / 'data' / 'cache'
    
    print(f"\n{'='*80}")
    print("📊 Creating Nonlinear Poisson Datasets")
    print(f"{'='*80}")
    print(f"PDE: d/dx[k(u,x) * du/dx] = w")
    print(f"k(u,x) = log(1 + exp(u * Σ ξ_i T_i(x))) + 0.1")
    print(f"\nDataset Parameters:")
    print(f"  Chebyshev coefficients ξ: {args.n_chebyshev}")
    print(f"  Grid points: {args.n_grid_points}")
    print(f"  w range: [{args.w_range_min}, {args.w_range_max}]")
    print(f"  Domain: [{args.x_range_min}, {args.x_range_max}]")
    print(f"  Noise std: {args.noise_std}")
    print(f"  Cache directory: {cache_dir}")
    print(f"{'='*80}\n")
    
    train_dataset = NonlinearPoissonDataset(
        num_samples=args.num_train,
        n_grid_points=args.n_grid_points,
        n_chebyshev=args.n_chebyshev,
        n_context_range=(args.n_context_min, args.n_context_max),
        n_target_range=(args.n_target_min, args.n_target_max),
        noise_std=config.observation_noise_std,
        x_range=(args.x_range_min, args.x_range_max),
        w_range=(args.w_range_min, args.w_range_max),
        device='cpu',
        precompute=args.precompute,
        seed=config.seed,
        cache_dir=str(cache_dir),
        force_regenerate=args.force_regenerate
    )
    
    val_dataset = NonlinearPoissonDataset(
        num_samples=args.num_val,
        n_grid_points=args.n_grid_points,
        n_chebyshev=args.n_chebyshev,
        n_context_range=(args.n_context_min, args.n_context_max),
        n_target_range=(args.n_target_min, args.n_target_max),
        noise_std=config.observation_noise_std,
        x_range=(args.x_range_min, args.x_range_max),
        w_range=(args.w_range_min, args.w_range_max),
        device='cpu',
        precompute=args.precompute,
        seed=config.seed + 1000 if config.seed is not None else None,
        cache_dir=str(cache_dir),
        force_regenerate=args.force_regenerate
    )
    
    test_dataset = NonlinearPoissonDataset(
        num_samples=args.num_test,
        n_grid_points=args.n_grid_points,
        n_chebyshev=args.n_chebyshev,
        n_context_range=(args.n_context_min, args.n_context_max),
        n_target_range=(args.n_target_min, args.n_target_max),
        noise_std=config.observation_noise_std,
        x_range=(args.x_range_min, args.x_range_max),
        w_range=(args.w_range_min, args.w_range_max),
        device='cpu',
        precompute=args.precompute,
        seed=config.seed + 2000 if config.seed is not None else None,
        cache_dir=str(cache_dir),
        force_regenerate=args.force_regenerate
    )
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        collate_fn=collate_variable_length,
        pin_memory=True if device == 'cuda' else False
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        collate_fn=collate_variable_length,
        pin_memory=True if device == 'cuda' else False
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=args.num_workers,
        collate_fn=collate_variable_length,
        pin_memory=True if device == 'cuda' else False
    )
    
    print(f"✅ Datasets created successfully")
    print(f"   Train: {len(train_dataset)} samples")
    print(f"   Val:   {len(val_dataset)} samples")
    print(f"   Test:  {len(test_dataset)} samples")
    print(f"   Parameter dimension: {parameter_dim} (ξ: {args.n_chebyshev}, w: 1)\n")
    
    return train_loader, val_loader, test_loader, parameter_dim


def create_model(config, parameter_dim, use_parameter_conditioning, device):
    """Create PI-ConvNP model."""
    model = PIConvNP(
        spatial_dim=config.spatial_dim,
        observation_dim=config.observation_dim,
        output_dim=config.output_dim,
        latent_dim=config.latent_dim,
        observation_encoder_dim=config.observation_encoder_dim,
        conv_channels=config.conv_channels,
        encoder_hidden_dims=(config.latent_encoder_dim, config.latent_encoder_dim, config.latent_encoder_dim),
        num_conv_blocks=config.conv_num_blocks,
        grid_resolution=config.grid_resolution,
        domain_bounds=config.domain_bounds,
        min_sigma=config.sigma_min,
        parameter_dim=parameter_dim if use_parameter_conditioning else None,
        use_parameter_conditioning=use_parameter_conditioning,
        device=device
    )
    
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"\n🏗️  Model created successfully")
    print(f"   Trainable parameters: {num_params:,}\n")
    
    return model


def save_config(output_dir, config, args):
    """Save configuration to JSON."""
    config_path = output_dir / 'config.json'
    config_dict = config.to_dict()
    config_dict.update(vars(args))
    
    with open(config_path, 'w') as f:
        json.dump(config_dict, f, indent=4)
    
    print(f"📝 Configuration saved to {config_path}")


def print_device_info(device):
    """Print device information."""
    print(f"\n🖥️  Device: {device}")
    if torch.cuda.is_available() and device == 'cuda':
        print(f"   GPU: {torch.cuda.get_device_name(0)}")
        print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")