"""
Unified Trainer - Supports all datasets and models

Features:
1. Supports onestep and pushforward training (correct implementation: one-step loss + stability loss)
2. Automatically adapts to different model output formats
3. Adaptive loss weighting strategy
4. Soft conservation loss (for baseline models only)
5. Complete logging and visualization
"""

import os
import time
import torch
import torch.nn as nn
from typing import Dict, Tuple, Optional, List
from tqdm import tqdm
import joblib
import numpy as np
import matplotlib.pyplot as plt

from .config import TrainingConfig


def compute_conservation_loss(pred: torch.Tensor, target: torch.Tensor,
                               normalize: bool = True) -> torch.Tensor:
    """
    Soft conservation loss - penalizes global mass drift

    Args:
        pred: Predicted field [batch, channels, ...]
        target: Target field [batch, channels, ...]
        normalize: Whether to normalize loss to reasonable scale

    Returns:
        Conservation loss scalar

    Note:
        Normalization makes conservation loss comparable to prediction MSE loss:
        - Raw relative error ~1e-3 to 1e-6 (very small, leads to excessive weight)
        - After normalization: converted to per-element MSE scale
    """
    # Ensure correct dimensions
    if len(pred.shape) == 2:
        # [batch, length] -> spatial dimension is dim=1
        pred_sum = pred.sum(dim=1)
        target_sum = target.sum(dim=1)
        num_elements = pred.shape[1]
    elif len(pred.shape) == 3:
        # [batch, channels, length] or [batch, H, W]
        if pred.shape[1] <= 4:  # likely [batch, channels, length]
            pred_sum = pred.sum(dim=2)
            target_sum = target.sum(dim=2)
            num_elements = pred.shape[2]
        else:  # likely [batch, H, W]
            pred_sum = pred.sum(dim=(1, 2))
            target_sum = target.sum(dim=(1, 2))
            num_elements = pred.shape[1] * pred.shape[2]
    else:
        # [batch, channels, H, W] or higher
        spatial_dims = list(range(2, len(pred.shape)))
        pred_sum = pred.sum(dim=spatial_dims)
        target_sum = target.sum(dim=spatial_dims)
        num_elements = 1
        for d in spatial_dims:
            num_elements *= pred.shape[d]

    # Compute total difference
    diff = torch.abs(pred_sum - target_sum)

    if normalize:
        # Normalization method: convert to per-element MSE scale
        # If total drifts by delta_total, each element drifts by delta_total / N on average
        # MSE scale: (delta_total / N)^2 * N = delta_total^2 / N
        normalized_loss = (diff ** 2) / num_elements
        return normalized_loss.mean()
    else:
        # Raw relative error (for monitoring, not training)
        relative_diff = diff / (torch.abs(target_sum) + 1e-8)
        return relative_diff.mean()


def get_model_output_format(model) -> str:
    """
    Detect model type and return output format identifier

    Returns:
        'fluxnet_d': (next, outflow, inflow) - FluxNet-D series (including FNO_FluxD)
        'fluxnet_sw': (next_state, h_delta, mx_delta, my_delta) - Shallow water equation FluxNet
        'fluxnet_nl': (next, delta) - FluxNet-N/L/P/U series
        'cnn': (next,) - CNN baseline
        'sw_baseline': (next,) - Shallow water equation baseline (including CNN_SW_Proj)
        'fno_sw': (next,) - FNO shallow water equation (including FNO_SW_Proj)
        'fno': (next,) - General FNO
    """
    model_name = model.__class__.__name__

    if 'FluxNet_D' in model_name or 'FNO_FluxD' in model_name:
        return 'fluxnet_d'
    elif 'FluxNet_SW_2D' == model_name:
        return 'fluxnet_sw'
    elif 'FluxNet_SW_Baseline' == model_name:
        return 'sw_baseline'
    elif 'CNN_SW_Proj' in model_name:
        return 'sw_baseline'  # CNN_SW_Proj has same output format as sw_baseline
    elif 'FNO_SW_Proj' in model_name:
        return 'fno_sw'  # FNO_SW_Proj has same output format as fno_sw
    elif 'FNO_SW' in model_name:
        return 'fno_sw'
    elif 'FNO' in model_name:
        return 'fno'
    elif 'FluxNet' in model_name:
        return 'fluxnet_nl'
    elif 'CNN_Baseline' in model_name:
        return 'cnn'
    else:
        raise ValueError(f"Unknown model type: {model_name}. "
                         f"Supported: FluxNet_D_*, FluxNet_SW_2D, FluxNet_SW_Baseline, "
                         f"FluxNet_N_*, FluxNet_L_*, FluxNet_P_*, FluxNet_U_*, "
                         f"CNN_Baseline_*, FNO_SW, FNO_SW_Proj, CNN_SW_Proj, FNO_1D, FNO_FluxD_1D")


