import math
import sys
import utils
import torch
from typing import Iterable, Optional, Dict, List, Tuple
from loss import sequence_loss


def train_one_epoch(model: torch.nn.Module,
                    data_loader: Iterable,
                    optimizer: torch.optim.Optimizer,
                    device: torch.device,
                    epoch: int,
                    loss_scaler,
                    max_norm: float = 0,
                    model_ema: Optional[object] = None):
    """
    Train for one epoch with three-chain Conformer (H, L, Ag) + Dual Sliding Attention.
    """
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = f"Epoch: [{epoch}]"
    print_freq = 10

    # 添加各个loss的meter
    metric_logger.add_meter('loss_H', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
    metric_logger.add_meter('loss_L', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
    metric_logger.add_meter('loss_Ag', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))

    for batch in metric_logger.log_every(data_loader, print_freq, header):
        # Load batch
        H_embed = batch['H_embedding'].to(device)
        L_embed = batch['L_embedding'].to(device)
        Ag_embed = batch['Ag_embedding'].to(device)

        H_labels = batch['H_labels'].to(device)
        L_labels = batch['L_labels'].to(device)
        Ag_labels = batch['Ag_labels'].to(device)

        H_mask = batch['H_mask'].to(device).bool()
        L_mask = batch['L_mask'].to(device).bool()
        Ag_mask = batch['Ag_mask'].to(device).bool()

        # Transpose for conformer input
        H_embed_t = H_embed.transpose(1, 2)
        L_embed_t = L_embed.transpose(1, 2)
        Ag_embed_t = Ag_embed.transpose(1, 2)

        with torch.cuda.amp.autocast():
            # Forward pass
            H_logits, L_logits, Ag_logits = model(H_embed_t, L_embed_t, Ag_embed_t)

            # Compute masked sequence losses
            loss_H = sequence_loss(H_logits, H_labels, H_mask)
            loss_L = sequence_loss(L_logits, L_labels, L_mask)
            loss_Ag = sequence_loss(Ag_logits, Ag_labels, Ag_mask)
            loss = (loss_H + loss_L + loss_Ag) / 3.0

        # Check NaN/Inf
        loss_value = loss.item()
        if not math.isfinite(loss_value):
            print(f"Loss is {loss_value}, stopping training")
            sys.exit(1)

        # Backprop with AMP
        optimizer.zero_grad()
        loss_scaler.scale(loss).backward()
        if max_norm > 0:
            loss_scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        loss_scaler.step(optimizer)
        loss_scaler.update()

        # EMA update
        if model_ema is not None:
            model_ema.update(model)

        # 更新所有loss值
        metric_logger.update(loss=loss_value)
        metric_logger.update(loss_H=loss_H.item())
        metric_logger.update(loss_L=loss_L.item())
        metric_logger.update(loss_Ag=loss_Ag.item())
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])

    metric_logger.synchronize_between_processes()
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}