from __future__ import annotations
import os, math, glob, random, sys, json, yaml, hashlib, traceback
from dataclasses import dataclass, field, asdict
from pathlib import Path
from typing import Tuple, List, Dict, Optional, Any
from collections import defaultdict
import logging
from datetime import datetime
import time
import argparse

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchaudio
from tqdm import tqdm
import csv
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="torchaudio._backend")
warnings.filterwarnings("ignore", message=".*StreamingMediaDecoder.*")
warnings.filterwarnings("ignore", message=".*torchaudio.load_with_torchcodec.*")

# ============================================================================
# Configuration for Downstream Tasks
# ============================================================================

@dataclass
class DownstreamConfig:
    # Audio settings (same as training code)
    dataset: str = 'Hybrid'
    sr: int = 48_000
    sample_len: int = 48_000
    n_fft: int = 2048
    win_length: int = 2048
    hop_length: int = 376
    n_mels: int = 128
    top_db: int = 80
    
    # Model architecture (must match pretrained model)
    mr = 0.4
    ut = 0.2
    base_channels: int = 32
    patch_size: Tuple[int, int] = (16, 16)
    embed_dim: int = 384
    depth: int = 8
    nhead: int = 16
    decoder_dim: int = 256
    mask_ratio: float = mr
    dropout: float = 0.0
    utter_loss_weight: float = ut
    
    # Patch Embedding settings
    patch_embed_use_bn: bool = False
    patch_embed_use_activation: bool = False
    patch_embed_activation: str = 'none'
    patch_embed_bias: bool = False
    
    # Downstream task specific settings
    classifier_hidden_dim: int = 256
    classifier_dropout: float = 0.0
    
    # Training settings
    lr: float = 1e-3
    wd: float = 5e-3
    batch_size: int = 16
    num_workers: int = 16
    epochs: int = 100
    seed: Optional[int] = 42
    device: str = field(default_factory=lambda: "cuda" if torch.cuda.is_available() else "cpu")
    
    # Training optimization
    patience: int = 10
    use_amp: bool = False
    gradient_clip: float = 1.0
    log_interval: int = 10
    
    # Scheduler settings
    scheduler_type: str = "plateau"  # "plateau", "cosine", "step"
    lr_decay_factor: float = 0.5
    lr_decay_patience: int = 5
    lr_min: float = 1e-6
    
    # Paths
    pretrained_model_path: str = f"./IMPACT_Models/impact_{dataset}_epoch_0010.pth"
    dataset_root: str = f"./Datasets/DINOS_Downstreams"
    save_dir: str = "./IMPACT_Models/"
    log_dir: str = "./IMPACT_Models/logs"

    def __post_init__(self):
        """Validate configuration"""
        assert os.path.exists(self.pretrained_model_path), f"Pretrained model not found: {self.pretrained_model_path}"
        assert self.classifier_hidden_dim > 0, "classifier_hidden_dim must be positive"
        assert 0 <= self.classifier_dropout < 1, "classifier_dropout must be between 0 and 1"

# ============================================================================
# Import Model Components from Training Code (from training code)
# ============================================================================

