from __future__ import annotations
import os, math, glob, random, sys, json, yaml, hashlib, traceback
from dataclasses import dataclass, field, asdict
from pathlib import Path
from typing import Tuple, List, Dict, Optional, Any
from collections import defaultdict
import logging
from datetime import datetime
import time
import argparse

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchaudio
from tqdm import tqdm
import csv
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="torchaudio._backend")
warnings.filterwarnings("ignore", message=".*StreamingMediaDecoder.*")
warnings.filterwarnings("ignore", message=".*torchaudio.load_with_torchcodec.*")

# ============================================================================
# Configuration
# ============================================================================

@dataclass
class EATMAEConfig:
    # Audio settings
    dataset: str = 'Hybrid'
    sr: int = 48_000
    sample_len: int = 48_000
    n_fft: int = 2048
    win_length: int = 2048
    hop_length: int = 376
    n_mels: int = 128
    top_db: int = 80
    
    # Model architecture
    base_channels: int = 32
    patch_size: Tuple[int, int] = (16, 16)
    embed_dim: int = 384
    depth: int = 8
    nhead: int = 16
    decoder_dim: int = 256
    mask_ratio: float = 0.4
    dropout: float = 0.0
    utter_loss_weight: float = 0.2

    # Patch Embedding settings
    patch_embed_use_bn: bool = False
    patch_embed_use_activation: bool = False
    patch_embed_activation: str = 'none'
    patch_embed_bias: bool = False
    
    # Training settings
    lr: float = 5e-5
    wd: float = 1e-5
    batch_size: int = 128
    num_workers: int = 32
    epochs: int = 10
    seed: Optional[int] = 42
    device: str = field(default_factory=lambda: "cuda:1" if torch.cuda.is_available() else "cpu")
    
    # Additional training settings
    use_amp: bool = False
    gradient_clip: float = 0.5
    log_interval: int = 10
    save_best_only: bool = False
    
    # Logging settings
    experiment_name: str = "EATMAE"
    log_dir: str = "./IMPACT_Models/logs"
    checkpoint_dir: str = "./IMPACT_Models"
    tensorboard: bool = True
    
    # Data settings
    train_data_path: str = f"./Datasets/{dataset}_Pretrain"
    clips_per_folder: Optional[int] = None
    use_weighted_sampling: bool = False
    
    # Advanced settings
    teacher_momentum: float = 0.9999
    frame_loss_type: str = "mse"
    utter_loss_type: str = "mse"  # "mse", "huber", "l1"
    
    def __post_init__(self):
        """Validate configuration"""
        self.validate()
        # Auto-adjust bias setting based on BN usage
        if self.patch_embed_use_bn:
            self.patch_embed_bias = False
        
    def validate(self):
        """Validate configuration values"""
        assert 0 < self.mask_ratio < 1, "mask_ratio must be between 0 and 1"
        assert self.embed_dim % self.nhead == 0, "embed_dim must be divisible by nhead"
        assert self.sr > 0, "Sample rate must be positive"
        assert self.batch_size > 0, "Batch size must be positive"
        assert self.gradient_clip > 0, "gradient_clip must be positive"
        
        # Check patch size compatibility
        expected_f = self.n_mels // 2  # CNN reduces by /2
        expected_t = 128 // 2
        assert expected_f % self.patch_size[0] == 0, f"Frequency dimension {expected_f} not divisible by patch_size[0]={self.patch_size[0]}"
        assert expected_t % self.patch_size[1] == 0, f"Time dimension {expected_t} not divisible by patch_size[1]={self.patch_size[1]}"
        
    def save(self, path: str):
        """Save configuration to YAML file"""
        with open(path, 'w') as f:
            yaml.dump(asdict(self), f, default_flow_style=False)
            
    @classmethod
    def load(cls, path: str):
        """Load configuration from YAML file"""
        with open(path, 'r') as f:
            config_dict = yaml.load(f, Loader=yaml.SafeLoader)
        return cls(**config_dict)
    
    def to_dict(self) -> Dict[str, Any]:
        """Convert to dictionary"""
        return asdict(self)

# ============================================================================
# Utilities
# ============================================================================

def setup_logging(log_dir: str):
    """Setup logging configuration"""
    log_dir = Path(log_dir)
    log_dir.mkdir(exist_ok=True, parents=True)
    
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_dir / 'training.log'),
            logging.StreamHandler()
        ]
    )
    return logging.getLogger(__name__)

