"""
Main training script for PI-ConvNP on 1D Nonlinear Poisson equation.
"""

import argparse
import torch
import json
from pathlib import Path

from training.trainer import Trainer
from training.losses import build_loss_function
from training.optimizer import build_optimizer, build_scheduler
from evaluation.metrics import evaluate_model, print_metrics, compute_multiple_confidence_levels
from utils.visualization import plot_predictions, plot_uncertainty, plot_training_curves
from utils.experiment import (
    get_device, set_seed, create_config, setup_directories,
    handle_cache_operations, create_dataloaders, create_model,
    save_config, print_device_info
)
from utils.model_loader import save_inference_model


def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description='Train PI-ConvNP on 1D Nonlinear Poisson equation')
    
    # Data parameters
    parser.add_argument('--num_train', type=int, default=1000)
    parser.add_argument('--num_val', type=int, default=200)
    parser.add_argument('--num_test', type=int, default=200)
    parser.add_argument('--n_grid_points', type=int, default=128)
    parser.add_argument('--n_chebyshev', type=int, default=5)
    parser.add_argument('--n_context_min', type=int, default=10)
    parser.add_argument('--n_context_max', type=int, default=50)
    parser.add_argument('--n_target_min', type=int, default=10)
    parser.add_argument('--n_target_max', type=int, default=100)
    parser.add_argument('--x_range_min', type=float, default=-1.0)
    parser.add_argument('--x_range_max', type=float, default=1.0)
    parser.add_argument('--w_range_min', type=float, default=0.5)
    parser.add_argument('--w_range_max', type=float, default=2.0)
    parser.add_argument('--noise_std', type=float, default=0.01)
    parser.add_argument('--precompute', action='store_true', default=True)
    
    # Cache management
    parser.add_argument('--force_regenerate', action='store_true')
    parser.add_argument('--show_cache_info', action='store_true')
    parser.add_argument('--clear_cache', action='store_true')
    parser.add_argument('--force_clear_cache', action='store_true')
    
    # Model parameters
    parser.add_argument('--latent_dim', type=int, default=None)
    parser.add_argument('--conv_channels', type=int, default=None)
    parser.add_argument('--num_conv_blocks', type=int, default=None)
    parser.add_argument('--grid_resolution', type=int, default=None)
    parser.add_argument('--use_parameter_conditioning', action='store_true', default=True)
    
    # Training parameters
    parser.add_argument('--num_epochs', type=int, default=100)
    parser.add_argument('--lr', type=float, default=None)
    parser.add_argument('--batch_size', type=int, default=None)
    parser.add_argument('--gradient_clip', type=float, default=None)
    parser.add_argument('--num_iterations', type=int, default=None)
    
    # Loss parameters
    parser.add_argument('--loss_type', type=str, default='bsnp', choices=['bsnp', 'mse', 'nll'])
    parser.add_argument('--lambda_data', type=float, default=None)
    parser.add_argument('--lambda_physics', type=float, default=None)
    parser.add_argument('--lambda_physics_final', type=float, default=None)
    parser.add_argument('--physics_warmup_epochs', type=int, default=None)
    
    # Optimizer and scheduler
    parser.add_argument('--optimizer', type=str, default=None)
    parser.add_argument('--scheduler', type=str, default=None)
    parser.add_argument('--scheduler_patience', type=int, default=10)
    parser.add_argument('--scheduler_step_size', type=int, default=30)
    parser.add_argument('--scheduler_gamma', type=float, default=0.1)
    
    # Output and logging
    parser.add_argument('--output_dir', type=str, default='./outputs')
    parser.add_argument('--checkpoint_dir', type=str, default=None)
    parser.add_argument('--log_interval', type=int, default=None)
    parser.add_argument('--use_tqdm', action='store_true', default=True)
    
    # Evaluation
    parser.add_argument('--confidence_level', type=float, default=None)
    parser.add_argument('--eval_only', action='store_true')
    parser.add_argument('--checkpoint_path', type=str, default=None)
    
    # Visualization
    parser.add_argument('--num_vis_samples', type=int, default=None)
    parser.add_argument('--save_plots', action='store_true', default=True)
    
    # Other
    parser.add_argument('--seed', type=int, default=None)
    parser.add_argument('--num_workers', type=int, default=4)
    
    return parser.parse_args()