def build_2d_sincos_pos_embed(d_model: int, grid_f: int, grid_t: int) -> torch.Tensor:
    """Build 2D sinusoidal position embedding"""
    def _pe_1d(dim, pos):
        omega = torch.arange(dim // 2) / (dim // 2)
        omega = 1.0 / (10000 ** omega)
        out = pos.float().unsqueeze(1) * omega.unsqueeze(0)
        return torch.cat([out.sin(), out.cos()], dim=1)
    
    f, t = torch.meshgrid(torch.arange(grid_f), torch.arange(grid_t), indexing="ij")
    return torch.cat([_pe_1d(d_model // 2, f.reshape(-1)),
                      _pe_1d(d_model // 2, t.reshape(-1))], dim=1)

class MaskingGenerator:
    """Token masking generator"""
    def __init__(self, mask_ratio: float):
        self.mask_ratio = mask_ratio
        
    def __call__(self, B: int, N: int, device):
        keep = int(N * (1 - self.mask_ratio))
        noise = torch.rand(B, N, device=device)
        idx_shuffle = torch.argsort(noise, dim=1)
        idx_restore = torch.argsort(idx_shuffle, dim=1)
        return idx_shuffle[:, :keep], idx_shuffle[:, keep:], idx_restore

class CNNEncoder(nn.Module):
    """CNN encoder for initial feature extraction"""
    def __init__(self, c=32):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, c, 3, 2, 1),
            nn.BatchNorm2d(c),
            nn.ReLU()
        )
        self.out_c = c
        
    def forward(self, x):
        B, C, F, T = x.shape
        F_new, T_new = (F // 2) * 2, (T // 2) * 2
        x = x[:, :, :F_new, :T_new]
        return self.net(x)

class OptimalPatchEmbed(nn.Module):
    """Optimized patch embedding with Conv2d and optional BatchNorm/Activation"""
    def __init__(self, c_in: int, d_model: int, patch: Tuple[int, int], 
                 use_bn: bool = True, use_activation: bool = True, 
                 activation_type: str = 'gelu', bias: bool = False):
        super().__init__()
        self.p = patch
        
        # Conv2d for patch creation - more efficient than unfold + linear
        self.net = nn.Conv2d(c_in, d_model, kernel_size=patch, stride=patch, bias=bias)
        
        # Optional BatchNorm for training stability
        self.bn = nn.BatchNorm2d(d_model) if use_bn else nn.Identity()
        
        # Optional Activation for non-linearity
        if use_activation:
            if activation_type == 'gelu':
                self.act = nn.GELU()
            elif activation_type == 'relu':
                self.act = nn.ReLU()
            elif activation_type == 'silu':
                self.act = nn.SiLU()
            else:
                self.act = nn.Identity()
        else:
            self.act = nn.Identity()
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        """Initialize weights with truncated normal distribution"""
        w = self.net.weight.data
        nn.init.trunc_normal_(w.view(w.size(0), -1), std=0.02)
        if self.net.bias is not None:
            nn.init.zeros_(self.net.bias)
    
    def forward(self, x):
        """
        Args:
            x: (B, C, F, T) input tensor
        Returns:
            patches: (B, N, D) where N = H*W is number of patches
            grid_size: (H, W) grid dimensions
        """
        B, C, F, T = x.shape
        H, W = F // self.p[0], T // self.p[1]
        
        # Efficient patch extraction with Conv2d
        x = self.net(x)  # (B, d_model, H, W)
        
        # Apply BatchNorm if enabled
        x = self.bn(x)
        
        # Apply activation if enabled
        x = self.act(x)
        
        # Reshape for transformer input: (B, d_model, H, W) -> (B, H*W, d_model)
        x = x.flatten(2).transpose(1, 2)
        
        return x, (H, W)

class Encoder(nn.Module):
    """Transformer encoder"""
    def __init__(self, d_model, nhead, depth, drop=0., return_all=False):
        super().__init__()
        self.return_all = return_all
        self.cls = nn.Parameter(torch.zeros(1, 1, d_model))
        self.dp = nn.Dropout(drop)
        self.blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model, nhead, d_model * 4, drop,
                batch_first=True, activation='gelu'
            ) for _ in range(depth)
        ])
        
    def forward(self, x):
        B = x.size(0)
        x = torch.cat([self.cls.expand(B, -1, -1), self.dp(x)], 1)
        cls_list = []
        for blk in self.blocks:
            x = blk(x)
            cls_list.append(x[:, 0, :])
        return (x, cls_list) if self.return_all else x

class CNNDecoder(nn.Module):
    """CNN decoder for reconstruction"""
    def __init__(self, patch_dim, base=128, out=1):
        super().__init__()
        self.base = base
        self.fc = nn.Linear(patch_dim, base)
        self.up = nn.Sequential(
            nn.ConvTranspose2d(base, base, 4, 2, 1),
            nn.BatchNorm2d(base),
            nn.ReLU(),
            nn.ConvTranspose2d(base, base // 2, 4, 2, 1),
            nn.BatchNorm2d(base // 2),
            nn.ReLU(),
            nn.ConvTranspose2d(base // 2, base // 4, 4, 2, 1),
            nn.BatchNorm2d(base // 4),
            nn.ReLU(),
            nn.ConvTranspose2d(base // 4, base // 8, 4, 2, 1),
            nn.BatchNorm2d(base // 8),
            nn.ReLU(),
            nn.ConvTranspose2d(base // 8, out, 4, 2, 1)
        )
        
    def forward(self, x, H, W):
        B, N, _ = x.shape
        x = self.fc(x).transpose(1, 2).contiguous().view(B, self.base, H, W)
        return self.up(x)

# ============================================================================
# EATMAE Model for Feature Extraction (from training code)
# ============================================================================

class EATMAE(nn.Module):
    """EATMAE with OptimalPatchEmbed for feature extraction"""
    def __init__(self, cfg: DownstreamConfig):
        super().__init__()
        self.cfg = cfg
        self.masker = MaskingGenerator(cfg.mask_ratio)
        
        # Model components
        self.cnn = CNNEncoder(cfg.base_channels)
        
        # Use OptimalPatchEmbed
        self.embed = OptimalPatchEmbed(
            c_in=self.cnn.out_c,
            d_model=cfg.embed_dim,
            patch=cfg.patch_size,
            use_bn=cfg.patch_embed_use_bn,
            use_activation=cfg.patch_embed_use_activation,
            activation_type=cfg.patch_embed_activation,
            bias=cfg.patch_embed_bias
        )
        
        self.enc = Encoder(cfg.embed_dim, cfg.nhead, cfg.depth, cfg.dropout, return_all=False)
        
        # Position embedding cache
        self.pos_embed_cache = {}
        self._precompute_common_sizes()
        
        # Decoder components (not used during inference but needed for model loading)
        self.enc2dec = nn.Linear(cfg.embed_dim, cfg.decoder_dim, bias=False)
        self.mask_tok = nn.Parameter(torch.zeros(1, 1, cfg.decoder_dim))
        self.dec = CNNDecoder(cfg.decoder_dim)
        
        # Initialize weights
        self._init_weights()
        
    def _init_weights(self):
        """Initialize model weights"""
        nn.init.normal_(self.mask_tok, std=0.02)
        nn.init.xavier_uniform_(self.enc2dec.weight)
        
    def _precompute_common_sizes(self):
        """Precompute position embeddings for common sizes"""
        common_sizes = [(8, 8), (7, 8), (8, 7), (6, 8), (8, 6), (4, 4)]
        for H, W in common_sizes:
            self.pos_embed_cache[(H, W)] = self._create_pos_embed(H, W)
            
    def _create_pos_embed(self, H, W):
        """Create position embeddings"""
        De, Dd = self.cfg.embed_dim, self.cfg.decoder_dim
        pe_enc = build_2d_sincos_pos_embed(De, H, W)
        pe_dec = build_2d_sincos_pos_embed(Dd, H, W)
        pe_enc = torch.cat([torch.zeros(1, De), pe_enc], 0)[None]
        pe_dec = torch.cat([torch.zeros(1, Dd), pe_dec], 0)[None]
        return pe_enc, pe_dec
    
    def get_pos_embed(self, H, W, device, dtype):
        """Get cached or new position embedding"""
        key = (H, W)
        if key not in self.pos_embed_cache:
            self.pos_embed_cache[key] = self._create_pos_embed(H, W)
        pe_enc, pe_dec = self.pos_embed_cache[key]
        return pe_enc.to(device, dtype), pe_dec.to(device, dtype)
    
    def forward_features(self, spec):
        """Extract features without reconstruction (for downstream tasks)"""
        B = spec.size(0)
        
        # Feature extraction with CNN
        feat = self.cnn(spec)
        
        # Patch embedding
        tokens, (H, W) = self.embed(feat)
        N, D = tokens.shape[1:]
        
        # Get position embeddings
        pos_enc, _ = self.get_pos_embed(H, W, tokens.device, tokens.dtype)
        tokens = tokens + pos_enc[:, 1:, :]
        
        # No masking for downstream tasks - use all tokens
        # Encoding
        enc_out = self.enc(tokens)
        cls_token = enc_out[:, 0, :]  # CLS token contains global representation
        
        return cls_token
    
    def forward(self, spec):
        """Full forward pass (for compatibility with pretrained weights)"""
        B = spec.size(0)
        feat = self.cnn(spec)
        tokens, (H, W) = self.embed(feat)
        N, D = tokens.shape[1:]
        
        pos_enc, pos_dec = self.get_pos_embed(H, W, tokens.device, tokens.dtype)
        tokens = tokens + pos_enc[:, 1:, :]
        
        # Masking (only during pretraining)
        vis_idx, mask_idx, restore_idx = self.masker(B, N, tokens.device)
        vis_tokens = torch.gather(tokens, 1, vis_idx.unsqueeze(-1).repeat(1, 1, D))
        
        # Encoding
        enc_out = self.enc(vis_tokens)
        cls_last = enc_out[:, 0, :]
        
        # Decoding
        vis_feat = self.enc2dec(enc_out[:, 1:, :])
        mask_tok = self.mask_tok.expand(B, mask_idx.size(1), -1)
        merged = torch.cat([vis_feat, mask_tok], 1)
        full = torch.zeros_like(merged).scatter_(
            1, restore_idx.unsqueeze(-1).expand_as(merged), merged
        )
        full = full + pos_dec[:, 1:, :]
        pred = self.dec(full, H, W)
        
        return pred, mask_idx, cls_last, (H, W)

# ============================================================================
# Spectrogram Converter (from training code)
# ============================================================================

class SpecConverter(nn.Module):
    """GPU-based spectrogram converter"""
    def __init__(self, cfg: DownstreamConfig):
        super().__init__()
        self.cfg = cfg
        self.mel = torchaudio.transforms.MelSpectrogram(
            sample_rate=cfg.sr,
            n_fft=cfg.n_fft,
            win_length=cfg.win_length,
            hop_length=cfg.hop_length,
            n_mels=cfg.n_mels,
            power=2.0
        )
        self.db = torchaudio.transforms.AmplitudeToDB("power", top_db=cfg.top_db)
        
    def forward(self, waveform):
        """Convert waveform to mel spectrogram on GPU"""
        if waveform.dim() == 2:
            waveform = waveform.unsqueeze(1)  # (B, T) -> (B, 1, T)
        
        # Normalize waveform
        rms = torch.sqrt((waveform**2).mean(dim=-1, keepdim=True))
        waveform = waveform / (rms + 1e-6)
        waveform = (waveform - waveform.mean(dim=-1, keepdim=True)) / (waveform.std(dim=-1, keepdim=True) + 1e-6)
        
        # Convert to mel spectrogram
        spec = self.mel(waveform)
        spec = self.db(spec)
        
        # Normalize to [0, 1]
        spec = spec[:, :, :, :128]  # Ensure time dimension is 128
        spec = (spec + self.cfg.top_db) / (2 * self.cfg.top_db)
        
        return spec

# ============================================================================
# Audio Loading Utilities (same method as training code)
# ============================================================================

def get_audio_info_safe(filepath: str) -> Optional[Tuple[int, int]]:
    """
    Safely get audio file information
    Returns: (num_frames, sample_rate) or None if failed
    """
    try:
        # More stable method to get audio information
        import wave
        import struct
        
        if filepath.lower().endswith('.wav'):
            try:
                with wave.open(filepath, 'rb') as wav_file:
                    frames = wav_file.getnframes()
                    sample_rate = wav_file.getframerate()
                    return frames, sample_rate
            except:
                pass
        
        # fallback: use torchaudio with enhanced error handling
        try:
            # First check if file exists
            if not os.path.exists(filepath):
                return None
                
            # Check file size (skip if too small)
            if os.path.getsize(filepath) < 1024:  # Less than 1KB
                return None
                
            # Get metadata only with torchaudio
            audio, sr = torchaudio.load(filepath, frame_offset=0, num_frames=1)
            # Load full file to estimate total length
            full_audio, _ = torchaudio.load(filepath)
            frames = full_audio.shape[-1]
            return frames, sr
            
        except Exception as e:
            logging.debug(f"Failed to get info for {filepath}: {e}")
            return None
            
    except Exception as e:
        logging.debug(f"Error processing {filepath}: {e}")
        return None

# ============================================================================
# Downstream Dataset (same audio loading method as training code)
# ============================================================================

class DownstreamDataset(Dataset):
    """CSV-based downstream task dataset - uses same audio loading method as training code"""
    def __init__(self, csv_path: str | os.PathLike, cfg: DownstreamConfig, max_retries: int = 3):
        super().__init__()
        self.cfg = cfg
        self.sample_length = cfg.sample_len
        self.max_retries = max_retries
        self.error_count = defaultdict(int)
        self.index_list = []
        self.class_to_idx = {}
        self.csv_path = csv_path
        class_names = set()

        if not os.path.exists(csv_path):
            raise FileNotFoundError(f"CSV file not found: {csv_path}")

        if cfg.seed is not None:
            random.seed(cfg.seed)

        # Read CSV file
        logging.info("Starting to read CSV file...")
        with open(self.csv_path, newline="") as f:
            reader = csv.DictReader(f)
            for row in reader:
                # Check file existence
                filepath = row["filepath"]
                if not os.path.exists(filepath):
                    logging.warning(f"File does not exist: {filepath}")
                    continue
                    
                self.index_list.append(
                    (filepath, int(row["start_sample"]),
                     int(row["label"]), row["class_name"])
                )
                class_names.add(row["class_name"])

        self.class_to_idx = {name: idx for idx, name in enumerate(sorted(class_names))}
        
        logging.info(f"Dataset loading complete - Total {len(self.index_list)} samples, {len(class_names)} classes")

    def __len__(self):
        return len(self.index_list)
    
    def __getitem__(self, idx):
        for retry in range(self.max_retries):
            try:
                fp, start, label, folder = self.index_list[idx]
                
                # Check file existence
                if not os.path.exists(fp):
                    raise FileNotFoundError(f"File not found: {fp}")
                
                # Suppress warnings while loading audio (same method as training code)
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    
                    # Load audio using torchaudio
                    try:
                        audio, _sr = torchaudio.load(
                            fp, 
                            frame_offset=start,
                            num_frames=self.cfg.sample_len
                        )
                    except Exception as e:
                        # fallback: load full file then slice
                        audio, _sr = torchaudio.load(fp)
                        end_frame = start + self.cfg.sample_len
                        audio = audio[:, start:end_frame]
                
                # Handle sample rate mismatch (same method as training code)
                if _sr != self.cfg.sr:
                    resampler = torchaudio.transforms.Resample(_sr, self.cfg.sr)
                    audio = resampler(audio)
                
                # Convert to numpy and handle channels (same method as training code)
                audio = audio.numpy()
                
                # Convert to mono if stereo
                if audio.shape[0] > 1:
                    audio = audio.mean(0)  # Average across channel dimension
                else:
                    audio = audio[0]  # Remove dimension for single channel
                
                # Pad if necessary
                audio = np.pad(audio, (0, max(0, self.cfg.sample_len - len(audio))))[:self.cfg.sample_len]
                
                # Convert to tensor
                waveform = torch.tensor(audio, dtype=torch.float32)
                
                return waveform, self.class_to_idx[folder], folder
                
            except Exception as e:
                self.error_count[fp] += 1
                logging.error(f"Error loading {fp} (attempt {retry+1}/{self.max_retries}): {e}")
                
                if retry == self.max_retries - 1:
                    logging.error(f"Max retries exceeded for {fp}, using next sample")
                    return self.__getitem__((idx + 1) % len(self))
                    
                time.sleep(0.1 * (retry + 1))

# ============================================================================
# Downstream Classifier
# ============================================================================

class SimpleAE(nn.Module):
    def __init__(self, input_dim=384, latent_dim=128):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, input_dim)
        )

    def forward(self, x):
        latent = self.encoder(x)
        reconstructed = self.decoder(latent)
        return reconstructed

# ============================================================================
# Training Utilities
# ============================================================================

def setup_logging(log_dir: str, downstream_name: str, fold: int):
    """Setup logging for downstream task"""
    log_dir = Path(log_dir) / downstream_name
    log_dir.mkdir(exist_ok=True, parents=True)
    
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_dir / f'fold_{fold:02d}.log'),
            logging.StreamHandler()
        ]
    )
    return logging.getLogger(__name__)

def load_pretrained_model(model: EATMAE, pretrained_path: str, device: str):
    """Load pretrained model"""
    if not os.path.exists(pretrained_path):
        raise FileNotFoundError(f"Pretrained model not found: {pretrained_path}")
    
    print(f"Loading pretrained model: {pretrained_path}")
    
    try:
        # Add weights_only=False for PyTorch 2.6 compatibility
        checkpoint = torch.load(pretrained_path, map_location=device, weights_only=False)
        
        # Check checkpoint format
        if 'model_state_dict' in checkpoint:
            state_dict = checkpoint['model_state_dict']
        else:
            state_dict = checkpoint
        
        # Load model
        model.load_state_dict(state_dict, strict=True)
        print("Pretrained model loaded successfully")
        
        # Set model to evaluation mode and disable gradients
        model.eval()
        for param in model.parameters():
            param.requires_grad = False
            
        return True
        
    except Exception as e:
        print(f"Model loading failed: {e}")
        return False

def get_scheduler(optimizer, cfg):
    """Create learning rate scheduler"""
    if cfg.scheduler_type == "plateau":
        return torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=cfg.lr_decay_factor, 
            patience=cfg.lr_decay_patience, min_lr=cfg.lr_min
        )
    elif cfg.scheduler_type == "cosine":
        return torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=cfg.epochs, eta_min=cfg.lr_min
        )
    elif cfg.scheduler_type == "step":
        return torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=30, gamma=cfg.lr_decay_factor
        )
    else:
        return None