def set_deterministic_mode(seed: int = 42):
    """Set deterministic mode for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
    torch.backends.cuda.enable_mem_efficient_sdp(False)

    try:
        torch.use_deterministic_algorithms(True, warn_only=True)
    except AttributeError:
        pass

def worker_init_fn(worker_id):
    """Initialize worker seed for DataLoader"""
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

def save_experiment_config(cfg: EATMAEConfig, save_dir: str):
    """Save experiment configuration"""
    config_dict = cfg.to_dict()
    
    # Generate experiment ID
    config_str = json.dumps(config_dict, sort_keys=True)
    experiment_id = hashlib.md5(config_str.encode()).hexdigest()[:8]
    
    # Save config file
    save_dir = Path(save_dir)
    save_dir.mkdir(exist_ok=True, parents=True)
    config_path = save_dir / f"config_{experiment_id}.json"
    
    with open(config_path, 'w') as f:
        json.dump({
            'experiment_id': experiment_id,
            'timestamp': datetime.now().isoformat(),
            'config': config_dict,
            'python_version': sys.version,
            'pytorch_version': torch.__version__,
            'cuda_version': torch.version.cuda if torch.cuda.is_available() else None,
            'cudnn_version': torch.backends.cudnn.version() if torch.cuda.is_available() else None
        }, f, indent=2)
    
    return experiment_id

# ============================================================================
# GPU Mel Spectrogram Converter
# ============================================================================

class SpecConverter(nn.Module):
    """GPU-based spectrogram converter"""
    def __init__(self, cfg: EATMAEConfig):
        super().__init__()
        self.cfg = cfg
        self.mel = torchaudio.transforms.MelSpectrogram(
            sample_rate=cfg.sr,
            n_fft=cfg.n_fft,
            win_length=cfg.win_length,
            hop_length=cfg.hop_length,
            n_mels=cfg.n_mels,
            power=2.0
        )
        self.db = torchaudio.transforms.AmplitudeToDB("power", top_db=cfg.top_db)
        
    def forward(self, waveform):
        """
        Convert waveform to mel spectrogram on GPU
        Args:
            waveform: (B, T) or (B, 1, T) tensor
        Returns:
            spec: (B, 1, F, T) tensor
        """
        if waveform.dim() == 2:
            waveform = waveform.unsqueeze(1)  # (B, T) -> (B, 1, T)
        
        # Normalize waveform
        rms = torch.sqrt((waveform**2).mean(dim=-1, keepdim=True))
        waveform = waveform / (rms + 1e-6)
        waveform = (waveform - waveform.mean(dim=-1, keepdim=True)) / (waveform.std(dim=-1, keepdim=True) + 1e-6)
        
        # Convert to mel spectrogram
        spec = self.mel(waveform)
        spec = self.db(spec)
        
        # Normalize to [0, 1]
        spec = spec[:, :, :, :128]  # Ensure time dimension is 128
        spec = (spec + self.cfg.top_db) / (2 * self.cfg.top_db)
        
        return spec

# ============================================================================
# Dataset - Improved version
# ============================================================================

def get_audio_info_safe(filepath: str) -> Optional[Tuple[int, int]]:
    """
    Safely get audio file information
    Returns: (num_frames, sample_rate) or None if failed
    """
    try:
        # More stable method to get audio information
        import wave
        import struct
        
        if filepath.lower().endswith('.wav'):
            try:
                with wave.open(filepath, 'rb') as wav_file:
                    frames = wav_file.getnframes()
                    sample_rate = wav_file.getframerate()
                    return frames, sample_rate
            except:
                pass
        
        # fallback: use torchaudio with enhanced error handling
        try:
            # First check if file exists
            if not os.path.exists(filepath):
                return None
                
            # Check file size (skip if too small)
            if os.path.getsize(filepath) < 1024:  # Less than 1KB
                return None
                
            # Get metadata only with torchaudio
            audio, sr = torchaudio.load(filepath, frame_offset=0, num_frames=1)
            # Load full file to estimate total length
            full_audio, _ = torchaudio.load(filepath)
            frames = full_audio.shape[-1]
            return frames, sr
            
        except Exception as e:
            logging.debug(f"Failed to get info for {filepath}: {e}")
            return None
            
    except Exception as e:
        logging.debug(f"Error processing {filepath}: {e}")
        return None

class RobustAudioDataset(Dataset):
    """Improved audio dataset - resolves deprecated API issues"""
    def __init__(self, wav_root: str | os.PathLike, cfg: EATMAEConfig,
                 clips_per_folder: int | None = None, max_retries: int = 3):
        super().__init__()
        self.cfg = cfg
        self.max_retries = max_retries
        self.error_count = defaultdict(int)
        
        if cfg.seed is not None:
            random.seed(cfg.seed)
            
        wav_root = Path(wav_root)
        clip_map: Dict[str, List[tuple[str, int, str]]] = defaultdict(list)
        
        logging.info("Starting audio file scan...")
        
        # Expand file patterns (support more audio formats)
        audio_patterns = ["**/*.wav", "**/*.WAV", "**/*.mp3", "**/*.flac"]
        audio_files = []
        
        for pattern in audio_patterns:
            audio_files.extend(glob.glob(str(wav_root / pattern), recursive=True))
        
        logging.info(f"Found {len(audio_files)} audio files")
        
        # Show scan progress
        valid_files = 0
        for fp in tqdm(audio_files, desc="Analyzing file info"):
            try:
                audio_info = get_audio_info_safe(fp)
                if audio_info is None:
                    continue
                    
                frames, file_sr = audio_info
                
                if frames < cfg.sample_len:  # Check minimum length
                    continue
                    
                valid_files += 1
                n_clips = math.ceil(frames / cfg.sample_len)
                folder = Path(fp).relative_to(wav_root).parts[0]
                
                for k in range(n_clips):
                    clip_map[folder].append((fp, k * cfg.sample_len, folder))
                    
            except Exception as e:
                logging.debug(f"Failed to process file {fp}: {e}")
                continue
        
        logging.info(f"Valid files: {valid_files}, folders: {len(clip_map)}")
        
        # Create index list
        self.index_list = []
        for fld, clips in clip_map.items():
            chosen = clips if clips_per_folder is None else \
                     random.sample(clips, min(len(clips), clips_per_folder))
            self.index_list.extend(chosen)
            
        self.class_to_idx = {f: i for i, f in enumerate(sorted(clip_map))}
        
        logging.info(f"Total clips: {len(self.index_list)}")
        
    def __len__(self):
        return len(self.index_list)
    
    def __getitem__(self, idx):
        for retry in range(self.max_retries):
            try:
                fp, start, folder = self.index_list[idx]
                
                # Check file existence
                if not os.path.exists(fp):
                    raise FileNotFoundError(f"File not found: {fp}")
                
                # Suppress warnings while loading audio
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    
                    # Load audio using torchaudio
                    try:
                        audio, _sr = torchaudio.load(
                            fp, 
                            frame_offset=start,
                            num_frames=self.cfg.sample_len
                        )
                    except Exception as e:
                        # fallback: load entire file then slice
                        audio, _sr = torchaudio.load(fp)
                        end_frame = start + self.cfg.sample_len
                        audio = audio[:, start:end_frame]
                
                # Handle sample rate mismatch
                if _sr != self.cfg.sr:
                    resampler = torchaudio.transforms.Resample(_sr, self.cfg.sr)
                    audio = resampler(audio)
                
                # Convert to numpy and handle channels
                audio = audio.numpy()
                
                # Convert to mono if stereo
                if audio.shape[0] > 1:
                    audio = audio.mean(0)  # Average across channel dimension
                else:
                    audio = audio[0]  # Remove dimension for single channel
                
                # Pad if necessary
                audio = np.pad(audio, (0, max(0, self.cfg.sample_len - len(audio))))[:self.cfg.sample_len]
                
                # Convert to tensor
                waveform = torch.tensor(audio, dtype=torch.float32)
                
                return waveform, self.class_to_idx[folder]
                
            except Exception as e:
                self.error_count[fp] += 1
                logging.error(f"Error loading {fp} (attempt {retry+1}/{self.max_retries}): {e}")
                
                if retry == self.max_retries - 1:
                    logging.error(f"Max retries exceeded for {fp}, using next sample")
                    return self.__getitem__((idx + 1) % len(self))
                    
                time.sleep(0.1 * (retry + 1))

# ============================================================================
# Model Components
# ============================================================================

def build_2d_sincos_pos_embed(d_model: int, grid_f: int, grid_t: int) -> torch.Tensor:
    """Build 2D sinusoidal position embedding"""
    def _pe_1d(dim, pos):
        omega = torch.arange(dim // 2) / (dim // 2)
        omega = 1.0 / (10000 ** omega)
        out = pos.float().unsqueeze(1) * omega.unsqueeze(0)
        return torch.cat([out.sin(), out.cos()], dim=1)
    
    f, t = torch.meshgrid(torch.arange(grid_f), torch.arange(grid_t), indexing="ij")
    return torch.cat([_pe_1d(d_model // 2, f.reshape(-1)),
                      _pe_1d(d_model // 2, t.reshape(-1))], dim=1)

class MaskingGenerator:
    """Token masking generator"""
    def __init__(self, mask_ratio: float):
        self.mask_ratio = mask_ratio
        
    def __call__(self, B: int, N: int, device):
        keep = int(N * (1 - self.mask_ratio))
        noise = torch.rand(B, N, device=device)
        idx_shuffle = torch.argsort(noise, dim=1)
        idx_restore = torch.argsort(idx_shuffle, dim=1)
        return idx_shuffle[:, :keep], idx_shuffle[:, keep:], idx_restore

class CNNEncoder(nn.Module):
    """CNN encoder for initial feature extraction"""
    def __init__(self, c=32):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, c, 3, 2, 1),
            nn.BatchNorm2d(c),
            nn.ReLU()
        )
        self.out_c = c
        
    def forward(self, x):
        B, C, F, T = x.shape
        F_new, T_new = (F // 2) * 2, (T // 2) * 2
        x = x[:, :, :F_new, :T_new]
        return self.net(x)

class OptimalPatchEmbed(nn.Module):
    """Optimized patch embedding with Conv2d and optional BatchNorm/Activation"""
    def __init__(self, c_in: int, d_model: int, patch: Tuple[int, int], 
                 use_bn: bool = True, use_activation: bool = True, 
                 activation_type: str = 'gelu', bias: bool = False):
        super().__init__()
        self.p = patch
        
        # Conv2d for patch creation - more efficient than unfold + linear
        self.net = nn.Conv2d(c_in, d_model, kernel_size=patch, stride=patch, bias=bias)
        
        # Optional BatchNorm for training stability
        self.bn = nn.BatchNorm2d(d_model) if use_bn else nn.Identity()
        
        # Optional Activation for non-linearity
        if use_activation:
            if activation_type == 'gelu':
                self.act = nn.GELU()
            elif activation_type == 'relu':
                self.act = nn.ReLU()
            elif activation_type == 'silu':
                self.act = nn.SiLU()
            else:
                self.act = nn.Identity()
        else:
            self.act = nn.Identity()
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        """Initialize weights with truncated normal distribution"""
        w = self.net.weight.data
        nn.init.trunc_normal_(w.view(w.size(0), -1), std=0.02)
        if self.net.bias is not None:
            nn.init.zeros_(self.net.bias)
    
    def forward(self, x):
        """
        Args:
            x: (B, C, F, T) input tensor
        Returns:
            patches: (B, N, D) where N = H*W is number of patches
            grid_size: (H, W) grid dimensions
        """
        B, C, F, T = x.shape
        H, W = F // self.p[0], T // self.p[1]
        
        # Efficient patch extraction with Conv2d
        x = self.net(x)  # (B, d_model, H, W)
        
        # Apply BatchNorm if enabled
        x = self.bn(x)
        
        # Apply activation if enabled
        x = self.act(x)
        
        # Reshape for transformer input: (B, d_model, H, W) -> (B, H*W, d_model)
        x = x.flatten(2).transpose(1, 2)
        
        return x, (H, W)

class Encoder(nn.Module):
    def __init__(self, d_model, nhead, depth, drop=0., return_all=False):
        super().__init__()
        self.return_all = return_all
        self.cls = nn.Parameter(torch.zeros(1, 1, d_model))
        self.dp = nn.Dropout(drop)
        self.blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model, nhead, d_model * 4, drop,
                batch_first=True, activation='gelu'
            ) for _ in range(depth)
        ])

    def forward(self, x):
        B = x.size(0)
        x = torch.cat([self.cls.expand(B, -1, -1), self.dp(x)], 1)  # (B, 1+N, d)
        if not self.return_all:
            for blk in self.blocks:
                x = blk(x)
            return x  # (B, 1+N, d)
        else:
            layer_tokens = []   # 전체 토큰 저장
            cls_list = []       # CLS만 저장
            for blk in self.blocks:
                x = blk(x)                  # (B, 1+N, d)
                layer_tokens.append(x[:, 1:, :])  # 패치 토큰만 (B, N, d)
                cls_list.append(x[:, 0, :])       # CLS 토큰 (B, d)
            return x, layer_tokens, cls_list

class CNNDecoder(nn.Module):
    """CNN decoder for reconstruction"""
    def __init__(self, patch_dim, base=128, out=1):
        super().__init__()
        self.base = base
        self.fc = nn.Linear(patch_dim, base)
        self.up = nn.Sequential(
            nn.ConvTranspose2d(base, base, 4, 2, 1),
            nn.BatchNorm2d(base),
            nn.ReLU(),
            nn.ConvTranspose2d(base, base // 2, 4, 2, 1),
            nn.BatchNorm2d(base // 2),
            nn.ReLU(),
            nn.ConvTranspose2d(base // 2, base // 4, 4, 2, 1),
            nn.BatchNorm2d(base // 4),
            nn.ReLU(),
            nn.ConvTranspose2d(base // 4, base // 8, 4, 2, 1),
            nn.BatchNorm2d(base // 8),
            nn.ReLU(),
            nn.ConvTranspose2d(base // 8, out, 4, 2, 1)
        )
        
    def forward(self, x, H, W):
        B, N, _ = x.shape
        x = self.fc(x).transpose(1, 2).contiguous().view(B, self.base, H, W)
        return self.up(x)

# ============================================================================
# Optimized EATMAE Model with OptimalPatchEmbed
# ============================================================================

class EATMAE(nn.Module):
    """EATMAE with OptimalPatchEmbed and position embedding caching"""
    def __init__(self, cfg: EATMAEConfig):
        super().__init__()
        self.cfg = cfg
        self.masker = MaskingGenerator(cfg.mask_ratio)
        
        # Model components
        self.cnn = CNNEncoder(cfg.base_channels)
        
        # Use OptimalPatchEmbed instead of standard PatchEmbed
        self.embed = OptimalPatchEmbed(
            c_in=self.cnn.out_c,
            d_model=cfg.embed_dim,
            patch=cfg.patch_size,
            use_bn=cfg.patch_embed_use_bn,
            use_activation=cfg.patch_embed_use_activation,
            activation_type=cfg.patch_embed_activation,
            bias=cfg.patch_embed_bias
        )
        
        self.enc = Encoder(cfg.embed_dim, cfg.nhead, cfg.depth, cfg.dropout, return_all=False)
        
        # Position embedding cache
        self.pos_embed_cache = {}
        self._precompute_common_sizes()
        
        # Decoder components
        self.enc2dec = nn.Linear(cfg.embed_dim, cfg.decoder_dim, bias=False)
        self.mask_tok = nn.Parameter(torch.zeros(1, 1, cfg.decoder_dim))
        self.dec = CNNDecoder(cfg.decoder_dim)
        
        # Initialize weights
        self._init_weights()
        
    def _init_weights(self):
        """Initialize model weights"""
        # Initialize mask token
        nn.init.normal_(self.mask_tok, std=0.02)
        
        # Initialize encoder to decoder projection
        nn.init.xavier_uniform_(self.enc2dec.weight)
        
    def _precompute_common_sizes(self):
        """Precompute position embeddings for common sizes"""
        common_sizes = [(8, 8), (7, 8), (8, 7), (6, 8), (8, 6), (4, 4)]
        for H, W in common_sizes:
            self.pos_embed_cache[(H, W)] = self._create_pos_embed(H, W)
            
    def _create_pos_embed(self, H, W):
        """Create position embeddings"""
        De, Dd = self.cfg.embed_dim, self.cfg.decoder_dim
        pe_enc = build_2d_sincos_pos_embed(De, H, W)
        pe_dec = build_2d_sincos_pos_embed(Dd, H, W)
        pe_enc = torch.cat([torch.zeros(1, De), pe_enc], 0)[None]
        pe_dec = torch.cat([torch.zeros(1, Dd), pe_dec], 0)[None]
        return pe_enc, pe_dec
    
    def get_pos_embed(self, H, W, device, dtype):
        """Get cached or new position embedding"""
        key = (H, W)
        if key not in self.pos_embed_cache:
            logging.debug(f"Creating new position embedding for size {H}x{W}")
            self.pos_embed_cache[key] = self._create_pos_embed(H, W)
        pe_enc, pe_dec = self.pos_embed_cache[key]
        return pe_enc.to(device, dtype), pe_dec.to(device, dtype)
    
    def forward(self, spec):
        B = spec.size(0)
        
        # Feature extraction with CNN
        feat = self.cnn(spec)
        
        # Patch embedding with OptimalPatchEmbed
        tokens, (H, W) = self.embed(feat)
        N, D = tokens.shape[1:]
        
        # Get position embeddings
        pos_enc, pos_dec = self.get_pos_embed(H, W, tokens.device, tokens.dtype)
        tokens = tokens + pos_enc[:, 1:, :]
        
        # Masking
        vis_idx, mask_idx, restore_idx = self.masker(B, N, tokens.device)
        vis_tokens = torch.gather(tokens, 1, vis_idx.unsqueeze(-1).repeat(1, 1, D))
        
        # Encoding
        enc_out = self.enc(vis_tokens)
        cls_last = enc_out[:, 0, :]
        
        # Decoding preparation
        vis_feat = self.enc2dec(enc_out[:, 1:, :])
        mask_tok = self.mask_tok.expand(B, mask_idx.size(1), -1)
        merged = torch.cat([vis_feat, mask_tok], 1)
        full = torch.zeros_like(merged).scatter_(
            1, restore_idx.unsqueeze(-1).expand_as(merged), merged
        )
        full = full + pos_dec[:, 1:, :]
        
        # Decode
        pred = self.dec(full, H, W)
        
        return pred, mask_idx, cls_last, (H, W)

class EATMAETeacher(EATMAE):
    def __init__(self, cfg: EATMAEConfig):
        super().__init__(cfg)
        self.enc = Encoder(cfg.embed_dim, cfg.nhead, cfg.depth, cfg.dropout, return_all=True)
        for p in self.parameters():
            p.requires_grad = False

    def forward(self, spec):
        B = spec.size(0)
        feat = self.cnn(spec)
        tokens, (H, W) = self.embed(feat)
        N, D = tokens.shape[1:]
        pos_enc, pos_dec = self.get_pos_embed(H, W, tokens.device, tokens.dtype)
        tokens_with_pos = tokens + pos_enc[:, 1:, :]

        enc_out, layer_tokens, cls_list = self.enc(tokens_with_pos)  # layer_tokens: list of L tensors (B, N, d)

        H_bar = torch.stack(layer_tokens, dim=0).mean(dim=0)
        g_t = H_bar.mean(dim=1)

        dec_input_features = self.enc2dec(enc_out[:, 1:, :])
        dec_input = dec_input_features + pos_dec[:, 1:, :]
        pred = self.dec(dec_input, H, W)

        return pred, None, g_t, (H, W)

# ============================================================================
# Training Components
# ============================================================================

class MetricLogger:
    """Metric logger for TensorBoard"""
    def __init__(self, log_dir: str):
        self.writer = SummaryWriter(log_dir)
        self.step = 0
        
    def log_scalar(self, tag: str, value: float, step: int = None):
        step = step or self.step
        self.writer.add_scalar(tag, value, step)
        
    def log_histogram(self, tag: str, values: torch.Tensor, step: int = None):
        step = step or self.step
        self.writer.add_histogram(tag, values, step)
        
    def log_model_weights(self, model: nn.Module, step: int = None):
        step = step or self.step
        for name, param in model.named_parameters():
            if param.grad is not None:
                self.writer.add_histogram(f'gradients/{name}', param.grad, step)
            self.writer.add_histogram(f'weights/{name}', param, step)
            
    def close(self):
        self.writer.close()

class CheckpointManager:
    """Checkpoint manager for model saving/loading - save all epochs"""
    def __init__(self, save_dir: str):
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(exist_ok=True, parents=True)
        self.checkpoint_files = []
        
    def save_checkpoint(self, model, optimizer, epoch, metrics, is_best=False):
        """Save checkpoint - save all epochs"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'metrics': metrics,
            'timestamp': datetime.now().isoformat(),
            'torch_version': torch.__version__,
        }
        
        # Generate filename
        filename = f"impact_{EATMAEConfig.dataset}_epoch_{epoch+1:04d}.pth"
        filepath = self.save_dir / filename
        
        # Save checkpoint
        torch.save(checkpoint, filepath)
        self.checkpoint_files.append(filepath)
        logging.info(f"Saved checkpoint for epoch {epoch+1} with loss {metrics['loss']:.4f}")
        
        # Save best model separately
        if is_best:
            best_path = self.save_dir / "best_model.pth"
            torch.save(checkpoint, best_path)
            logging.info(f"Saved best model with loss {metrics['loss']:.4f}")
        
        return filepath
    
    def load_checkpoint(self, model, optimizer=None, checkpoint_path=None):
        """Load checkpoint"""
        if checkpoint_path is None:
            checkpoints = sorted(self.save_dir.glob("checkpoint_*.pth"))
            if not checkpoints:
                logging.warning("No checkpoint found")
                return None
            checkpoint_path = checkpoints[-1]
        
        logging.info(f"Loading checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path)
        
        model.load_state_dict(checkpoint['model_state_dict'])
        
        if optimizer and 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        return checkpoint

def update_teacher(student, teacher, m=0.999):
    """Update teacher model with EMA"""
    for s, t in zip(student.parameters(), teacher.parameters()):
        t.data = m * t.data + (1 - m) * s.data

def create_optimized_dataloader(dataset, cfg, train=True):
    """Create optimized DataLoader"""
    num_workers = min(cfg.num_workers, os.cpu_count() or 1)
    
    dataloader = DataLoader(
        dataset,
        batch_size=cfg.batch_size,
        shuffle=train,
        num_workers=num_workers,
        pin_memory=(cfg.device == "cuda"),
        persistent_workers=(num_workers > 0),
        prefetch_factor=2 if num_workers > 0 else None,
        worker_init_fn=worker_init_fn
    )
    
    return dataloader

# ============================================================================
# Training Loop
# ============================================================================

def run_epoch_with_amp(student, teacher, spec_converter, loader, opt, scaler, cfg, logger, train=True):
    """Run one epoch with mixed precision training"""
    student.train() if train else student.eval()
    
    total_loss = 0.
    total_frame_loss = 0.
    total_utter_loss = 0.
    pbar = tqdm(loader, leave=False, desc="train" if train else "eval")
    
    for batch_idx, (waveform, label) in enumerate(pbar):
        waveform = waveform.to(cfg.device)
        
        # Convert waveform to spectrogram on GPU
        with torch.no_grad():
            spec = spec_converter(waveform)
        
        if train:
            opt.zero_grad()
        
        # Forward pass with automatic mixed precision
        with torch.amp.autocast(device_type='cuda' if cfg.device == 'cuda' else 'cpu', enabled=cfg.use_amp):
            s_pred, _, s_cls, _ = student(spec)
            
            with torch.no_grad():
                t_pred, _, t_cls, _ = teacher(spec)
            
            # Calculate losses
            if cfg.frame_loss_type == "huber":
                frame_loss = F.huber_loss(s_pred, t_pred)
            elif cfg.frame_loss_type == "mse":
                frame_loss = F.mse_loss(s_pred, t_pred)
            elif cfg.frame_loss_type == "l1":
                frame_loss = F.l1_loss(s_pred, t_pred)

            if cfg.utter_loss_type == "huber":
                utter_loss = F.huber_loss(s_cls, t_cls)
            elif cfg.utter_loss_type == "mse":
                utter_loss = F.mse_loss(s_cls, t_cls)
            elif cfg.utter_loss_type == "l1":
                utter_loss = F.l1_loss(s_cls, t_cls)                
            
            loss = frame_loss + cfg.utter_loss_weight * utter_loss

        if train:
            if scaler:
                # When using AMP
                scaler.scale(loss).backward()
                scaler.unscale_(opt)
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    student.parameters(),
                    max_norm=cfg.gradient_clip
                )
                scaler.step(opt)
                scaler.update()
            else:
                # When not using AMP
                loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    student.parameters(),
                    max_norm=cfg.gradient_clip
                )
                opt.step()
            
            # Update teacher
            update_teacher(student, teacher, m=cfg.teacher_momentum)
            
        # Update metrics
        total_loss += loss.item()
        total_frame_loss += frame_loss.item()
        total_utter_loss += utter_loss.item()
        
        # Update progress bar
        pbar.set_postfix({
            'loss': loss.item(),
            'loc': frame_loss.item(),
            'cls': utter_loss.item(),
            'grad': grad_norm.item() if train else 0
        })
        
        # Log to TensorBoard
        if train and logger and batch_idx % cfg.log_interval == 0:
            global_step = batch_idx + len(loader) * logger.step
            logger.log_scalar('batch/loss', loss.item(), global_step)
            logger.log_scalar('batch/frame_loss', frame_loss.item(), global_step)
            logger.log_scalar('batch/utter_loss', utter_loss.item(), global_step)
            if train:
                logger.log_scalar('batch/grad_norm', grad_norm.item(), global_step)
                logger.log_scalar('batch/learning_rate', opt.param_groups[0]['lr'], global_step)
    
    # Calculate epoch averages
    n_batches = len(loader)
    avg_loss = total_loss / n_batches
    avg_frame_loss = total_frame_loss / n_batches
    avg_utter_loss = total_utter_loss / n_batches
    
    return {
        'loss': avg_loss,
        'frame_loss': avg_frame_loss,
        'utter_loss': avg_utter_loss
    }