def train_model(args, config, model, train_loader, val_loader, device, checkpoint_dir):
    """Train the model."""
    lambda_data = args.lambda_data if args.lambda_data is not None else 1.0
    
    # Get physics weight parameters
    lambda_physics_initial = args.lambda_physics if args.lambda_physics is not None else config.physics_weight
    lambda_physics_final = args.lambda_physics_final if args.lambda_physics_final is not None else config.physics_weight_final
    
    # Calculate warmup steps
    if args.physics_warmup_epochs is not None:
        physics_warmup_steps = args.physics_warmup_epochs * len(train_loader)
    else:
        physics_warmup_steps = config.physics_warmup_iterations
    
    # Build loss with warmup parameters
    loss_fn = build_loss_function(
        args.loss_type,
        lambda_data=lambda_data,
        lambda_physics=lambda_physics_initial,
        lambda_reg=1e-6,
        physics_warmup_steps=physics_warmup_steps,
        physics_weight_final=lambda_physics_final
    )
    
    optimizer = build_optimizer(
        model,
        config.optimizer,
        lr=config.learning_rate,
        weight_decay=config.weight_decay
    )
    
    scheduler_type = config.scheduler_type if args.scheduler is None else args.scheduler
    scheduler_kwargs = {}
    if scheduler_type == 'cosine':
        scheduler_kwargs.update({'T_max': args.num_epochs, 'eta_min': config.lr_min})
    elif scheduler_type == 'plateau':
        scheduler_kwargs.update({'patience': args.scheduler_patience, 'verbose': True})
    elif scheduler_type == 'step':
        scheduler_kwargs.update({'step_size': args.scheduler_step_size, 'gamma': args.scheduler_gamma})
    
    scheduler = build_scheduler(optimizer, scheduler_type, **scheduler_kwargs)
    
    trainer = Trainer(
        model=model,
        loss_fn=loss_fn,
        optimizer=optimizer,
        device=device,
        scheduler=scheduler,
        gradient_clip=config.grad_clip_norm,
        log_interval=config.log_interval,
        checkpoint_dir=str(checkpoint_dir),
        use_tqdm=args.use_tqdm
    )
    
    print(f"\n{'='*80}\n🚀 Starting Training\n{'='*80}")
    print(f"Batch size: {config.batch_size} | Epochs: {args.num_epochs}")
    print(f"LR: {config.learning_rate} → {config.lr_min} | Loss: {args.loss_type}")
    print(f"Gradient clip: {config.grad_clip_norm}")
    if args.loss_type == 'bsnp':
        print(f"Physics weight: {lambda_physics_initial:.2e} → {lambda_physics_final:.2e}")
        print(f"Warmup steps: {physics_warmup_steps}")
    print(f"{'='*80}\n")
    
    trainer.train(
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=args.num_epochs
    )
    
    return trainer