# ============================================================================
# Training Function
# ============================================================================

def train_downstream_model(fold: int, downstream: str, csv_path: str, cfg: DownstreamConfig):
    """Train downstream task model"""
    
    # Setup logging
    logger = setup_logging(cfg.log_dir, downstream, fold)
    logger.info(f"=== Starting downstream training: {downstream}, Fold {fold} ===")
    
    # Set seed
    if cfg.seed is not None:
        torch.manual_seed(cfg.seed)
        np.random.seed(cfg.seed)
        random.seed(cfg.seed)
    
    # Load dataset
    logger.info(f"Loading dataset: {csv_path}")
    try:
        dataset = DownstreamDataset(csv_path, cfg)
        dataloader = DataLoader(
            dataset, 
            batch_size=cfg.batch_size, 
            shuffle=True,
            num_workers=cfg.num_workers, 
            pin_memory=(cfg.device == "cuda"), 
        )
        logger.info(f"Dataset loaded successfully - Classes: {len(dataset.class_to_idx)}, Samples: {len(dataset)}")
    except Exception as e:
        logger.error(f"Dataset loading failed: {e}")
        return
    
    # Initialize models
    logger.info("Initializing models...")
    try:
        # Spectrogram converter
        spec_converter = SpecConverter(cfg).to(cfg.device)
        
        # Pretrained EATMAE model (feature extractor)
        feature_extractor = EATMAE(cfg).to(cfg.device)
        
        # Load pretrained model
        if not load_pretrained_model(feature_extractor, cfg.pretrained_model_path, cfg.device):
            logger.error("Failed to load pretrained model")
            return
            
        # Downstream classifier
        num_classes = len(dataset.class_to_idx)
        ae = SimpleAE().to(cfg.device)
        ae.train()
        
    except Exception as e:
        logger.error(f"Model initialization failed: {e}")
        return
    
    # Setup optimizer and scheduler
    optimizer = torch.optim.AdamW(ae.parameters(), lr=cfg.lr, weight_decay=cfg.wd)
    criterion = nn.MSELoss()
    scheduler = get_scheduler(optimizer, cfg)
    
    # Create save directory
    save_dir = Path(cfg.save_dir) / downstream
    save_dir.mkdir(exist_ok=True, parents=True)
    
    # Training loop
    logger.info("Starting training...")
    best_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(cfg.epochs):
        feature_extractor.eval()  # Always keep in eval mode
        
        running_loss = 0.0        
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{cfg.epochs}")
        
        for batch_idx, (waveform, label, _) in enumerate(progress_bar):
            waveform = waveform.to(cfg.device)
            
            # Forward pass
            with torch.no_grad():
                spec = spec_converter(waveform)
                features = feature_extractor.forward_features(spec)
            
            # Classification
            reconstructed_features = ae(features)
            loss = criterion(reconstructed_features, features)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Update metrics
            running_loss += loss.item()
            progress_bar.set_postfix({
                'batch_loss': f"{loss.item():.8f}",
            })
        
        # Epoch results
        avg_loss = running_loss / len(dataloader)
        current_lr = optimizer.param_groups[0]['lr']
        print(f"[Epoch {epoch+1}/{cfg.epochs}] AE Reconstruction Loss: {avg_loss:.8f}, LR: {current_lr:.8f}")

        # Update scheduler
        if scheduler:
            if cfg.scheduler_type == "plateau":
                scheduler.step(avg_loss)
            else:
                scheduler.step()
        
    # Save final model
    save_path = save_dir / f"{downstream}_auc_fold_{fold:02d}_{cfg.dataset}.pth"
    torch.save({
        'epoch': epoch,
        'model_state_dict': ae.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_loss,
        'num_classes': num_classes,
        'class_to_idx': dataset.class_to_idx
    }, save_path)

    logger.info(f"Final model saved: {save_path} (Loss: {avg_loss:.4f})")
    logger.info(f"=== Downstream training complete: {downstream}, Fold {fold} ===")
    logger.info(f"Best performance: Loss {best_loss:.4f}")