def is_baseline_model(model) -> bool:
    """Check if model is a baseline model (without structural conservation guarantee)"""
    model_name = model.__class__.__name__
    # FNO_FluxD has conservation guarantee, not considered baseline
    if 'FNO_FluxD' in model_name:
        return False
    return 'Baseline' in model_name or 'CNN' in model_name or 'FNO' in model_name


class AdaptiveLossWeights:
    """
    Adaptive loss weights manager

    Records initial values of each loss term in the first training batch,
    uses their geometric mean as baseline to calculate weight coefficients,
    making each loss term's contribution to total loss comparable in magnitude.
    """

    def __init__(self, loss_names: List[str], mode: str = 'adaptive'):
        """
        Args:
            loss_names: List of loss term names
            mode: 'adaptive' for adaptive weights, 'manual' for user-specified weights
        """
        self.loss_names = loss_names
        self.mode = mode
        self.initial_values = {}
        self.weights = {name: 1.0 for name in loss_names}
        self.initialized = False

    def initialize(self, losses: Dict[str, torch.Tensor]):
        """Initialize weights on first call"""
        if self.initialized or self.mode != 'adaptive':
            return

        # Record initial values
        for name in self.loss_names:
            if name in losses:
                val = losses[name].item()
                if val > 1e-10:  # Avoid division by zero
                    self.initial_values[name] = val

        if len(self.initial_values) == 0:
            return

        # Calculate geometric mean
        values = list(self.initial_values.values())
        geo_mean = np.exp(np.mean(np.log(np.array(values) + 1e-10)))

        # Calculate weights: weight_i = geo_mean / initial_i
        for name, initial in self.initial_values.items():
            self.weights[name] = geo_mean / (initial + 1e-10)

        self.initialized = True
        print(f"[AdaptiveLossWeights] Initialized weights: {self.weights}")

    def set_manual_weights(self, weight_dict: Dict[str, float]):
        """Manually set weights"""
        for name, weight in weight_dict.items():
            if name in self.weights:
                self.weights[name] = weight
        # Normalize so weights sum to 1
        total = sum(self.weights.values())
        if total > 0:
            for name in self.weights:
                self.weights[name] /= total

    def get_weight(self, name: str) -> float:
        return self.weights.get(name, 1.0)


