from __future__ import annotations
import os
import random
import time
import logging
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import Tuple, Optional, Dict, List
from collections import defaultdict

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchaudio
import csv

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, 
    average_precision_score, roc_auc_score, confusion_matrix, ConfusionMatrixDisplay
)

# Suppress warnings
warnings.filterwarnings("ignore", category=UserWarning, module="torchaudio._backend")
warnings.filterwarnings("ignore", message=".*StreamingMediaDecoder.*")
warnings.filterwarnings("ignore", message=".*torchaudio.load_with_torchcodec.*")

# ============================================================================
# Configuration (same as training code)
# ============================================================================

@dataclass
class TestConfig:
    # Audio settings (keep same as training code)
    dataset: str = 'Hybrid'
    sr: int = 48_000
    sample_len: int = 48_000
    n_fft: int = 2048
    win_length: int = 2048
    hop_length: int = 376
    n_mels: int = 128
    top_db: int = 80
    
    # Model architecture (same as training code)
    mr = 0.4
    ut = 0.2
    base_channels: int = 32
    patch_size: Tuple[int, int] = (16, 16)
    embed_dim: int = 384
    depth: int = 8
    nhead: int = 16
    decoder_dim: int = 256
    mask_ratio: float = mr
    dropout: float = 0.0
    utter_loss_weight: float = ut

    # Patch Embedding settings (same as training code)
    patch_embed_use_bn: bool = False
    patch_embed_use_activation: bool = False
    patch_embed_activation: str = 'none'
    patch_embed_bias: bool = False
    
    # Test settings
    batch_size: int = 16
    num_workers: int = 16
    seed: Optional[int] = 42
    device: str = field(default_factory=lambda: "cuda:1" if torch.cuda.is_available() else "cpu")
    
    # Paths
    pretrained_model_path: str = f"./IMPACT_Models/impact_{dataset}_epoch_0010.pth"
    downstream_model_dir: str = "./IMPACT_Models"
    dataset_root: str = "./Datasets/DINOS_Downstreams"
    results_dir: str = f"./IMPACT_Results/{dataset}"

# ============================================================================
# Model Components (from training code)
# ============================================================================

