"""
Model loading utilities for PIConvNP.
"""

import torch
from pathlib import Path
import re


def load_trained_model(checkpoint_path, device='cpu'):
    """
    Load a trained PIConvNP model from checkpoint.
    
    Args:
        checkpoint_path: Path to .pt checkpoint file
        device: Device to load model on
    
    Returns:
        model: Loaded PIConvNP model
        info: Dictionary with training info
    """
    checkpoint_path = Path(checkpoint_path)
    if not checkpoint_path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
    
    print(f"Loading model from: {checkpoint_path}")
    
    # Load checkpoint
    checkpoint = torch.load(
        checkpoint_path, 
        map_location=device,
        weights_only=False
    )
    
    # Extract config
    config = checkpoint.get('model_config', checkpoint.get('config', {}))
    
    if not config:
        raise ValueError("No config found in checkpoint!")
    
    state_dict = checkpoint['model_state_dict']
    
    # ========================================================================
    # Detect parameter_dim from checkpoint
    # ========================================================================
    parameter_dim = 6  # default
    if 'parameter_encoder.network.0.weight' in state_dict:
        param_encoder_weight = state_dict['parameter_encoder.network.0.weight']
        parameter_dim = param_encoder_weight.shape[1]
        print(f"✅ Detected parameter_dim: {parameter_dim}")
    
    # ========================================================================
    # Detect encoder architecture from checkpoint
    # ========================================================================
    # Find all observation_encoder layer indices
    obs_enc_keys = [k for k in state_dict.keys() 
                    if 'observation_encoder.network' in k and 'weight' in k]
    
    layer_indices = []
    for key in obs_enc_keys:
        match = re.search(r'network\.(\d+)\.weight', key)
        if match:
            layer_indices.append(int(match.group(1)))
    
    layer_indices = sorted(set(layer_indices))
    
    # Get dimensions from each layer
    dims = []
    for idx in layer_indices:
        weight_key = f'set_encoder.observation_encoder.network.{idx}.weight'
        weight = state_dict[weight_key]
        dims.append((idx, weight.shape[1], weight.shape[0]))  # (index, in_dim, out_dim)
    
    # Infer hidden_dims
    # Architecture: input_dim -> [hidden_dims] -> output_dim
    # dims[0] = (0, input_dim, hidden_0)
    # dims[1] = (2, hidden_0, hidden_1)
    # dims[2] = (4, hidden_1, hidden_2)
    # dims[3] = (6, hidden_2, output_dim)
    # So hidden_dims = (hidden_0, hidden_1, hidden_2) = outputs of dims[:-1]
    
    if len(dims) >= 2:
        hidden_dims = tuple([d[2] for d in dims[:-1]])
        input_dim = dims[0][1]
        output_dim = dims[-1][2]
    else:
        # Fallback
        hidden_dims = (64, 64)
        input_dim = 1
        output_dim = 64
    
    print(f"✅ Detected encoder architecture:")
    print(f"   {len(layer_indices)} Linear layers at indices {layer_indices}")
    print(f"   input_dim={input_dim}, hidden_dims={hidden_dims}, output_dim={output_dim}")
    
    obs_enc_dim = config.get('observation_encoder_dim', 64)
    latent_dim = config.get('latent_encoder_dim', 64)
    
    print(f"\nModel configuration:")
    print(f"  spatial_dim: {config.get('spatial_dim', 1)}")
    print(f"  observation_dim: {config.get('observation_dim', 1)}")
    print(f"  output_dim: {config.get('output_dim', 1)}")
    print(f"  observation_encoder_dim: {obs_enc_dim}")
    print(f"  latent_encoder_dim: {latent_dim}")
    print(f"  parameter_encoder_dim: {config.get('parameter_encoder_dim', 64)}")
    print(f"  encoder_hidden_dims: {hidden_dims}")
    print(f"  conv_num_blocks: {config.get('conv_num_blocks', 6)}")
    print(f"  conv_kernel_size: {config.get('conv_kernel_size', 3)}")
    print(f"  parameter_dim: {parameter_dim}")
    
    # ========================================================================
    # Create model
    # ========================================================================
    from models.pi_convnp import PIConvNP
    
    # Get grid resolution
    if config.get('grid_resolution') is not None:
        grid_resolution = config['grid_resolution']
        if not isinstance(grid_resolution, (tuple, list)):
            grid_resolution = (grid_resolution,)
    else:
        grid_resolution = (256,)
    
    # Get domain bounds
    domain_bounds = config.get('domain_bounds', [[-1.0, 1.0]])
    if not isinstance(domain_bounds, list):
        domain_bounds = [[-1.0, 1.0]]
    domain_bounds = tuple(tuple(b) for b in domain_bounds)
    
    model = PIConvNP(
        # Problem specification
        spatial_dim=config.get('spatial_dim', 1),
        observation_dim=config.get('observation_dim', 1),
        output_dim=config.get('output_dim', 1),
        
        # Domain and grid
        grid_resolution=grid_resolution,
        domain_bounds=domain_bounds,
        
        # Architecture dimensions
        latent_dim=latent_dim,
        observation_encoder_dim=obs_enc_dim,
        conv_channels=obs_enc_dim,
        
        # Encoder settings - 使用检测到的 hidden_dims！
        encoder_hidden_dims=hidden_dims,
        
        # Kernel settings
        kernel_type=config.get('kernel_type', 'rbf'),
        kernel_lengthscale=config.get('kernel_lengthscale_init', 0.1),
        kernel_learnable=config.get('kernel_learnable', True),
        
        # Backbone settings
        num_conv_blocks=config.get('conv_num_blocks', 6),
        conv_kernel_size=config.get('conv_kernel_size', 3),
        use_unet=False,
        
        # Decoder settings
        min_sigma=config.get('sigma_min', 1e-4),
        
        # Parameter conditioning
        parameter_dim=parameter_dim,
        use_parameter_conditioning=True,
        
        # Other
        activation=config.get('conv_activation', 'swish'),
        device=str(device)
    ).to(device)
    
    # ========================================================================
    # Verify architecture
    # ========================================================================
    print(f"\n{'='*80}")
    print("Architecture verification:")
    print(f"{'='*80}")
    
    if hasattr(model, 'parameter_encoder') and model.parameter_encoder is not None:
        actual_param_dim = model.parameter_encoder.network[0].weight.shape[1]
        print(f"✅ parameter_encoder.input_dim: {actual_param_dim} (expected: {parameter_dim})")
    
    # Check observation_encoder layers
    if hasattr(model, 'set_encoder'):
        obs_enc = model.set_encoder.observation_encoder
        actual_layers = len([m for m in obs_enc.network if isinstance(m, torch.nn.Linear)])
        print(f"✅ observation_encoder.num_layers: {actual_layers} (expected: {len(layer_indices)})")
    
    # ========================================================================
    # Load state dict
    # ========================================================================
    try:
        missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
        
        if not missing_keys and not unexpected_keys:
            print("\n✅ Model loaded successfully (exact match)")
        else:
            print("\n⚠️  Model loaded with some differences:")
            if missing_keys:
                print(f"   Missing keys: {len(missing_keys)}")
                if len(missing_keys) <= 5:
                    for key in missing_keys:
                        print(f"     - {key}")
                else:
                    for key in missing_keys[:3]:
                        print(f"     - {key}")
                    print(f"     ... and {len(missing_keys)-3} more")
            
            if unexpected_keys:
                print(f"   Unexpected keys: {len(unexpected_keys)}")
                if len(unexpected_keys) <= 5:
                    for key in unexpected_keys:
                        print(f"     - {key}")
                else:
                    for key in unexpected_keys[:3]:
                        print(f"     - {key}")
                    print(f"     ... and {len(unexpected_keys)-3} more")
    
    except RuntimeError as e:
        print(f"\n❌ Error loading state_dict: {e}")
        raise
    
    model.eval()
    
    # ========================================================================
    # Extract training info
    # ========================================================================
    train_history = checkpoint.get('train_history', [])
    val_history = checkpoint.get('val_history', [])
    
    if isinstance(train_history, list) and len(train_history) > 0:
        train_loss = train_history[-1] if isinstance(train_history[-1], (int, float)) else 0.0
    else:
        train_loss = 0.0
    
    if isinstance(val_history, list) and len(val_history) > 0:
        val_loss = val_history[-1] if isinstance(val_history[-1], (int, float)) else 0.0
    else:
        val_loss = 0.0
    
    info = {
        'epoch': checkpoint.get('epoch', 0),
        'step': checkpoint.get('step', 0),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'best_val_loss': checkpoint.get('best_val_loss', 0.0),
        'config': config
    }
    
    print(f"\nTraining info:")
    print(f"  Epoch: {info['epoch']}")
    print(f"  Step: {info['step']}")
    print(f"  Best Val Loss: {info['best_val_loss']:.6f}")
    print(f"{'='*80}\n")
    
    return model, info