def evaluate_and_visualize(args, config, model, test_loader, device, save_dir=None):
    """Evaluate and visualize results."""
    print(f"\n{'='*80}\n📊 Evaluating Model on Test Set\n{'='*80}")
    
    test_metrics = evaluate_model(
        model,
        test_loader,
        device,
        compute_uncertainty=True,
        confidence_level=config.confidence_level
    )
    print_metrics(test_metrics, title="Test Set Evaluation")
    
    # Calibration analysis
    print(f"\n{'='*80}\n🎯 Detailed Calibration Analysis\n{'='*80}")
    
    model.eval()
    all_mean, all_sigma, all_target = [], [], []
    
    with torch.no_grad():
        for batch in test_loader:
            x_context = batch['x_context'].to(device)
            y_context = batch['y_context'].to(device)
            x_target = batch['x_target'].to(device)
            y_target = batch['y_target'].to(device)
            
            lambda_params = None
            if args.use_parameter_conditioning and 'lambda_params' in batch:
                lambda_params = batch['lambda_params'].to(device)
            
            mean, sigma = model(x_context, y_context, x_target, lambda_params)
            
            if 'target_mask' in batch:
                target_mask = batch['target_mask'].to(device)
                batch_size = target_mask.shape[0]
                for b in range(batch_size):
                    mask_b = target_mask[b]
                    all_mean.append(mean[b, :, mask_b].transpose(0, 1))
                    all_sigma.append(sigma[b, :, mask_b].transpose(0, 1))
                    all_target.append(y_target[b, mask_b])
            else:
                # Handle shape transformations
                if mean.dim() == 3:
                    mean = mean.transpose(1, 2)  # (batch, n_target, output_dim)
                if sigma.dim() == 3:
                    sigma = sigma.transpose(1, 2)
                if y_target.dim() == 2:
                    y_target = y_target.unsqueeze(-1)
                    
                all_mean.append(mean)
                all_sigma.append(sigma)
                all_target.append(y_target)
    
    all_mean = torch.cat(all_mean, dim=0)
    all_sigma = torch.cat(all_sigma, dim=0)
    all_target = torch.cat(all_target, dim=0)
    
    multi_level_ecp = compute_multiple_confidence_levels(
        all_mean, all_sigma, all_target,
        levels=[0.80, 0.90, 0.95, 0.99]
    )
    
    print("\n📈 Empirical Coverage Probability:")
    for key, value in multi_level_ecp.items():
        level = int(key.split('_')[1])
        print(f"  {level}% CI: ECP = {value:.4f} (Expected: {level/100:.2f})")
    
    ecp_90 = multi_level_ecp['ecp_90']
    if abs(ecp_90 - 0.90) < 0.05:
        print(f"\n✅ Well-calibrated at 90% (ECP = {ecp_90:.4f})")
    else:
        print(f"\n⚠️  {'Over' if ecp_90 < 0.90 else 'Under'}-confident (ECP = {ecp_90:.4f})")
    
    print(f"{'='*80}\n")
    
    # Visualizations
    if save_dir and args.save_plots:
        vis_dir = save_dir / 'visualizations'
        vis_dir.mkdir(parents=True, exist_ok=True)
        print(f"📸 Creating visualizations in {vis_dir}...")
        
        # Sample a few test cases for visualization
        num_vis = min(args.num_vis_samples if args.num_vis_samples else 5, len(test_loader.dataset))
        
        model.eval()
        with torch.no_grad():
            for idx in range(num_vis):
                sample = test_loader.dataset[idx]
                
                # Move to device and add batch dimension
                x_context = sample['x_context'].unsqueeze(0).to(device)
                y_context = sample['y_context'].unsqueeze(0).to(device)
                x_target = sample['x_target'].unsqueeze(0).to(device)
                y_target = sample['y_target'].unsqueeze(0).to(device)
                
                lambda_params = None
                if args.use_parameter_conditioning and 'lambda_params' in sample:
                    lambda_params = sample['lambda_params'].unsqueeze(0).to(device)
                
                # Get predictions
                mean, sigma = model(x_context, y_context, x_target, lambda_params)
                
                # Plot predictions
                try:
                    plot_predictions(
                        x_context.cpu().squeeze(0),
                        y_context.cpu().squeeze(0),
                        x_target.cpu().squeeze(0),
                        y_target.cpu().squeeze(0),
                        mean.cpu().squeeze(0).transpose(0, 1),
                        sigma.cpu().squeeze(0).transpose(0, 1),
                        save_path=str(vis_dir / f'prediction_sample_{idx}.png'),
                        title=f'Prediction Sample {idx}'
                    )
                    
                    plot_uncertainty(
                        x_target.cpu().squeeze(0),
                        mean.cpu().squeeze(0).transpose(0, 1),
                        sigma.cpu().squeeze(0).transpose(0, 1),
                        save_path=str(vis_dir / f'uncertainty_sample_{idx}.png'),
                        title=f'Uncertainty Sample {idx}'
                    )
                except Exception as e:
                    print(f"Warning: Failed to create visualization for sample {idx}: {e}")
        
        print(f"✅ Visualizations saved to {vis_dir}")
    
    return test_metrics