def build_2d_sincos_pos_embed(d_model: int, grid_f: int, grid_t: int) -> torch.Tensor:
    """Build 2D sinusoidal position embedding"""
    def _pe_1d(dim, pos):
        omega = torch.arange(dim // 2) / (dim // 2)
        omega = 1.0 / (10000 ** omega)
        out = pos.float().unsqueeze(1) * omega.unsqueeze(0)
        return torch.cat([out.sin(), out.cos()], dim=1)
    
    f, t = torch.meshgrid(torch.arange(grid_f), torch.arange(grid_t), indexing="ij")
    return torch.cat([_pe_1d(d_model // 2, f.reshape(-1)),
                      _pe_1d(d_model // 2, t.reshape(-1))], dim=1)

class MaskingGenerator:
    """Token masking generator"""
    def __init__(self, mask_ratio: float):
        self.mask_ratio = mask_ratio
        
    def __call__(self, B: int, N: int, device):
        keep = int(N * (1 - self.mask_ratio))
        noise = torch.rand(B, N, device=device)
        idx_shuffle = torch.argsort(noise, dim=1)
        idx_restore = torch.argsort(idx_shuffle, dim=1)
        return idx_shuffle[:, :keep], idx_shuffle[:, keep:], idx_restore

class CNNEncoder(nn.Module):
    """CNN encoder for initial feature extraction"""
    def __init__(self, c=32):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, c, 3, 2, 1),
            nn.BatchNorm2d(c),
            nn.ReLU()
        )
        self.out_c = c
        
    def forward(self, x):
        B, C, F, T = x.shape
        F_new, T_new = (F // 2) * 2, (T // 2) * 2
        x = x[:, :, :F_new, :T_new]
        return self.net(x)

class OptimalPatchEmbed(nn.Module):
    """Optimized patch embedding with Conv2d and optional BatchNorm/Activation"""
    def __init__(self, c_in: int, d_model: int, patch: Tuple[int, int], 
                 use_bn: bool = True, use_activation: bool = True, 
                 activation_type: str = 'gelu', bias: bool = False):
        super().__init__()
        self.p = patch
        
        # Conv2d for patch creation - more efficient than unfold + linear
        self.net = nn.Conv2d(c_in, d_model, kernel_size=patch, stride=patch, bias=bias)
        
        # Optional BatchNorm for training stability
        self.bn = nn.BatchNorm2d(d_model) if use_bn else nn.Identity()
        
        # Optional Activation for non-linearity
        if use_activation:
            if activation_type == 'gelu':
                self.act = nn.GELU()
            elif activation_type == 'relu':
                self.act = nn.ReLU()
            elif activation_type == 'silu':
                self.act = nn.SiLU()
            else:
                self.act = nn.Identity()
        else:
            self.act = nn.Identity()
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        """Initialize weights with truncated normal distribution"""
        w = self.net.weight.data
        nn.init.trunc_normal_(w.view(w.size(0), -1), std=0.02)
        if self.net.bias is not None:
            nn.init.zeros_(self.net.bias)
    
    def forward(self, x):
        """
        Args:
            x: (B, C, F, T) input tensor
        Returns:
            patches: (B, N, D) where N = H*W is number of patches
            grid_size: (H, W) grid dimensions
        """
        B, C, F, T = x.shape
        H, W = F // self.p[0], T // self.p[1]
        
        # Efficient patch extraction with Conv2d
        x = self.net(x)  # (B, d_model, H, W)
        
        # Apply BatchNorm if enabled
        x = self.bn(x)
        
        # Apply activation if enabled
        x = self.act(x)
        
        # Reshape for transformer input: (B, d_model, H, W) -> (B, H*W, d_model)
        x = x.flatten(2).transpose(1, 2)
        
        return x, (H, W)

class Encoder(nn.Module):
    def __init__(self, d_model, nhead, depth, drop=0., return_all=False):
        super().__init__()
        self.return_all = return_all
        self.cls = nn.Parameter(torch.zeros(1, 1, d_model))
        self.dp = nn.Dropout(drop)
        self.blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model, nhead, d_model * 4, drop,
                batch_first=True, activation='gelu'
            ) for _ in range(depth)
        ])

    def forward(self, x):
        B = x.size(0)
        x = torch.cat([self.cls.expand(B, -1, -1), self.dp(x)], 1)  # (B, 1+N, d)
        if not self.return_all:
            for blk in self.blocks:
                x = blk(x)
            return x  # (B, 1+N, d)
        else:
            layer_tokens = []   # 전체 토큰 저장
            cls_list = []       # CLS만 저장
            for blk in self.blocks:
                x = blk(x)                  # (B, 1+N, d)
                layer_tokens.append(x[:, 1:, :])  # 패치 토큰만 (B, N, d)
                cls_list.append(x[:, 0, :])       # CLS 토큰 (B, d)
            return x, layer_tokens, cls_list
        
class CNNDecoder(nn.Module):
    """CNN decoder for reconstruction"""
    def __init__(self, patch_dim, base=128, out=1):
        super().__init__()
        self.base = base
        self.fc = nn.Linear(patch_dim, base)
        self.up = nn.Sequential(
            nn.ConvTranspose2d(base, base, 4, 2, 1),
            nn.BatchNorm2d(base),
            nn.ReLU(),
            nn.ConvTranspose2d(base, base // 2, 4, 2, 1),
            nn.BatchNorm2d(base // 2),
            nn.ReLU(),
            nn.ConvTranspose2d(base // 2, base // 4, 4, 2, 1),
            nn.BatchNorm2d(base // 4),
            nn.ReLU(),
            nn.ConvTranspose2d(base // 4, base // 8, 4, 2, 1),
            nn.BatchNorm2d(base // 8),
            nn.ReLU(),
            nn.ConvTranspose2d(base // 8, out, 4, 2, 1)
        )
        
    def forward(self, x, H, W):
        B, N, _ = x.shape
        x = self.fc(x).transpose(1, 2).contiguous().view(B, self.base, H, W)
        return self.up(x)

class EATMAE(nn.Module):
    """EATMAE with OptimalPatchEmbed for feature extraction"""
    def __init__(self, cfg: TestConfig):
        super().__init__()
        self.cfg = cfg
        self.masker = MaskingGenerator(cfg.mask_ratio)
        
        # Model components
        self.cnn = CNNEncoder(cfg.base_channels)
        
        # Use OptimalPatchEmbed
        self.embed = OptimalPatchEmbed(
            c_in=self.cnn.out_c,
            d_model=cfg.embed_dim,
            patch=cfg.patch_size,
            use_bn=cfg.patch_embed_use_bn,
            use_activation=cfg.patch_embed_use_activation,
            activation_type=cfg.patch_embed_activation,
            bias=cfg.patch_embed_bias
        )
        
        self.enc = Encoder(cfg.embed_dim, cfg.nhead, cfg.depth, cfg.dropout, return_all=False)
        
        # Position embedding cache
        self.pos_embed_cache = {}
        self._precompute_common_sizes()
        
        # Decoder components (not used during inference but needed for model loading)
        self.enc2dec = nn.Linear(cfg.embed_dim, cfg.decoder_dim, bias=False)
        self.mask_tok = nn.Parameter(torch.zeros(1, 1, cfg.decoder_dim))
        self.dec = CNNDecoder(cfg.decoder_dim)
        
        # Initialize weights
        self._init_weights()
        
    def _init_weights(self):
        """Initialize model weights"""
        nn.init.normal_(self.mask_tok, std=0.02)
        nn.init.xavier_uniform_(self.enc2dec.weight)
        
    def _precompute_common_sizes(self):
        """Precompute position embeddings for common sizes"""
        common_sizes = [(8, 8), (7, 8), (8, 7), (6, 8), (8, 6), (4, 4)]
        for H, W in common_sizes:
            self.pos_embed_cache[(H, W)] = self._create_pos_embed(H, W)
            
    def _create_pos_embed(self, H, W):
        """Create position embeddings"""
        De, Dd = self.cfg.embed_dim, self.cfg.decoder_dim
        pe_enc = build_2d_sincos_pos_embed(De, H, W)
        pe_dec = build_2d_sincos_pos_embed(Dd, H, W)
        pe_enc = torch.cat([torch.zeros(1, De), pe_enc], 0)[None]
        pe_dec = torch.cat([torch.zeros(1, Dd), pe_dec], 0)[None]
        return pe_enc, pe_dec
    
    def get_pos_embed(self, H, W, device, dtype):
        """Get cached or new position embedding"""
        key = (H, W)
        if key not in self.pos_embed_cache:
            self.pos_embed_cache[key] = self._create_pos_embed(H, W)
        pe_enc, pe_dec = self.pos_embed_cache[key]
        return pe_enc.to(device, dtype), pe_dec.to(device, dtype)
    
    def forward_features(self, spec):
        """Extract features without reconstruction (for downstream tasks)"""
        B = spec.size(0)
        
        # Feature extraction with CNN
        feat = self.cnn(spec)
        
        # Patch embedding
        tokens, (H, W) = self.embed(feat)
        N, D = tokens.shape[1:]
        
        # Get position embeddings
        pos_enc, _ = self.get_pos_embed(H, W, tokens.device, tokens.dtype)
        tokens = tokens + pos_enc[:, 1:, :]
        
        # No masking for downstream tasks - use all tokens
        # Encoding
        enc_out = self.enc(tokens)
        cls_token = enc_out[:, 0, :]  # CLS token contains global representation
        
        return cls_token

class SpecConverter(nn.Module):
    """GPU-based spectrogram converter"""
    def __init__(self, cfg: TestConfig):
        super().__init__()
        self.cfg = cfg
        self.mel = torchaudio.transforms.MelSpectrogram(
            sample_rate=cfg.sr,
            n_fft=cfg.n_fft,
            win_length=cfg.win_length,
            hop_length=cfg.hop_length,
            n_mels=cfg.n_mels,
            power=2.0
        )
        self.db = torchaudio.transforms.AmplitudeToDB("power", top_db=cfg.top_db)
        
    def forward(self, waveform):
        """Convert waveform to mel spectrogram on GPU"""
        if waveform.dim() == 2:
            waveform = waveform.unsqueeze(1)  # (B, T) -> (B, 1, T)
        
        # Normalize waveform
        rms = torch.sqrt((waveform**2).mean(dim=-1, keepdim=True))
        waveform = waveform / (rms + 1e-6)
        waveform = (waveform - waveform.mean(dim=-1, keepdim=True)) / (waveform.std(dim=-1, keepdim=True) + 1e-6)
        
        # Convert to mel spectrogram
        spec = self.mel(waveform)
        spec = self.db(spec)
        
        # Normalize to [0, 1]
        spec = spec[:, :, :, :128]  # Ensure time dimension is 128
        spec = (spec + self.cfg.top_db) / (2 * self.cfg.top_db)
        
        return spec

class DownstreamClassifier(nn.Module):
    """MLP classifier for downstream tasks (same as training code)"""
    def __init__(self, embed_dim: int = 384, num_classes: int = 4, 
                 hidden_dim: int = 256, dropout: float = 0.1):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes)
        )
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        """Initialize classifier weights"""
        for module in self.classifier.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                nn.init.zeros_(module.bias)
    
    def forward(self, features):
        return self.classifier(features)

# ============================================================================
# Dataset (same method as training code)
# ============================================================================

class TestDataset(Dataset):
    """Test dataset - uses same audio loading method as training code"""
    def __init__(self, csv_path: str | os.PathLike, cfg: TestConfig, max_retries: int = 3):
        super().__init__()
        self.cfg = cfg
        self.sample_length = cfg.sample_len
        self.max_retries = max_retries
        self.error_count = defaultdict(int)
        self.index_list = []
        self.class_to_idx = {}
        self.csv_path = csv_path
        class_names = set()

        if not os.path.exists(csv_path):
            raise FileNotFoundError(f"CSV file not found: {csv_path}")

        if cfg.seed is not None:
            random.seed(cfg.seed)

        # Read CSV file
        logging.info(f"Starting to read CSV file: {csv_path}")
        with open(self.csv_path, newline="") as f:
            reader = csv.DictReader(f)
            for row in reader:
                # Check file existence
                filepath = row["filepath"]
                if not os.path.exists(filepath):
                    logging.warning(f"File does not exist: {filepath}")
                    continue
                    
                self.index_list.append(
                    (filepath, int(row["start_sample"]),
                     int(row["label"]), row["class_name"])
                )
                class_names.add(row["class_name"])

        self.class_to_idx = {name: idx for idx, name in enumerate(sorted(class_names))}
        
        logging.info(f"Dataset loading complete - Total {len(self.index_list)} samples, {len(class_names)} classes")

    def __len__(self):
        return len(self.index_list)
    
    def __getitem__(self, idx):
        for retry in range(self.max_retries):
            try:
                fp, start, label, folder = self.index_list[idx]
                
                # Check file existence
                if not os.path.exists(fp):
                    raise FileNotFoundError(f"File not found: {fp}")
                
                # Suppress warnings while loading audio (same method as training code)
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    
                    # Load audio using torchaudio
                    try:
                        audio, _sr = torchaudio.load(
                            fp, 
                            frame_offset=start,
                            num_frames=self.cfg.sample_len
                        )
                    except Exception as e:
                        # fallback: load full file then slice
                        audio, _sr = torchaudio.load(fp)
                        end_frame = start + self.cfg.sample_len
                        audio = audio[:, start:end_frame]
                
                # Handle sample rate mismatch (same method as training code)
                if _sr != self.cfg.sr:
                    resampler = torchaudio.transforms.Resample(_sr, self.cfg.sr)
                    audio = resampler(audio)
                
                # Convert to numpy and handle channels (same method as training code)
                audio = audio.numpy()
                
                # Convert to mono if stereo
                if audio.shape[0] > 1:
                    audio = audio.mean(0)  # Average across channel dimension
                else:
                    audio = audio[0]  # Remove dimension for single channel
                
                # Pad if necessary
                audio = np.pad(audio, (0, max(0, self.cfg.sample_len - len(audio))))[:self.cfg.sample_len]
                
                # Convert to tensor
                waveform = torch.tensor(audio, dtype=torch.float32)
                
                return waveform, self.class_to_idx[folder], folder
                
            except Exception as e:
                self.error_count[fp] += 1
                logging.error(f"Error loading {fp} (attempt {retry+1}/{self.max_retries}): {e}")
                
                if retry == self.max_retries - 1:
                    logging.error(f"Max retries exceeded for {fp}, using next sample")
                    return self.__getitem__((idx + 1) % len(self))
                    
                time.sleep(0.1 * (retry + 1))

# ============================================================================
# Model Loading Functions
# ============================================================================

def load_pretrained_model(model: EATMAE, pretrained_path: str, device: str):
    """Load pretrained EATMAE model"""
    if not os.path.exists(pretrained_path):
        raise FileNotFoundError(f"Pretrained model not found: {pretrained_path}")
    
    print(f"Loading pretrained model: {pretrained_path}")
    
    try:
        # Add weights_only=False for PyTorch 2.6 compatibility
        checkpoint = torch.load(pretrained_path, map_location=device, weights_only=False)
        
        # Check checkpoint format
        if 'model_state_dict' in checkpoint:
            state_dict = checkpoint['model_state_dict']
        else:
            state_dict = checkpoint
        
        # Load model
        model.load_state_dict(state_dict, strict=True)
        print("Pretrained model loaded successfully")
        
        # Set model to evaluation mode
        model.eval()
        for param in model.parameters():
            param.requires_grad = False
            
        return True
        
    except Exception as e:
        print(f"Model loading failed: {e}")
        return False

def load_downstream_model(classifier: DownstreamClassifier, model_path: str, device: str):
    """Load trained downstream model"""
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Downstream model not found: {model_path}")
    
    print(f"Loading downstream model: {model_path}")
    
    try:
        # Add weights_only=False for PyTorch 2.6 compatibility
        checkpoint = torch.load(model_path, map_location=device, weights_only=False)
        
        # Load model state from checkpoint
        classifier.load_state_dict(checkpoint['model_state_dict'])
        
        # Extract metadata
        metadata = {
            'num_classes': checkpoint.get('num_classes', None),
            'class_to_idx': checkpoint.get('class_to_idx', {}),
            'epoch': checkpoint.get('epoch', None),
            'loss': checkpoint.get('loss', None),
            'accuracy': checkpoint.get('accuracy', None)
        }
        
        print(f"Downstream model loaded successfully - Classes: {metadata['num_classes']}")
        
        # Set model to evaluation mode
        classifier.eval()
        
        return metadata
        
    except Exception as e:
        print(f"Downstream model loading failed: {e}")
        return None

# ============================================================================
# Test Function with Inference Time Measurement
# ============================================================================

def test_downstream_model(fold: int, downstream: str, csv_path: str, cfg: TestConfig):
    """Test downstream model - with inference time measurement"""
    
    print(f"\n==================== Fold {fold} Test ====================")
    
    # Set seed
    if cfg.seed is not None:
        torch.manual_seed(cfg.seed)
        np.random.seed(cfg.seed)
        random.seed(cfg.seed)
    
    # Load dataset
    try:
        dataset = TestDataset(csv_path, cfg)
        dataloader = DataLoader(
            dataset, 
            cfg.batch_size, 
            shuffle=False,  # No shuffle during test
            num_workers=cfg.num_workers, 
            pin_memory=(cfg.device == "cuda"), 
        )
        print(f"Dataset loaded successfully - Classes: {len(dataset.class_to_idx)}, Samples: {len(dataset)}")
    except Exception as e:
        print(f"Dataset loading failed: {e}")
        return None
    
    # Initialize models
    try:
        # Spectrogram converter
        spec_converter = SpecConverter(cfg).to(cfg.device)
        
        # Pretrained EATMAE model (feature extractor)
        feature_extractor = EATMAE(cfg).to(cfg.device)
        
        # Load pretrained model
        if not load_pretrained_model(feature_extractor, cfg.pretrained_model_path, cfg.device):
            print("Failed to load pretrained model")
            return None
            
        # Downstream classifier
        num_classes = len(dataset.class_to_idx)
        classifier = DownstreamClassifier(
            embed_dim=cfg.embed_dim,
            num_classes=num_classes,
            hidden_dim=256,  # Use default value
            dropout=0.0      # Disable dropout during test
        ).to(cfg.device)
        
        # Load downstream model
        model_path = Path(cfg.downstream_model_dir) / downstream / f"{downstream}_f1_fold_{fold:02d}_{cfg.dataset}.pth"
        metadata = load_downstream_model(classifier, str(model_path), cfg.device)
        if metadata is None:
            print("Failed to load downstream model")
            return None
            
        print(f"Model initialization complete - Classes: {num_classes}")
        
    except Exception as e:
        print(f"Model initialization failed: {e}")
        return None
    
    # Run test with time measurement
    feature_extractor.eval()
    classifier.eval()
    
    overall_y_true = []
    overall_y_pred = []
    overall_y_score = []
    
    # Time measurement variables
    total_test_time = 0
    spec_conversion_times = []
    feature_extraction_times = []
    classification_times = []
    batch_inference_times = []
    total_samples = 0
    
    print("Running test...")
    
    # Start total test time measurement
    test_start_time = time.time()
    
    with torch.no_grad():
        for batch_idx, (waveform, label, folder_names) in enumerate(dataloader):
            batch_start_time = time.time()
            
            waveform = waveform.to(cfg.device)
            label = label.to(cfg.device)
            current_batch_size = waveform.size(0)
            
            # 1. Measure spectrogram conversion time
            spec_start_time = time.time()
            spec = spec_converter(waveform)
            if cfg.device == "cuda":
                torch.cuda.synchronize()  # GPU synchronization
            spec_end_time = time.time()
            spec_conversion_time = spec_end_time - spec_start_time
            spec_conversion_times.append(spec_conversion_time)
            
            # 2. Measure feature extraction time
            feat_start_time = time.time()
            features = feature_extractor.forward_features(spec)
            if cfg.device == "cuda":
                torch.cuda.synchronize()  # GPU synchronization
            feat_end_time = time.time()
            feature_extraction_time = feat_end_time - feat_start_time
            feature_extraction_times.append(feature_extraction_time)
            
            # 3. Measure classification time
            cls_start_time = time.time()
            logits = classifier(features)
            if cfg.device == "cuda":
                torch.cuda.synchronize()  # GPU synchronization
            cls_end_time = time.time()
            classification_time = cls_end_time - cls_start_time
            classification_times.append(classification_time)
            
            # Total batch processing time
            batch_end_time = time.time()
            batch_inference_time = batch_end_time - batch_start_time
            batch_inference_times.append(batch_inference_time)
            
            # Calculate prediction results
            _, predicted_labels = torch.max(logits.data, 1)
            predicted_probs = logits.softmax(dim=1)
            
            # Move to CPU and save
            true_labels_cpu = label.cpu().numpy()
            predicted_labels_cpu = predicted_labels.cpu().numpy()
            predicted_probs_cpu = predicted_probs.cpu().numpy()
            
            # Add to overall results
            for idx_in_batch in range(len(true_labels_cpu)):
                overall_y_true.append(true_labels_cpu[idx_in_batch])
                overall_y_pred.append(predicted_labels_cpu[idx_in_batch])
                overall_y_score.append(predicted_probs_cpu[idx_in_batch])
            
            total_samples += current_batch_size
            
            # Print progress (every 10 batches)
            if (batch_idx + 1) % 10 == 0:
                avg_batch_time = np.mean(batch_inference_times[-10:])
                print(f"Batch {batch_idx + 1}/{len(dataloader)} complete - Avg batch time: {avg_batch_time:.4f}s")
    
    # Complete total test time measurement
    test_end_time = time.time()
    total_test_time = test_end_time - test_start_time
    
    # Convert result arrays
    y_true = np.array(overall_y_true)
    y_pred = np.array(overall_y_pred)
    y_score = np.vstack(overall_y_score)
    
    # Prepare class information
    class_labels = np.unique(y_true)
    idx_to_class = {v: k for k, v in dataset.class_to_idx.items()}
    
    print(f"Test complete - Processed {len(y_true)} samples")
    
    # Calculate time statistics
    time_stats = {
        'total_test_time': total_test_time,
        'total_samples': total_samples,
        'avg_spec_conversion_time': np.mean(spec_conversion_times),
        'avg_feature_extraction_time': np.mean(feature_extraction_times), 
        'avg_classification_time': np.mean(classification_times),
        'avg_batch_inference_time': np.mean(batch_inference_times),
        'avg_sample_inference_time': np.mean(batch_inference_times) / cfg.batch_size,
        'std_spec_conversion_time': np.std(spec_conversion_times),
        'std_feature_extraction_time': np.std(feature_extraction_times),
        'std_classification_time': np.std(classification_times),
        'std_batch_inference_time': np.std(batch_inference_times),
        'total_inference_time': np.sum(batch_inference_times),
        'samples_per_second': total_samples / total_test_time
    }
    
    # Print time statistics
    print(f"\n==================== Inference Time Statistics ====================")
    print(f"Total test time           : {time_stats['total_test_time']:.4f}s")
    print(f"Total samples             : {time_stats['total_samples']}")
    print(f"Samples per second        : {time_stats['samples_per_second']:.2f} samples/sec")
    print(f"Avg inference per sample  : {time_stats['avg_sample_inference_time']:.4f}s")
    print(f"Avg inference per batch   : {time_stats['avg_batch_inference_time']:.4f}s")
    print(f"\n[Step-wise Average Time]")
    print(f"Spectrogram conversion    : {time_stats['avg_spec_conversion_time']:.4f}s (±{time_stats['std_spec_conversion_time']:.4f})")
    print(f"Feature extraction        : {time_stats['avg_feature_extraction_time']:.4f}s (±{time_stats['std_feature_extraction_time']:.4f})")
    print(f"Classification            : {time_stats['avg_classification_time']:.4f}s (±{time_stats['std_classification_time']:.4f})")
    
    # Calculate metrics
    aps = []
    aurocs = []
    class_metrics = []
    
    print(f"\n[Per-class Performance]")
    for c in class_labels:
        y_true_c = (y_true == c).astype(int)
        y_pred_c = (y_pred == c).astype(int)
        y_score_c = y_score[:, int(c)]
        class_name = idx_to_class[int(c)]
        
        # Calculate metrics
        acc = accuracy_score(y_true_c, y_pred_c)
        prec = precision_score(y_true_c, y_pred_c, zero_division=0)
        rec = recall_score(y_true_c, y_pred_c, zero_division=0)
        f1 = f1_score(y_true_c, y_pred_c, zero_division=0)
        ap = average_precision_score(y_true_c, y_score_c)
        auroc = roc_auc_score(y_true_c, y_score_c)
        
        # TP count and support
        tp = int(np.sum((y_true_c == 1) & (y_pred_c == 1)))
        support = int(np.sum(y_true_c))
        
        aps.append(ap)
        aurocs.append(auroc)
        class_metrics.append({
            'class': class_name,
            'precision': prec,
            'recall': rec,
            'f1': f1,
            'ap': ap,
            'auroc': auroc
        })
        
        print(
            f"Class {class_name} | TP: {tp}/{support} | "
            f"Acc: {acc:.4f} | Prec: {prec:.4f} | Rec: {rec:.4f} | "
            f"F1: {f1:.4f} | AP: {ap:.4f} | AUROC: {auroc:.4f}"
        )
    
    # Calculate overall performance metrics
    acc_total = accuracy_score(y_true, y_pred)
    prec_macro = precision_score(y_true, y_pred, average='macro', zero_division=0)
    rec_macro = recall_score(y_true, y_pred, average='macro', zero_division=0)
    f1_macro = f1_score(y_true, y_pred, average='macro', zero_division=0)
    
    prec_weighted = precision_score(y_true, y_pred, average='weighted', zero_division=0)
    rec_weighted = recall_score(y_true, y_pred, average='weighted', zero_division=0)
    f1_weighted = f1_score(y_true, y_pred, average='weighted', zero_division=0)
    
    mean_ap = np.mean(aps)
    mean_auroc = np.mean(aurocs)
    
    print(f"\nmAP (mean AP over {len(class_labels)} classes): {mean_ap:.4f}")
    print(f"\n[Overall Multi-class Classification Performance]")
    print(f"Accuracy       : {acc_total:.4f}")
    print(f"Macro  - Prec  : {prec_macro:.4f} | Recall: {rec_macro:.4f} | F1: {f1_macro:.4f}")
    print(f"Weighted- Prec : {prec_weighted:.4f} | Recall: {rec_weighted:.4f} | F1: {f1_weighted:.4f}")
    print(f"mAP  (mean AP over {len(class_labels)} classes): {mean_ap:.4f}")
    print(f"AUROC (mean AUROC over {len(class_labels)} classes): {mean_auroc:.4f}")
    
    # Save results
    iter_dir = Path(cfg.results_dir) / downstream / f"fold_{fold}"
    iter_dir.mkdir(exist_ok=True, parents=True)
    
    # Save per-class metrics
    df_class_iter = pd.DataFrame(class_metrics)
    df_class_iter.to_csv(iter_dir / "IMPACT_f1_per_class_metrics.csv", index=False)
    
    # Save overall metrics
    overall_iter_metrics = {
        "accuracy": acc_total,
        "precision_macro": prec_macro,
        "recall_macro": rec_macro,
        "f1_macro": f1_macro,
        "precision_weighted": prec_weighted,
        "recall_weighted": rec_weighted,
        "f1_weighted": f1_weighted,
        "mAP": mean_ap,
        "AUROC": mean_auroc
    }
    pd.DataFrame([overall_iter_metrics]).to_csv(
        iter_dir / "IMPACT_f1_overall_metrics.csv", index=False
    )
    
    # Save time statistics
    time_stats_df = pd.DataFrame([time_stats])
    time_stats_df.to_csv(iter_dir / "IMPACT_f1_inference_time_stats.csv", index=False)
    
    # Save detailed batch times
    batch_times_df = pd.DataFrame({
        'batch_idx': range(len(batch_inference_times)),
        'spec_conversion_time': spec_conversion_times,
        'feature_extraction_time': feature_extraction_times,
        'classification_time': classification_times,
        'total_batch_time': batch_inference_times
    })
    batch_times_df.to_csv(iter_dir / "IMPACT_f1_batch_inference_times.csv", index=False)
    
    # Generate and save Confusion Matrix
    cm = confusion_matrix(y_true, y_pred, labels=sorted(class_labels))
    cm_all = cm.astype(np.int32)
    
    disp = ConfusionMatrixDisplay(
        confusion_matrix=cm, 
        display_labels=[idx_to_class[c] for c in sorted(class_labels)]
    )
    fig, ax = plt.subplots(figsize=(6, 6))
    disp.plot(ax=ax, cmap="Blues", xticks_rotation=45, colorbar=False, values_format='d')
    plt.title(f"Confusion Matrix - {downstream} Fold {fold}")
    plt.tight_layout()
    
    confmat_path = iter_dir / "confusion_matrix.png"
    plt.savefig(confmat_path)
    plt.close()
    
    print(f"Results saved: {iter_dir}")
    
    # Add time statistics to return value
    result = {
        'accuracy': acc_total,
        'macro_f1': f1_macro,
        'weighted_f1': f1_weighted,
        'mAP': mean_ap,
        'AUROC': mean_auroc,
        'class_metrics': class_metrics,
        'confusion_matrix': cm_all,
        'time_stats': time_stats  # Add time statistics
    }
    
    return result

# ============================================================================
# Main Function with Aggregated Time Statistics
# ============================================================================

def set_deterministic_mode(seed: int = 42):
    """Set deterministic mode for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
    torch.backends.cuda.enable_mem_efficient_sdp(False)

    try:
        torch.use_deterministic_algorithms(True, warn_only=True)
    except AttributeError:
        pass

def main():
    """Main function - with aggregated inference time statistics"""
    
    # Configure settings
    cfg = TestConfig()
    
    # Check required directory existence
    if not os.path.exists(cfg.pretrained_model_path):
        print(f"Uncertain: Pretrained model file does not exist: {cfg.pretrained_model_path}")
        return
    
    if not os.path.exists(cfg.downstream_model_dir):
        print(f"Uncertain: Downstream model directory does not exist: {cfg.downstream_model_dir}")
        return
    
    # Select downstream task
    downstream_input = input("Select downstream task (1.ColdSpray, 2.RenishawL, 3.Yornew, 4.VF2): ")
    try:
        downstream_input = int(downstream_input)
        downstreams_map = {1: 'ColdSpray', 2: 'RenishawL', 3: 'Yornew', 4: 'VF2'}
        if downstream_input not in downstreams_map:
            print("Invalid selection.")
            return
        downstream = downstreams_map[downstream_input]
    except ValueError:
        print("Please enter a number.")
        return
    
    print(f"\n=== Starting {downstream} Downstream Task Test ===")
    
    set_deterministic_mode(cfg.seed or 42)

    # Variables for storing results
    metrics_list = []
    time_stats_list = []  # Add time statistics list
    classwise_metrics_by_class = defaultdict(lambda: defaultdict(list))
    confusion_matrix_accum = None
    confusion_matrix_count = 0
    
    # Run test for all folds
    for fold in range(1, 11):
        print(f"\n--- Testing Fold {fold}... ---")
        
        # CSV file path (use validation data)
        csv_path = Path(cfg.dataset_root) / downstream / f"test_f1_list_{fold}.csv"

        if not csv_path.exists():
            print(f"I don't know: CSV file not found: {csv_path}")
            continue
        
        try:
            # Run test
            metrics = test_downstream_model(fold, downstream, str(csv_path), cfg)
            
            if metrics is None:
                print(f"Fold {fold} test failed")
                continue
            
            # Collect metrics
            exclude_keys = ['class_metrics', 'confusion_matrix', 'time_stats']
            metrics_list.append({
                k: v for k, v in metrics.items()
                if k not in exclude_keys and np.isscalar(v)
            })
            
            # Collect time statistics
            if 'time_stats' in metrics:
                time_stats_list.append(metrics['time_stats'])
            
            # Accumulate confusion matrix
            if confusion_matrix_accum is None:
                confusion_matrix_accum = metrics['confusion_matrix'].astype(np.float32)
            else:
                confusion_matrix_accum += metrics['confusion_matrix'].astype(np.float32)
            confusion_matrix_count += 1
            
            # Collect per-class metrics
            for class_metric in metrics['class_metrics']:
                cname = class_metric['class']
                for key in ['precision', 'recall', 'f1', 'ap', 'auroc']:
                    classwise_metrics_by_class[cname][key].append(class_metric[key])
            
            print(f"Fold {fold} test complete")
            
        except Exception as e:
            print(f"Fold {fold} test failed: {e}")
            continue
    
    # Analyze and save overall results
    if not metrics_list:
        print("All fold tests failed.")
        return
    
    df = pd.DataFrame(metrics_list)
    
    print(f"\n==================== Final Aggregated Metrics (F1 Score) ====================")
    for col in df.columns:
        mean = df[col].mean()
        std = df[col].std()
        min_val = df[col].min()
        max_val = df[col].max()
        print(f"{col:12s} | Mean: {mean:.4f} | Std: {std:.4f} | Min: {min_val:.4f} | Max: {max_val:.4f}")
    
    # Aggregate and print time statistics
    if time_stats_list:
        time_df = pd.DataFrame(time_stats_list)
        
        print(f"\n==================== Aggregated Inference Time Statistics ====================")
        time_metrics = [
            'total_test_time', 'avg_spec_conversion_time', 'avg_feature_extraction_time',
            'avg_classification_time', 'avg_batch_inference_time', 'avg_sample_inference_time',
            'samples_per_second'
        ]
        
        for metric in time_metrics:
            if metric in time_df.columns:
                mean = time_df[metric].mean()
                std = time_df[metric].std()
                min_val = time_df[metric].min()
                max_val = time_df[metric].max()
                
                if 'time' in metric:
                    unit = 's'
                    print(f"{metric:25s} | Mean: {mean:.4f}{unit} | Std: {std:.4f}{unit} | Min: {min_val:.4f}{unit} | Max: {max_val:.4f}{unit}")
                else:
                    unit = 'samples/sec' if 'samples_per_second' in metric else ''
                    print(f"{metric:25s} | Mean: {mean:.2f}{unit} | Std: {std:.2f}{unit} | Min: {min_val:.2f}{unit} | Max: {max_val:.2f}{unit}")
    
    print(f"\n==================== Per-class Aggregated Metrics ====================")
    for cname, metrics in classwise_metrics_by_class.items():
        print(f"\nClass: {cname}")
        for key, values in metrics.items():
            mean = np.mean(values)
            std = np.std(values)
            min_val = np.min(values)
            max_val = np.max(values)
            print(f"  {key:>7s} | Mean: {mean:.4f} | Std: {std:.4f} | Min: {min_val:.4f} | Max: {max_val:.4f}")
    
    # Save aggregated results
    agg_dir = Path(cfg.results_dir) / downstream / "aggregated"
    agg_dir.mkdir(exist_ok=True, parents=True)
    
    # Save fold-wise overall metrics
    df.to_csv(agg_dir / "IMPACT_f1_fold_overall_mean_std.csv", index=False)
    
    # Save fold-wise time statistics
    if time_stats_list:
        time_df.to_csv(agg_dir / "IMPACT_f1_fold_inference_time_stats.csv", index=False)
        
        # Save time statistics summary
        time_summary = {}
        for metric in time_df.columns:
            time_summary[f"{metric}_mean"] = time_df[metric].mean()
            time_summary[f"{metric}_std"] = time_df[metric].std()
            time_summary[f"{metric}_min"] = time_df[metric].min()
            time_summary[f"{metric}_max"] = time_df[metric].max()
        
        pd.DataFrame([time_summary]).to_csv(
            agg_dir / "IMPACT_f1_inference_time_summary.csv", index=False
        )
    
    # Save per-class aggregated metrics
    per_class_rows = []
    for cname, metrics in classwise_metrics_by_class.items():
        row = {"class": cname}
        for k, vals in metrics.items():
            row[f"{k}_mean"] = float(np.mean(vals))
            row[f"{k}_std"] = float(np.std(vals))
            row[f"{k}_min"] = float(np.min(vals))
            row[f"{k}_max"] = float(np.max(vals))
        per_class_rows.append(row)
    
    pd.DataFrame(per_class_rows).to_csv(
        agg_dir / "IMPACT_f1_fold_per_class_mean_std.csv", index=False
    )
    
    # Save average confusion matrix
    if confusion_matrix_accum is not None:
        label_mapping = {
            'RenishawL': ['T1', 'T2'],
            'VF2': ['T3', 'T4', 'T5'],
            'Yornew': [f'T{i}' for i in range(6, 18)],
            'ColdSpray': [f'T{i}' for i in range(18, 26)]
        }
        
        cm_avg = (confusion_matrix_accum // confusion_matrix_count).astype(int)
        fig, ax = plt.subplots(figsize=(6, 6))
        
        # Get the actual number of classes from confusion matrix
        num_classes = cm_avg.shape[0]
        
        # Get predefined labels for this downstream task
        predefined_labels = label_mapping.get(downstream, [])
        
        # Use predefined labels only if the count matches, otherwise use generic labels
        if len(predefined_labels) == num_classes:
            labels = predefined_labels
        else:
            print(f"Warning: Predefined labels ({len(predefined_labels)}) don't match actual classes ({num_classes}). Using generic labels.")
            labels = [f'Class {i}' for i in range(num_classes)]
        
        disp = ConfusionMatrixDisplay(confusion_matrix=cm_avg, display_labels=labels)
        disp.plot(ax=ax, cmap="Blues", xticks_rotation=45, colorbar=False, values_format='d')
        plt.title(f"Average Confusion Matrix - {downstream}")
        plt.tight_layout()
        plt.savefig(agg_dir / "confusion_matrix_avg.png")
        plt.close()

    print(f"\n=== All {downstream} Tests Complete ===")
    print(f"Results saved to: {cfg.results_dir}/{downstream}")
    
    # Final summary output
    if time_stats_list:
        overall_avg_inference_time = np.mean([stats['avg_sample_inference_time'] for stats in time_stats_list])
        overall_throughput = np.mean([stats['samples_per_second'] for stats in time_stats_list])
        print(f"\n=== Final Performance Summary ===")
        print(f"Overall avg inference time per sample: {overall_avg_inference_time:.4f}s")
        print(f"Overall avg throughput: {overall_throughput:.2f} samples/sec")

if __name__ == "__main__":
    main()