class UnifiedTrainer:
    """
    Unified Trainer

    Supports training of all FluxNet models and baseline models

    Correct implementation of pushforward training:
    - For each batch, compute simultaneously:
      1. one-step loss: predict next step using original input
      2. stability loss: rollout N steps (no gradient), then compute loss at step N (with gradient)
    - total_loss = one_step_loss + stability_loss
    """

    def __init__(
        self,
        model: nn.Module,
        config: TrainingConfig,
        train_loader,
        val_loader,
        result_dir: str,
        device: torch.device,
        dataset_type: str = 'spinodal_decomposition'
    ):
        self.model = model.to(device)
        self.config = config
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.result_dir = result_dir
        self.device = device
        self.dataset_type = dataset_type

        # Create output directories
        os.makedirs(result_dir, exist_ok=True)
        self.mpdt_dir = os.path.join(result_dir, "training_visualization")
        os.makedirs(self.mpdt_dir, exist_ok=True)

        # Detect model output format
        self.output_format = get_model_output_format(model)
        self.is_baseline = is_baseline_model(model)
        print(f"Model output format: {self.output_format}, Baseline: {self.is_baseline}")

        # Set loss function
        if config.loss_criterion == 'MSE':
            self.criterion = nn.MSELoss()
        elif config.loss_criterion == 'MAE':
            self.criterion = nn.L1Loss()
        else:
            raise ValueError(f"Unknown loss_criterion: {config.loss_criterion}")

        # Determine loss term names
        self.loss_names = self._get_loss_names()
        print(f"Loss terms: {self.loss_names}")

        # Initialize adaptive weight manager
        if config.loss_weight_mode == 'adaptive':
            self.loss_weights = AdaptiveLossWeights(self.loss_names, mode='adaptive')
        else:
            self.loss_weights = AdaptiveLossWeights(self.loss_names, mode='manual')
            if config.loss_weights:
                self.loss_weights.set_manual_weights(config.loss_weights)

        # Set optimizer
        self.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=config.learning_rate,
            weight_decay=config.weight_decay
        )

        # Set learning rate scheduler
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer,
            mode='min',
            factor=config.scheduler_factor,
            patience=config.scheduler_patience,
        )

        # Training records
        self.train_losses = []
        self.val_losses = []
        self.val_losses_dict = {name: [] for name in self.loss_names}
        self.val_cons_losses = []  # Conservation error (all models)
        self.best_losses = []
        self.optimizer_lrs = []

        self.best_loss = float('inf')
        self.vis_counter = 0

    def _get_loss_names(self) -> List[str]:
        """Determine loss term names based on model type"""
        names = ['p_loss']  # Prediction loss always exists

        # FluxNet-D DCL loss (only when dcl_weight > 0)
        if self.output_format == 'fluxnet_d' and self.config.dcl_weight > 0:
            names.append('dcl_loss')

        if self.config.use_pushforward:
            # Pushforward: add stability loss
            names.append('stability_loss')
            if self.output_format == 'fluxnet_d' and self.config.dcl_weight > 0:
                names.append('dcl_n_loss')  # DCL loss for pushforward
            if self.is_baseline and self.config.soft_conservation_weight > 0:
                names.append('cons_n_loss')  # Conservation loss for pushforward

        if self.is_baseline and self.config.soft_conservation_weight > 0:
            names.append('cons_loss')

        return names

    def _get_model_prediction(self, model_input: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Get model prediction and extra losses

        Returns:
            pred: Predicted next state
            extra_losses: Extra loss terms (e.g., dcl_loss)
        """
        outputs = self.model(model_input)
        extra_losses = {}

        if self.output_format == 'fluxnet_d':
            pred, outflow, inflow = outputs
            extra_losses['dcl'] = self.criterion(outflow, inflow)
        elif self.output_format == 'fluxnet_sw':
            pred = outputs[0]
        elif self.output_format == 'fluxnet_nl':
            pred, _ = outputs
        else:  # cnn, sw_baseline, fno_sw, fno
            pred = outputs[0]

        return pred, extra_losses

    def compute_onestep_loss(self, inputs: torch.Tensor, target: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        Compute one-step loss

        Args:
            inputs: Model input
            target: Ground truth for next step
        """
        losses = {}

        pred, extra_losses = self._get_model_prediction(inputs)

        # Main prediction loss
        losses['p_loss'] = self.criterion(pred, target)

        # FluxNet-D DCL loss (computed only when dcl_weight > 0)
        if 'dcl' in extra_losses and self.config.dcl_weight > 0:
            losses['dcl_loss'] = extra_losses['dcl'] * self.config.dcl_weight

        # Baseline conservation loss (using normalized version)
        if self.is_baseline and self.config.soft_conservation_weight > 0:
            cons_loss = compute_conservation_loss(pred, target, normalize=True)
            losses['cons_loss'] = cons_loss * self.config.soft_conservation_weight

        return losses, pred

    def compute_stability_loss(self, inputs: torch.Tensor, targets: torch.Tensor) -> Dict[str, torch.Tensor]:
        """
        Compute stability/pushforward loss

        Correct implementation:
        1. Rollout N-1 steps first (completely without gradients)
        2. Perform one forward pass with gradient at step N
        3. Compute and return loss at step N

        Args:
            inputs: Initial input [batch, channels, ...]
            targets: Multi-step targets [batch, K, ...]
        """
        losses = {}
        N = self.config.unroll_steps  # Number of rollout steps

        # Get current state
        if self.dataset_type in ['convection_diffusion', 'traffic_flow']:
            current_state = inputs[:, 0:1]  # 守恒场
            external_field = inputs[:, 1:]
            has_external = True
        elif self.dataset_type == 'shallow_water':
            current_state = inputs  # [batch, 3, H, W]
            has_external = False
        else:  # spinodal
            current_state = inputs
            has_external = False

        # ===== Step 1: Rollout N-1 steps (no gradient) =====
        with torch.no_grad():
            for k in range(N - 1):
                # Prepare input
                if has_external:
                    model_input = torch.cat([current_state, external_field], dim=1)
                else:
                    model_input = current_state

                # Forward pass
                pred, _ = self._get_model_prediction(model_input)
                current_state = pred

        # ===== Step 2: Forward pass with gradient at step N =====
        if has_external:
            model_input = torch.cat([current_state, external_field], dim=1)
        else:
            model_input = current_state

        pred, extra_losses = self._get_model_prediction(model_input)

        # Get target at step N
        if self.dataset_type == 'shallow_water':
            target_N = targets[:, N-1]  # [batch, 3, H, W]
        elif self.dataset_type in ['convection_diffusion', 'traffic_flow']:
            # targets shape: [batch, K, length]
            target_N = targets[:, N-1:N]  # [batch, 1, length]
        else:  # spinodal
            # targets shape: [batch, K, H, W]
            target_N = targets[:, N-1:N]  # [batch, 1, H, W]

        # Compute stability loss
        losses['stability_loss'] = self.criterion(pred, target_N)

        # FluxNet-D DCL loss (computed only when dcl_weight > 0)
        if 'dcl' in extra_losses and self.config.dcl_weight > 0:
            losses['dcl_n_loss'] = extra_losses['dcl'] * self.config.dcl_weight

        # Baseline conservation loss (using normalized version)
        if self.is_baseline and self.config.soft_conservation_weight > 0:
            cons_loss = compute_conservation_loss(pred, target_N, normalize=True)
            losses['cons_n_loss'] = cons_loss * self.config.soft_conservation_weight

        return losses

    def compute_total_loss(self, losses: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Compute weighted total loss

        Note: Some loss terms are already multiplied by user-specified weights during computation
        (e.g., cons_loss, dcl_loss), these terms will not be re-weighted by AdaptiveLossWeights.
        """
        # Loss terms already weighted during computation, not subject to adaptive weighting
        pre_weighted_losses = {'cons_loss', 'cons_n_loss', 'dcl_loss', 'dcl_n_loss'}

        # Filter out losses that need adaptive weighting
        losses_for_adaptive = {k: v for k, v in losses.items() if k not in pre_weighted_losses}

        # Initialize adaptive weights (only for non-pre-weighted losses)
        self.loss_weights.initialize(losses_for_adaptive)

        total = 0
        for name, loss in losses.items():
            if name in pre_weighted_losses:
                # Pre-weighted losses added directly
                total += loss
            else:
                # Other losses use adaptive/manual weights
                weight = self.loss_weights.get_weight(name)
                total += weight * loss

        return total

    def train_one_epoch(self, epoch: int) -> float:
        """Train for one epoch"""
        self.model.train()
        total_loss = 0.0
        num_batches = 0

        pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}")
        for inputs, targets in pbar:
            inputs = inputs.to(self.device)
            targets = targets.to(self.device)

            all_losses = {}

            # ===== Part 1: One-step loss (always computed) =====
            if self.config.use_pushforward:
                # In pushforward mode, one-step target is the first step
                if self.dataset_type == 'shallow_water':
                    onestep_target = targets[:, 0]  # [batch, 3, H, W]
                elif self.dataset_type in ['convection_diffusion', 'traffic_flow']:
                    onestep_target = targets[:, 0:1]  # [batch, 1, length]
                else:  # spinodal
                    onestep_target = targets[:, 0:1]  # [batch, 1, H, W]
            else:
                onestep_target = targets

            onestep_losses, _ = self.compute_onestep_loss(inputs, onestep_target)
            all_losses.update(onestep_losses)

            # ===== Part 2: Stability loss (pushforward mode only) =====
            if self.config.use_pushforward:
                stability_losses = self.compute_stability_loss(inputs, targets)
                all_losses.update(stability_losses)

            # ===== Compute total loss and backpropagate =====
            loss = self.compute_total_loss(all_losses)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            total_loss += loss.item()
            num_batches += 1

            pbar.set_postfix({'loss': f"{loss.item():.6f}"})

        return total_loss / num_batches

    def validate(self, epoch: int) -> Tuple[float, Dict]:
        """Validation"""
        self.model.eval()

        val_metrics = {name: 0.0 for name in self.loss_names}
        val_metrics['total_loss'] = 0.0
        val_metrics['cons_error'] = 0.0  # Track actual conservation error
        num_batches = 0

        vis_done = False

        with torch.no_grad():
            for inputs, targets in tqdm(self.val_loader, desc="Validation"):
                inputs = inputs.to(self.device)
                targets = targets.to(self.device)

                all_losses = {}

                # One-step loss
                if self.config.use_pushforward:
                    if self.dataset_type == 'shallow_water':
                        onestep_target = targets[:, 0]
                    elif self.dataset_type in ['convection_diffusion', 'traffic_flow']:
                        onestep_target = targets[:, 0:1]
                    else:
                        onestep_target = targets[:, 0:1]
                else:
                    onestep_target = targets

                onestep_losses, pred = self.compute_onestep_loss(inputs, onestep_target)
                all_losses.update(onestep_losses)

                # Stability loss
                if self.config.use_pushforward:
                    stability_losses = self.compute_stability_loss(inputs, targets)
                    all_losses.update(stability_losses)

                total_loss = self.compute_total_loss(all_losses)

                # Accumulate
                val_metrics['total_loss'] += total_loss.item()
                for name in self.loss_names:
                    if name in all_losses:
                        val_metrics[name] += all_losses[name].item()

                # Track conservation error (using unnormalized version for monitoring)
                cons_err = compute_conservation_loss(pred, onestep_target, normalize=False).item()
                val_metrics['cons_error'] += cons_err

                num_batches += 1

                # Visualization
                if not vis_done and epoch % self.config.save_interval == 0:
                    self._save_visualization(inputs, onestep_target, pred, epoch)
                    vis_done = True

        # Average
        for key in val_metrics:
            val_metrics[key] /= max(num_batches, 1)

        return val_metrics['total_loss'], val_metrics

    def _save_visualization(self, inputs, target, pred, epoch):
        """Save training process visualization"""
        try:
            # Select visualization based on dataset type
            if self.dataset_type in ['spinodal_decomposition']:
                self._visualize_2d_field(pred, target, epoch)
            elif self.dataset_type in ['convection_diffusion', 'traffic_flow']:
                self._visualize_1d_field(inputs, pred, target, epoch)
            elif self.dataset_type == 'shallow_water':
                self._visualize_shallow_water(pred, target, epoch)

        except Exception as e:
            print(f"Visualization save failed: {e}")

    def _visualize_2d_field(self, pred, target, epoch):
        """2D field visualization (spinodal_decomposition)"""
        pred_np = pred[0, 0].cpu().numpy()
        target_np = target[0, 0].cpu().numpy()
        error_np = np.abs(pred_np - target_np)

        fig, axes = plt.subplots(1, 3, figsize=(15, 4))

        vmin = min(pred_np.min(), target_np.min())
        vmax = max(pred_np.max(), target_np.max())

        im0 = axes[0].imshow(pred_np, cmap='viridis', vmin=vmin, vmax=vmax)
        axes[0].set_title('Predicted')
        plt.colorbar(im0, ax=axes[0], shrink=0.8)

        im1 = axes[1].imshow(target_np, cmap='viridis', vmin=vmin, vmax=vmax)
        axes[1].set_title('Target')
        plt.colorbar(im1, ax=axes[1], shrink=0.8)

        im2 = axes[2].imshow(error_np, cmap='hot')
        axes[2].set_title(f'Error (MAE={error_np.mean():.4e})')
        plt.colorbar(im2, ax=axes[2], shrink=0.8)

        plt.suptitle(f'Epoch {epoch+1}')
        plt.tight_layout()
        plt.savefig(os.path.join(self.mpdt_dir, f'epoch_{epoch+1:04d}.png'), dpi=100)
        plt.close()

    def _visualize_1d_field(self, inputs, pred, target, epoch):
        """1D field visualization (convection_diffusion, traffic_flow)"""
        current = inputs[0, 0].cpu().numpy()

        # Handle different shapes
        if len(pred.shape) == 3:
            pred_np = pred[0, 0].cpu().numpy()
        else:
            pred_np = pred[0].cpu().numpy()

        if len(target.shape) == 3:
            target_np = target[0, 0].cpu().numpy()
        else:
            target_np = target[0].cpu().numpy()

        fig, axes = plt.subplots(1, 2, figsize=(12, 4))

        x = np.arange(len(pred_np))

        # Prediction vs Target
        axes[0].plot(x, current, 'g--', label='Current', alpha=0.7)
        axes[0].plot(x, target_np, 'b-', label='Target', linewidth=2)
        axes[0].plot(x, pred_np, 'r--', label='Predicted', linewidth=2)
        axes[0].legend()
        axes[0].set_xlabel('Position')
        axes[0].set_ylabel('Value')
        axes[0].set_title('Prediction vs Target')
        axes[0].grid(True, alpha=0.3)

        # Error
        error = np.abs(pred_np - target_np)
        axes[1].plot(x, error, 'g-', linewidth=2)
        axes[1].set_xlabel('Position')
        axes[1].set_ylabel('Absolute Error')
        axes[1].set_title(f'Error (MAE={error.mean():.4e})')
        axes[1].grid(True, alpha=0.3)

        plt.suptitle(f'Epoch {epoch+1}')
        plt.tight_layout()
        plt.savefig(os.path.join(self.mpdt_dir, f'epoch_{epoch+1:04d}.png'), dpi=100)
        plt.close()

    def _visualize_shallow_water(self, pred, target, epoch):
        """Shallow water equation three-channel visualization"""
        field_names = ['h', 'mx', 'my']

        fig, axes = plt.subplots(3, 3, figsize=(12, 12))

        for i, name in enumerate(field_names):
            pred_np = pred[0, i].cpu().numpy()
            target_np = target[0, i].cpu().numpy()
            error_np = np.abs(pred_np - target_np)

            vmin = min(pred_np.min(), target_np.min())
            vmax = max(pred_np.max(), target_np.max())

            im0 = axes[i, 0].imshow(pred_np, cmap='viridis', vmin=vmin, vmax=vmax)
            axes[i, 0].set_title(f'Pred {name}')
            plt.colorbar(im0, ax=axes[i, 0], shrink=0.8)

            im1 = axes[i, 1].imshow(target_np, cmap='viridis', vmin=vmin, vmax=vmax)
            axes[i, 1].set_title(f'Target {name}')
            plt.colorbar(im1, ax=axes[i, 1], shrink=0.8)

            im2 = axes[i, 2].imshow(error_np, cmap='hot')
            axes[i, 2].set_title(f'Error {name}')
            plt.colorbar(im2, ax=axes[i, 2], shrink=0.8)

        plt.suptitle(f'Epoch {epoch+1}')
        plt.tight_layout()
        plt.savefig(os.path.join(self.mpdt_dir, f'epoch_{epoch+1:04d}.png'), dpi=100)
        plt.close()

    def save_checkpoint(self, epoch: int, is_best: bool = False):
        """Save checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'model_state': self.model.state_dict(),
            'optimizer_state': self.optimizer.state_dict(),
            'scheduler_state': self.scheduler.state_dict(),
            'best_loss': self.best_loss,
            'config': self.config
        }

        torch.save(checkpoint, os.path.join(self.result_dir, 'latest_checkpoint.pt'))

        if is_best:
            torch.save(checkpoint, os.path.join(self.result_dir, 'best_checkpoint.pt'))
            torch.save(self.model.state_dict(), os.path.join(self.result_dir, 'best_model.pt'))

    def train(self) -> Dict:
        """Complete training pipeline"""
        print(f"\n{'='*60}")
        print(f"Starting training - {self.config.num_epochs} epochs")
        print(f"Training mode: {'pushforward (one-step + stability)' if self.config.use_pushforward else 'onestep'}")
        print(f"Results saved to: {self.result_dir}")
        print(f"{'='*60}\n")

        start_time = time.time()

        for epoch in range(self.config.num_epochs):
            # Training
            train_loss = self.train_one_epoch(epoch)
            self.train_losses.append(train_loss)

            # Validation
            val_loss, val_metrics = self.validate(epoch)
            self.val_losses.append(val_loss)

            # Record each loss term
            for name in self.loss_names:
                self.val_losses_dict[name].append(val_metrics.get(name, 0.0))

            # Record conservation error
            self.val_cons_losses.append(val_metrics['cons_error'])

            # Learning rate scheduling
            self.scheduler.step(val_loss)
            current_lr = self.optimizer.param_groups[0]['lr']
            self.optimizer_lrs.append(current_lr)

            # Save best model
            is_best = val_loss < self.best_loss
            if is_best:
                self.best_loss = val_loss
            self.best_losses.append(self.best_loss)

            # Save checkpoint
            if (epoch + 1) % self.config.save_interval == 0 or is_best:
                self.save_checkpoint(epoch, is_best)

            # Print progress
            print(f"\nEpoch [{epoch+1}/{self.config.num_epochs}]")
            print(f"  Train Loss: {train_loss:.6e}")
            print(f"  Val Loss: {val_loss:.6e}, Best: {self.best_loss:.6e}")
            print(f"  Conservation Error: {val_metrics['cons_error']:.6e}")
            print(f"  LR: {current_lr:.6e}")

            # Save loss curves
            self._save_loss_curves()

        # Training completed
        total_time = time.time() - start_time
        print(f"\nTraining completed! Total time: {total_time:.1f}s")
        print(f"Best validation loss: {self.best_loss:.6e}")

        self.save_checkpoint(self.config.num_epochs - 1, False)

        with open(os.path.join(self.result_dir, f"training_time_{int(total_time)}_s.txt"), 'w') as f:
            f.write(f"Total training time: {total_time:.1f} seconds\n")
            f.write(f"Best validation loss: {self.best_loss:.6e}\n")

        return {
            'best_loss': self.best_loss,
            'train_losses': self.train_losses,
            'val_losses': self.val_losses,
            'total_time': total_time
        }

    def _save_loss_curves(self):
        """Save loss curves"""
        # Save data
        loss_data = {
            'train_losses': self.train_losses,
            'val_losses': self.val_losses,
            'val_losses_dict': self.val_losses_dict,
            'val_cons_losses': self.val_cons_losses,
            'best_losses': self.best_losses,
            'optimizer_lrs': self.optimizer_lrs,
            'output_format': self.output_format,
            'is_baseline': self.is_baseline,
            'use_pushforward': self.config.use_pushforward
        }
        joblib.dump(loss_data, os.path.join(self.result_dir, "loss_curve.pkl"))

        # Plot loss curves
        self._plot_loss_curves(loss_data)

    def _plot_loss_curves(self, loss_data):
        """
        Plot appropriate loss curves based on model type

        Only plot losses actually used for backpropagation:
        - FluxNet models: don't plot conservation loss curve (conservation is guaranteed by architecture)
        - Baseline models: plot conservation loss (if soft conservation is used)
        """
        epochs = range(1, len(loss_data['train_losses']) + 1)

        # Determine which losses to plot
        losses_to_plot = []
        losses_to_plot.append(('Train Loss', loss_data['train_losses'], 'blue'))
        losses_to_plot.append(('Val Loss', loss_data['val_losses'], 'orange'))

        # Add losses based on model type
        # DCL loss (dual consistency loss, formerly io_loss)
        if self.output_format == 'fluxnet_d' and 'dcl_loss' in loss_data['val_losses_dict']:
            losses_to_plot.append(('DCL Loss', loss_data['val_losses_dict']['dcl_loss'], 'green'))

        if 'p_loss' in loss_data['val_losses_dict']:
            losses_to_plot.append(('Pred Loss', loss_data['val_losses_dict']['p_loss'], 'red'))

        # Only plot conservation error curve for baseline models (FluxNet conservation is guaranteed by architecture)
        if self.is_baseline:
            losses_to_plot.append(('Conservation Error', loss_data['val_cons_losses'], 'purple'))

        # Add stability loss for pushforward
        if self.config.use_pushforward and 'stability_loss' in loss_data['val_losses_dict']:
            losses_to_plot.append(('Stability Loss', loss_data['val_losses_dict']['stability_loss'], 'cyan'))

        # Plot
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))

        # Main loss plot
        for name, data, color in losses_to_plot[:4]:  # First 4 in main plot
            ax1.plot(epochs, data, label=name, color=color if isinstance(color, str) else None, linewidth=2)

        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.set_yscale('log')
        ax1.legend(loc='upper right')
        ax1.grid(True, alpha=0.3)
        ax1.set_title(f'Training Curves - {self.output_format}')

        # Learning rate
        ax1_lr = ax1.twinx()
        ax1_lr.plot(epochs, loss_data['optimizer_lrs'], 'k--', alpha=0.5, label='LR')
        ax1_lr.set_ylabel('Learning Rate', color='gray')
        ax1_lr.tick_params(axis='y', labelcolor='gray')

        # Secondary loss plot (if more losses exist)
        if len(losses_to_plot) > 4:
            for name, data, color in losses_to_plot[4:]:
                ax2.plot(epochs, data, label=name, linewidth=2)
            ax2.set_xlabel('Epoch')
            ax2.set_ylabel('Loss')
            ax2.set_yscale('log')
            ax2.legend(loc='upper right')
            ax2.grid(True, alpha=0.3)
            ax2.set_title('Additional Loss Components')
        else:
            ax2.axis('off')

        plt.tight_layout()
        plt.savefig(os.path.join(self.result_dir, 'loss_curve.png'), dpi=150)
        plt.close()


def train_model(
    model: nn.Module,
    dataset_type: str,
    train_folder: str,
    val_folder: str,
    result_dir: str,
    config: TrainingConfig,
    device: torch.device = None,
    num_workers: int = 4
) -> Dict:
    """
    Unified training interface
    """
    from .dataloader import create_data_loader

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    training_mode = 'pushforward' if config.use_pushforward else 'onestep'

    train_loader = create_data_loader(
        dataset_type=dataset_type,
        folder_path=train_folder,
        batch_size=config.batch_size,
        ndt=config.ndt,
        shuffle=True,
        num_workers=num_workers,
        training_mode=training_mode,
        unroll_steps=config.unroll_steps
    )

    val_loader = create_data_loader(
        dataset_type=dataset_type,
        folder_path=val_folder,
        batch_size=config.batch_size,
        ndt=config.ndt,
        shuffle=False,
        num_workers=num_workers,
        training_mode=training_mode,
        unroll_steps=config.unroll_steps
    )

    trainer = UnifiedTrainer(
        model=model,
        config=config,
        train_loader=train_loader,
        val_loader=val_loader,
        result_dir=result_dir,
        device=device,
        dataset_type=dataset_type
    )

    return trainer.train()