def main():
    """Main function."""
    args = parse_args()
    
    # Handle cache operations
    handle_cache_operations(args)
    
    # Setup
    config = create_config(args)
    set_seed(config.seed)
    device = get_device()
    config.device = device
    print_device_info(device)
    
    output_dir, checkpoint_dir = setup_directories(args)
    print(f"\n📁 Output: {output_dir}\n📁 Checkpoints: {checkpoint_dir}")
    
    save_config(output_dir, config, args)
    
    # Create data and model
    train_loader, val_loader, test_loader, parameter_dim = create_dataloaders(args, config, device)
    model = create_model(config, parameter_dim, args.use_parameter_conditioning, device)
    
    # Evaluation only
    if args.eval_only:
        if not args.checkpoint_path:
            raise ValueError("Must provide --checkpoint_path for eval_only mode")
        
        print(f"\n📥 Loading checkpoint: {args.checkpoint_path}")
        checkpoint = torch.load(args.checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        
        test_metrics = evaluate_and_visualize(args, config, model, test_loader, device, output_dir)
        
        with open(output_dir / 'test_metrics.json', 'w') as f:
            json.dump(test_metrics, f, indent=4)
        
        print(f"\n✅ Evaluation complete! Results saved to {output_dir}")
        return
    
    # Training
    trainer = train_model(args, config, model, train_loader, val_loader, device, checkpoint_dir)
    
    # Plot training curves
    if args.save_plots and len(trainer.train_history) > 0:
        try:
            train_losses = [h['total'] for h in trainer.train_history]
            val_losses = [h['total'] for h in trainer.val_history] if trainer.val_history else None
            
            plot_training_curves(
                {'train_loss': train_losses, 'val_loss': val_losses},
                save_path=str(output_dir / 'training_curves.png'),
                title='PI-ConvNP Training - Nonlinear Poisson'
            )
            print(f"📈 Training curves saved to {output_dir / 'training_curves.png'}")
        except Exception as e:
            print(f"Warning: Failed to plot training curves: {e}")
    
    # Load best model for final evaluation
    best_model_path = checkpoint_dir / 'best_model.pt'
    if best_model_path.exists():
        print(f"\n📥 Loading best model from {best_model_path}")
        checkpoint = torch.load(best_model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        print("\n⚠️  Best model checkpoint not found, using final model")
    
    # Final evaluation
    test_metrics = evaluate_and_visualize(args, config, model, test_loader, device, output_dir)
    
    # 🆕 Save final trained model in clean format for inference
    final_model_path = checkpoint_dir / 'trained_model.pt'
    save_inference_model(
        model=model,
        save_path=str(final_model_path),
        test_metrics=test_metrics
    )
    
    # Save results
    results = {
        'test_metrics': test_metrics,
        'best_val_loss': trainer.best_val_loss,
        'final_epoch': trainer.epoch,
        'total_steps': trainer.step,
        'model_path': str(final_model_path),
        'best_model_path': str(best_model_path)
    }
    
    with open(output_dir / 'results.json', 'w') as f:
        json.dump(results, f, indent=4)
    
    print(f"\n{'='*80}\n✅ Training Complete!\n{'='*80}")
    print(f"📁 Results: {output_dir}")
    print(f"\n📊 Final Test Metrics:")
    print(f"  MNSE: {test_metrics['mnse']:.6f}")
    print(f"  ECP:  {test_metrics['ecp']:.4f}")
    print(f"  RMSE: {test_metrics['rmse']:.6f}")
    print(f"  NLL:  {test_metrics['nll']:.6f}")
    print(f"\n🏆 Best Validation Loss: {trainer.best_val_loss:.6f}")
    print(f"📈 Total Training Steps: {trainer.step}")
    print(f"\n💾 Saved Models:")
    print(f"  ├─ Best: {best_model_path}")
    print(f"  └─ Final: {final_model_path}")
    print(f"{'='*80}\n")


if __name__ == '__main__':
    main()