# ============================================================================
# Main Function
# ============================================================================

def set_deterministic_mode(seed: int = 42):
    """Set deterministic mode for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
    torch.backends.cuda.enable_mem_efficient_sdp(False)

    try:
        torch.use_deterministic_algorithms(True, warn_only=True)
    except AttributeError:
        pass

def main():
    """Main function"""
       
    # Configure settings
    cfg = DownstreamConfig()

    # Check pretrained model existence
    if not os.path.exists(cfg.pretrained_model_path):
        print(f"Uncertain: Pretrained model file does not exist: {cfg.pretrained_model_path}")
        return
    
    downstream_input = input("Select downstream task (1.ColdSpray, 2.Yornew): ")
    try:
        downstream_input = int(downstream_input)
        downstreams_map = {1: 'ColdSpray', 2: 'Yornew'}
        if downstream_input not in downstreams_map:
            print("Invalid selection.")
            return
        downstreams = [downstreams_map[downstream_input]]
    except ValueError:
        print("Please enter a number.")
        return

    folds = list(range(1, 11))  # All folds 1-10
        
    set_deterministic_mode(cfg.seed or 42)

    # Execute training
    for downstream in downstreams:
        print(f"\n=== Starting {downstream} Downstream Task ===")
        
        for fold in folds:
            print(f"\n--- Training Fold {fold}... ---")
            
            csv_path = Path(cfg.dataset_root) / downstream / f"train_auc_list_{fold}.csv"
            
            if not csv_path.exists():
                print(f"I don't know: CSV file not found: {csv_path}")
                continue
            
            try:
                train_downstream_model(fold, downstream, str(csv_path), cfg)
                print(f"Fold {fold} training complete")
                print("=" * 50)
                
            except Exception as e:
                print(f"Fold {fold} training failed: {e}")
                import traceback
                traceback.print_exc()
                continue
        
        print(f"{downstream} all fold training complete")

if __name__ == "__main__":
    main()