# ============================================================================
# Main Training Function
# ============================================================================

def main():
    # Parse arguments and load config
    parser = argparse.ArgumentParser()
    parser.add_argument('--resume', type=str, default=None, help='Resume from checkpoint')
    args = parser.parse_args()
    
    warnings.filterwarnings("ignore", category=UserWarning, module="torchaudio._backend")
    warnings.filterwarnings("ignore", message=".*StreamingMediaDecoder.*")
    warnings.filterwarnings("ignore", message=".*Flash Attention.*")
    warnings.filterwarnings("ignore", message=".*Memory Efficient attention.*")
    
    cfg = EATMAEConfig()
       
    # Setup
    set_deterministic_mode(cfg.seed or 42)
    logger = setup_logging(cfg.log_dir)
    experiment_id = save_experiment_config(cfg, cfg.log_dir)
    
    logger.info(f"Starting experiment {experiment_id}")
    logger.info(f"Configuration: {cfg}")
    logger.info(f"Using OptimalPatchEmbed with BN={cfg.patch_embed_use_bn}, Activation={cfg.patch_embed_use_activation}")
    logger.info("Checkpoints for all epochs will be saved.")
    
    # Create directories
    Path(cfg.checkpoint_dir).mkdir(exist_ok=True, parents=True)
    
    # Initialize TensorBoard logger
    tb_logger = MetricLogger(f"{cfg.log_dir}/runs/{experiment_id}") if cfg.tensorboard else None
    
    # Load datasets
    logger.info("Loading training dataset...")
    train_dataset = RobustAudioDataset(cfg.train_data_path, cfg)
    
    logger.info(f"Training dataset size: {len(train_dataset)}")
    
    # Create dataloader
    train_loader = create_optimized_dataloader(train_dataset, cfg, train=True)
    
    # Initialize models
    logger.info("Initializing models with OptimalPatchEmbed...")
    spec_converter = SpecConverter(cfg).to(cfg.device)
    student = EATMAE(cfg).to(cfg.device)
    teacher = EATMAETeacher(cfg).to(cfg.device)
    
    # Copy student weights to teacher
    teacher.load_state_dict(student.state_dict())
    
    # Log model architecture
    total_params = sum(p.numel() for p in student.parameters())
    trainable_params = sum(p.numel() for p in student.parameters() if p.requires_grad)
    logger.info(f"Total parameters: {total_params:,}")
    logger.info(f"Trainable parameters: {trainable_params:,}")
    
    # Initialize optimizer
    optimizer = torch.optim.AdamW(student.parameters(), lr=cfg.lr, weight_decay=cfg.wd)
    
    # Initialize gradient scaler for mixed precision
    scaler = torch.amp.GradScaler('cuda') if cfg.use_amp and cfg.device == 'cuda' else None
    
    # Initialize checkpoint manager - max_checkpoints removed
    checkpoint_manager = CheckpointManager(cfg.checkpoint_dir)
    
    # Resume from checkpoint if specified
    start_epoch = 0
    if args.resume:
        checkpoint = checkpoint_manager.load_checkpoint(student, optimizer, args.resume)
        if checkpoint:
            start_epoch = checkpoint['epoch'] + 1
            logger.info(f"Resumed from epoch {start_epoch}")
    
    # Training loop
    logger.info("Starting training...")
    
    # CSV logging
    csv_path = f"{cfg.log_dir}/training_log_{experiment_id}.csv"
    with open(csv_path, 'w', newline='') as csvfile:
        csv_writer = csv.writer(csvfile)
        csv_writer.writerow(['epoch', 'train_loss', 'train_frame_loss', 
                           'train_utter_loss', 'learning_rate', 'epoch_time'])
        
        for epoch in range(start_epoch, cfg.epochs):
            epoch_start = time.time()
            
            logger.info(f"Epoch {epoch + 1}/{cfg.epochs}")
            
            # Training
            train_metrics = run_epoch_with_amp(
                student, teacher, spec_converter, train_loader, 
                optimizer, scaler, cfg, tb_logger, train=True
            )
            
            # Get current learning rate
            current_lr = optimizer.param_groups[0]['lr']
            
            # Calculate epoch time
            epoch_time = time.time() - epoch_start
            
            # Log epoch metrics
            logger.info(f"Epoch {epoch + 1} - "
                       f"Train Loss: {train_metrics['loss']:.4f}, "
                       f"LR: {current_lr:.6f}, "
                       f"Time: {epoch_time:.2f}s")
            
            # Log to CSV
            csv_writer.writerow([
                epoch + 1,
                train_metrics['loss'],
                train_metrics['frame_loss'],
                train_metrics['utter_loss'],
                current_lr,
                epoch_time
            ])
            csvfile.flush()
            
            # Log to TensorBoard
            if tb_logger:
                tb_logger.step = epoch
                tb_logger.log_scalar('epoch/train_loss', train_metrics['loss'], epoch)
                tb_logger.log_scalar('epoch/learning_rate', current_lr, epoch)
                
                # Log model weights periodically
                if epoch % 10 == 0:
                    tb_logger.log_model_weights(student, epoch)
            
            # Save checkpoint - save all epochs
            checkpoint_manager.save_checkpoint(
                student, optimizer, epoch,
                {'loss': train_metrics['loss'], **train_metrics},
                is_best=False
            )
    
    logger.info("Training completed!")
    logger.info(f"Checkpoints for all {cfg.epochs} epochs have been saved.")
    
    # Close TensorBoard logger
    if tb_logger:
        tb_logger.close()

if __name__ == "__main__":
    main()