from __future__ import annotations
import os
import random
import time
import logging
import warnings
from dataclasses import dataclass, field
from pathlib import Path
from typing import Tuple, Optional, Dict, List
from collections import defaultdict

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchaudio
import csv

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

from sklearn.metrics import (
    roc_auc_score, average_precision_score, accuracy_score, 
    precision_score, recall_score, f1_score, 
    confusion_matrix, ConfusionMatrixDisplay
)

# Suppress warnings
warnings.filterwarnings("ignore", category=UserWarning, module="torchaudio._backend")
warnings.filterwarnings("ignore", message=".*StreamingMediaDecoder.*")
warnings.filterwarnings("ignore", message=".*torchaudio.load_with_torchcodec.*")

# ============================================================================
# Configuration
# ============================================================================

@dataclass
class AUCTestConfig:
    # Audio settings (same as training code)
    dataset: str = 'Hybrid'
    sr: int = 48_000
    sample_len: int = 48_000
    n_fft: int = 2048
    win_length: int = 2048
    hop_length: int = 376
    n_mels: int = 128
    top_db: int = 80
    
    # Model architecture (same as training code)
    mr = 0.4
    ut = 0.2
    base_channels: int = 32
    patch_size: Tuple[int, int] = (16, 16)
    embed_dim: int = 384
    depth: int = 8
    nhead: int = 16
    decoder_dim: int = 256
    mask_ratio: float = mr
    dropout: float = 0.0
    utter_loss_weight: float = ut

    # Patch Embedding settings (same as training code)
    patch_embed_use_bn: bool = False
    patch_embed_use_activation: bool = False
    patch_embed_activation: str = 'none'
    patch_embed_bias: bool = False
    
    # Test settings
    batch_size: int = 16
    num_workers: int = 16
    seed: Optional[int] = 42
    device: str = field(default_factory=lambda: "cuda:1" if torch.cuda.is_available() else "cpu")
    
    # Paths
    pretrained_model_path: str = f"./IMPACT_Models/impact_{dataset}_epoch_0010.pth"
    downstream_model_dir: str = "./IMPACT_Models"
    dataset_root: str = "./Datasets/DINOS_Downstreams"
    results_dir: str = f"./IMPACT_Results/{dataset}"

# ============================================================================
# Model Components (from training code)
# ============================================================================

def build_2d_sincos_pos_embed(d_model: int, grid_f: int, grid_t: int) -> torch.Tensor:
    """Build 2D sinusoidal position embedding"""
    def _pe_1d(dim, pos):
        omega = torch.arange(dim // 2) / (dim // 2)
        omega = 1.0 / (10000 ** omega)
        out = pos.float().unsqueeze(1) * omega.unsqueeze(0)
        return torch.cat([out.sin(), out.cos()], dim=1)
    
    f, t = torch.meshgrid(torch.arange(grid_f), torch.arange(grid_t), indexing="ij")
    return torch.cat([_pe_1d(d_model // 2, f.reshape(-1)),
                      _pe_1d(d_model // 2, t.reshape(-1))], dim=1)

class MaskingGenerator:
    """Token masking generator"""
    def __init__(self, mask_ratio: float):
        self.mask_ratio = mask_ratio
        
    def __call__(self, B: int, N: int, device):
        keep = int(N * (1 - self.mask_ratio))
        noise = torch.rand(B, N, device=device)
        idx_shuffle = torch.argsort(noise, dim=1)
        idx_restore = torch.argsort(idx_shuffle, dim=1)
        return idx_shuffle[:, :keep], idx_shuffle[:, keep:], idx_restore

class CNNEncoder(nn.Module):
    """CNN encoder for initial feature extraction"""
    def __init__(self, c=32):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, c, 3, 2, 1),
            nn.BatchNorm2d(c),
            nn.ReLU()
        )
        self.out_c = c
        
    def forward(self, x):
        B, C, F, T = x.shape
        F_new, T_new = (F // 2) * 2, (T // 2) * 2
        x = x[:, :, :F_new, :T_new]
        return self.net(x)

class OptimalPatchEmbed(nn.Module):
    """Optimized patch embedding with Conv2d and optional BatchNorm/Activation"""
    def __init__(self, c_in: int, d_model: int, patch: Tuple[int, int], 
                 use_bn: bool = True, use_activation: bool = True, 
                 activation_type: str = 'gelu', bias: bool = False):
        super().__init__()
        self.p = patch
        
        # Conv2d for patch creation - more efficient than unfold + linear
        self.net = nn.Conv2d(c_in, d_model, kernel_size=patch, stride=patch, bias=bias)
        
        # Optional BatchNorm for training stability
        self.bn = nn.BatchNorm2d(d_model) if use_bn else nn.Identity()
        
        # Optional Activation for non-linearity
        if use_activation:
            if activation_type == 'gelu':
                self.act = nn.GELU()
            elif activation_type == 'relu':
                self.act = nn.ReLU()
            elif activation_type == 'silu':
                self.act = nn.SiLU()
            else:
                self.act = nn.Identity()
        else:
            self.act = nn.Identity()
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        """Initialize weights with truncated normal distribution"""
        w = self.net.weight.data
        nn.init.trunc_normal_(w.view(w.size(0), -1), std=0.02)
        if self.net.bias is not None:
            nn.init.zeros_(self.net.bias)
    
    def forward(self, x):
        """
        Args:
            x: (B, C, F, T) input tensor
        Returns:
            patches: (B, N, D) where N = H*W is number of patches
            grid_size: (H, W) grid dimensions
        """
        B, C, F, T = x.shape
        H, W = F // self.p[0], T // self.p[1]
        
        # Efficient patch extraction with Conv2d
        x = self.net(x)  # (B, d_model, H, W)
        
        # Apply BatchNorm if enabled
        x = self.bn(x)
        
        # Apply activation if enabled
        x = self.act(x)
        
        # Reshape for transformer input: (B, d_model, H, W) -> (B, H*W, d_model)
        x = x.flatten(2).transpose(1, 2)
        
        return x, (H, W)

class Encoder(nn.Module):
    """Transformer encoder"""
    def __init__(self, d_model, nhead, depth, drop=0., return_all=False):
        super().__init__()
        self.return_all = return_all
        self.cls = nn.Parameter(torch.zeros(1, 1, d_model))
        self.dp = nn.Dropout(drop)
        self.blocks = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model, nhead, d_model * 4, drop,
                batch_first=True, activation='gelu'
            ) for _ in range(depth)
        ])
        
    def forward(self, x):
        B = x.size(0)
        x = torch.cat([self.cls.expand(B, -1, -1), self.dp(x)], 1)
        cls_list = []
        for blk in self.blocks:
            x = blk(x)
            cls_list.append(x[:, 0, :])
        return (x, cls_list) if self.return_all else x

class CNNDecoder(nn.Module):
    """CNN decoder for reconstruction"""
    def __init__(self, patch_dim, base=128, out=1):
        super().__init__()
        self.base = base
        self.fc = nn.Linear(patch_dim, base)
        self.up = nn.Sequential(
            nn.ConvTranspose2d(base, base, 4, 2, 1),
            nn.BatchNorm2d(base),
            nn.ReLU(),
            nn.ConvTranspose2d(base, base // 2, 4, 2, 1),
            nn.BatchNorm2d(base // 2),
            nn.ReLU(),
            nn.ConvTranspose2d(base // 2, base // 4, 4, 2, 1),
            nn.BatchNorm2d(base // 4),
            nn.ReLU(),
            nn.ConvTranspose2d(base // 4, base // 8, 4, 2, 1),
            nn.BatchNorm2d(base // 8),
            nn.ReLU(),
            nn.ConvTranspose2d(base // 8, out, 4, 2, 1)
        )
        
    def forward(self, x, H, W):
        B, N, _ = x.shape
        x = self.fc(x).transpose(1, 2).contiguous().view(B, self.base, H, W)
        return self.up(x)

class EATMAE(nn.Module):
    """EATMAE with OptimalPatchEmbed for feature extraction"""
    def __init__(self, cfg: AUCTestConfig):
        super().__init__()
        self.cfg = cfg
        self.masker = MaskingGenerator(cfg.mask_ratio)
        
        # Model components
        self.cnn = CNNEncoder(cfg.base_channels)
        
        # Use OptimalPatchEmbed
        self.embed = OptimalPatchEmbed(
            c_in=self.cnn.out_c,
            d_model=cfg.embed_dim,
            patch=cfg.patch_size,
            use_bn=cfg.patch_embed_use_bn,
            use_activation=cfg.patch_embed_use_activation,
            activation_type=cfg.patch_embed_activation,
            bias=cfg.patch_embed_bias
        )
        
        self.enc = Encoder(cfg.embed_dim, cfg.nhead, cfg.depth, cfg.dropout, return_all=False)
        
        # Position embedding cache
        self.pos_embed_cache = {}
        self._precompute_common_sizes()
        
        # Decoder components (not used during inference but needed for model loading)
        self.enc2dec = nn.Linear(cfg.embed_dim, cfg.decoder_dim, bias=False)
        self.mask_tok = nn.Parameter(torch.zeros(1, 1, cfg.decoder_dim))
        self.dec = CNNDecoder(cfg.decoder_dim)
        
        # Initialize weights
        self._init_weights()
        
    def _init_weights(self):
        """Initialize model weights"""
        nn.init.normal_(self.mask_tok, std=0.02)
        nn.init.xavier_uniform_(self.enc2dec.weight)
        
    def _precompute_common_sizes(self):
        """Precompute position embeddings for common sizes"""
        common_sizes = [(8, 8), (7, 8), (8, 7), (6, 8), (8, 6), (4, 4)]
        for H, W in common_sizes:
            self.pos_embed_cache[(H, W)] = self._create_pos_embed(H, W)
            
    def _create_pos_embed(self, H, W):
        """Create position embeddings"""
        De, Dd = self.cfg.embed_dim, self.cfg.decoder_dim
        pe_enc = build_2d_sincos_pos_embed(De, H, W)
        pe_dec = build_2d_sincos_pos_embed(Dd, H, W)
        pe_enc = torch.cat([torch.zeros(1, De), pe_enc], 0)[None]
        pe_dec = torch.cat([torch.zeros(1, Dd), pe_dec], 0)[None]
        return pe_enc, pe_dec
    
    def get_pos_embed(self, H, W, device, dtype):
        """Get cached or new position embedding"""
        key = (H, W)
        if key not in self.pos_embed_cache:
            self.pos_embed_cache[key] = self._create_pos_embed(H, W)
        pe_enc, pe_dec = self.pos_embed_cache[key]
        return pe_enc.to(device, dtype), pe_dec.to(device, dtype)
    
    def forward_features(self, spec):
        """Extract features without reconstruction (for downstream tasks)"""
        B = spec.size(0)
        
        # Feature extraction with CNN
        feat = self.cnn(spec)
        
        # Patch embedding
        tokens, (H, W) = self.embed(feat)
        N, D = tokens.shape[1:]
        
        # Get position embeddings
        pos_enc, _ = self.get_pos_embed(H, W, tokens.device, tokens.dtype)
        tokens = tokens + pos_enc[:, 1:, :]
        
        # No masking for downstream tasks - use all tokens
        # Encoding
        enc_out = self.enc(tokens)
        cls_token = enc_out[:, 0, :]  # CLS token contains global representation
        
        return cls_token

class SpecConverter(nn.Module):
    """GPU-based spectrogram converter"""
    def __init__(self, cfg: AUCTestConfig):
        super().__init__()
        self.cfg = cfg
        self.mel = torchaudio.transforms.MelSpectrogram(
            sample_rate=cfg.sr,
            n_fft=cfg.n_fft,
            win_length=cfg.win_length,
            hop_length=cfg.hop_length,
            n_mels=cfg.n_mels,
            power=2.0
        )
        self.db = torchaudio.transforms.AmplitudeToDB("power", top_db=cfg.top_db)
        
    def forward(self, waveform):
        """Convert waveform to mel spectrogram on GPU"""
        if waveform.dim() == 2:
            waveform = waveform.unsqueeze(1)  # (B, T) -> (B, 1, T)
        
        # Normalize waveform
        rms = torch.sqrt((waveform**2).mean(dim=-1, keepdim=True))
        waveform = waveform / (rms + 1e-6)
        waveform = (waveform - waveform.mean(dim=-1, keepdim=True)) / (waveform.std(dim=-1, keepdim=True) + 1e-6)
        
        # Convert to mel spectrogram
        spec = self.mel(waveform)
        spec = self.db(spec)
        
        # Normalize to [0, 1]
        spec = spec[:, :, :, :128]  # Ensure time dimension is 128
        spec = (spec + self.cfg.top_db) / (2 * self.cfg.top_db)
        
        return spec

class SimpleAE(nn.Module):
    """Simple AutoEncoder for reconstruction (same as training code)"""
    def __init__(self, input_dim=384, latent_dim=128):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, input_dim)
        )

    def forward(self, x):
        latent = self.encoder(x)
        reconstructed = self.decoder(latent)
        return reconstructed

# ============================================================================
# Dataset (same method as training code)
# ============================================================================

class AUCTestDataset(Dataset):
    """AUC test dataset - uses same audio loading method as training code"""
    def __init__(self, csv_path: str | os.PathLike, cfg: AUCTestConfig, max_retries: int = 3):
        super().__init__()
        self.cfg = cfg
        self.sample_length = cfg.sample_len
        self.max_retries = max_retries
        self.error_count = defaultdict(int)
        self.index_list = []
        self.class_to_idx = {}
        self.csv_path = csv_path
        class_names = set()

        if not os.path.exists(csv_path):
            raise FileNotFoundError(f"CSV file not found: {csv_path}")

        if cfg.seed is not None:
            random.seed(cfg.seed)

        # Read CSV file
        logging.info(f"Starting to read CSV file: {csv_path}")
        with open(self.csv_path, newline="") as f:
            reader = csv.DictReader(f)
            for row in reader:
                # Check file existence
                filepath = row["filepath"]
                if not os.path.exists(filepath):
                    logging.warning(f"File does not exist: {filepath}")
                    continue
                    
                self.index_list.append(
                    (filepath, int(row["start_sample"]),
                     int(row["label"]), row["class_name"])
                )
                class_names.add(row["class_name"])

        self.class_to_idx = {name: idx for idx, name in enumerate(sorted(class_names))}
        
        logging.info(f"Dataset loading complete - Total {len(self.index_list)} samples, {len(class_names)} classes")

    def __len__(self):
        return len(self.index_list)
    
    def __getitem__(self, idx):
        for retry in range(self.max_retries):
            try:
                fp, start, label, folder = self.index_list[idx]
                
                # Check file existence
                if not os.path.exists(fp):
                    raise FileNotFoundError(f"File not found: {fp}")
                
                # Suppress warnings while loading audio (same method as training code)
                with warnings.catch_warnings():
                    warnings.simplefilter("ignore")
                    
                    # Load audio using torchaudio
                    try:
                        audio, _sr = torchaudio.load(
                            fp, 
                            frame_offset=start,
                            num_frames=self.cfg.sample_len
                        )
                    except Exception as e:
                        # fallback: load full file then slice
                        audio, _sr = torchaudio.load(fp)
                        end_frame = start + self.cfg.sample_len
                        audio = audio[:, start:end_frame]
                
                # Handle sample rate mismatch (same method as training code)
                if _sr != self.cfg.sr:
                    resampler = torchaudio.transforms.Resample(_sr, self.cfg.sr)
                    audio = resampler(audio)
                
                # Convert to numpy and handle channels (same method as training code)
                audio = audio.numpy()
                
                # Convert to mono if stereo
                if audio.shape[0] > 1:
                    audio = audio.mean(0)  # Average across channel dimension
                else:
                    audio = audio[0]  # Remove dimension for single channel
                
                # Pad if necessary
                audio = np.pad(audio, (0, max(0, self.cfg.sample_len - len(audio))))[:self.cfg.sample_len]
                
                # Convert to tensor
                waveform = torch.tensor(audio, dtype=torch.float32)
                
                return waveform, label, folder  # Return original binary label (0/1)
                
            except Exception as e:
                self.error_count[fp] += 1
                logging.error(f"Error loading {fp} (attempt {retry+1}/{self.max_retries}): {e}")
                
                if retry == self.max_retries - 1:
                    logging.error(f"Max retries exceeded for {fp}, using next sample")
                    return self.__getitem__((idx + 1) % len(self))
                    
                time.sleep(0.1 * (retry + 1))

# ============================================================================
# Model Loading Functions
# ============================================================================

def load_pretrained_model(model: EATMAE, pretrained_path: str, device: str):
    """Load pretrained EATMAE model"""
    if not os.path.exists(pretrained_path):
        raise FileNotFoundError(f"Pretrained model not found: {pretrained_path}")
    
    print(f"Loading pretrained model: {pretrained_path}")
    
    try:
        # Add weights_only=False for PyTorch 2.6 compatibility
        checkpoint = torch.load(pretrained_path, map_location=device, weights_only=False)
        
        # Check checkpoint format
        if 'model_state_dict' in checkpoint:
            state_dict = checkpoint['model_state_dict']
        else:
            state_dict = checkpoint
        
        # Load model
        model.load_state_dict(state_dict, strict=True)
        print("Pretrained model loaded successfully")
        
        # Set model to evaluation mode
        model.eval()
        for param in model.parameters():
            param.requires_grad = False
            
        return True
        
    except Exception as e:
        print(f"Model loading failed: {e}")
        return False

def load_downstream_ae_model(ae_model: SimpleAE, model_path: str, device: str):
    """Load trained downstream AutoEncoder model"""
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Downstream AE model not found: {model_path}")
    
    print(f"Loading downstream AE model: {model_path}")
    
    try:
        # Add weights_only=False for PyTorch 2.6 compatibility
        checkpoint = torch.load(model_path, map_location=device, weights_only=False)
        
        # Load model state from checkpoint
        ae_model.load_state_dict(checkpoint['model_state_dict'])
        
        # Extract metadata
        metadata = {
            'epoch': checkpoint.get('epoch', None),
            'loss': checkpoint.get('loss', None)
        }
        
        print(f"Downstream AE model loaded successfully - Epoch: {metadata['epoch']}, Loss: {metadata['loss']}")
        
        # Set model to evaluation mode
        ae_model.eval()
        
        return metadata
        
    except Exception as e:
        print(f"Downstream AE model loading failed: {e}")
        return None

# ============================================================================
# Utility Functions
# ============================================================================

def calculate_reconstruction_error(original, reconstructed, method='mse'):
    """Calculate reconstruction error"""
    if method == 'mse':
        return torch.mean((original - reconstructed) ** 2, dim=1)
    elif method == 'mae':
        return torch.mean(torch.abs(original - reconstructed), dim=1)
    elif method == 'cosine':
        cos_sim = F.cosine_similarity(original, reconstructed, dim=1)
        return 1 - cos_sim
    else:
        return torch.mean((original - reconstructed) ** 2, dim=1)

def find_optimal_threshold(y_true, scores):
    """Find optimal threshold based on F1 score"""
    thresholds = np.percentile(scores, np.linspace(0, 100, 101))
    best_threshold = thresholds[0]
    best_f1 = 0
    
    for threshold in thresholds:
        y_pred = (scores >= threshold).astype(int)
        f1 = f1_score(y_true, y_pred, zero_division=0)
        if f1 > best_f1:
            best_f1 = f1
            best_threshold = threshold
    
    return best_threshold

# ============================================================================
# Test Function
# ============================================================================

def test_downstream_auc_model(fold: int, downstream: str, csv_path: str, cfg: AUCTestConfig):
    """Test downstream AUC model with reconstruction error-based anomaly detection"""
    
    print(f"\n==================== Fold {fold} AUC Test ====================")
    
    # Set seed
    if cfg.seed is not None:
        torch.manual_seed(cfg.seed)
        np.random.seed(cfg.seed)
        random.seed(cfg.seed)
    
    # Load dataset
    try:
        dataset = AUCTestDataset(csv_path, cfg)
        dataloader = DataLoader(
            dataset, 
            cfg.batch_size, 
            shuffle=False,  # No shuffle during test
            num_workers=cfg.num_workers, 
            pin_memory=(cfg.device == "cuda"), 
            drop_last=False  # Don't drop last batch for complete evaluation
        )
        print(f"Dataset loaded successfully - Classes: {len(dataset.class_to_idx)}, Samples: {len(dataset)}")
    except Exception as e:
        print(f"Dataset loading failed: {e}")
        return None
    
    # Initialize models
    try:
        # Spectrogram converter
        spec_converter = SpecConverter(cfg).to(cfg.device)
        
        # Pretrained EATMAE model (feature extractor)
        feature_extractor = EATMAE(cfg).to(cfg.device)
        
        # Load pretrained model
        if not load_pretrained_model(feature_extractor, cfg.pretrained_model_path, cfg.device):
            print("Failed to load pretrained model")
            return None
            
        # Downstream AutoEncoder
        ae_model = SimpleAE(input_dim=cfg.embed_dim, latent_dim=128).to(cfg.device)
        
        # Load downstream AE model
        model_path = Path(cfg.downstream_model_dir) / downstream / f"{downstream}_auc_fold_{fold:02d}_{cfg.dataset}.pth"
        metadata = load_downstream_ae_model(ae_model, str(model_path), cfg.device)
        if metadata is None:
            print("Failed to load downstream AE model")
            return None
            
        print(f"Model initialization complete")
        
    except Exception as e:
        print(f"Model initialization failed: {e}")
        return None
    
    # Run test
    feature_extractor.eval()
    ae_model.eval()
    
    all_labels = []
    all_reconstruction_errors = []
    all_class_names = []
    
    print("Running AUC test...")
    
    with torch.no_grad():
        for batch_idx, (waveform, label, class_names) in enumerate(dataloader):
            waveform = waveform.to(cfg.device)
            
            # Convert to spectrogram
            spec = spec_converter(waveform)
            
            # Extract features using pretrained EATMAE
            features = feature_extractor.forward_features(spec)
            
            # Reconstruct features using trained AutoEncoder
            reconstructed_features = ae_model(features)
            
            # Calculate reconstruction error (MSE)
            mse_errors = calculate_reconstruction_error(features, reconstructed_features, 'mse')
            
            # Store results
            all_labels.extend(label.cpu().numpy())
            all_reconstruction_errors.extend(mse_errors.cpu().numpy())
            all_class_names.extend(class_names)
            
            # Print progress
            if (batch_idx + 1) % 10 == 0:
                print(f"Batch {batch_idx + 1}/{len(dataloader)} complete")
    
    # Convert to numpy arrays
    y_true = np.array(all_labels)
    binary_labels = y_true.copy()

    reconstruction_errors = np.array(all_reconstruction_errors)
    
    # For AUC calculation, higher reconstruction error indicates abnormal
    # So use reconstruction error as anomaly score
    anomaly_scores = reconstruction_errors
    
    print(f"Test complete - Processed {len(y_true)} samples")
    
    # Calculate AUC metrics
    try:
        auc_score = roc_auc_score(y_true, anomaly_scores)
        ap_score = average_precision_score(y_true, anomaly_scores)
    except ValueError as e:
        print(f"Warning: AUC calculation failed: {e}")
        auc_score = 0.0
        ap_score = 0.0
    
    # Find optimal threshold
    optimal_threshold = find_optimal_threshold(y_true, anomaly_scores)

    # Create binary predictions if missing
    binary_predictions = (all_reconstruction_errors >= optimal_threshold).astype(int)

    # --- Confusion Matrix (per-fold) ---
    # Ensure fixed label order (0=normal, 1=abnormal)
    cm = confusion_matrix(binary_labels, binary_predictions, labels=[0, 1])
    
    # Threshold-based predictions
    y_pred = (anomaly_scores >= optimal_threshold).astype(int)
    
    # Calculate classification metrics
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, zero_division=0)
    recall = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)
    
    # Class-wise statistics
    idx_to_class = {v: k for k, v in dataset.class_to_idx.items()}
    class_stats = {}
    
    for class_idx in np.unique(y_true):
        class_name = idx_to_class[class_idx]
        class_mask = y_true == class_idx
        class_errors = anomaly_scores[class_mask]
        
        class_stats[class_name] = {
            'count': int(np.sum(class_mask)),
            'mean_error': float(np.mean(class_errors)),
            'std_error': float(np.std(class_errors)),
            'min_error': float(np.min(class_errors)),
            'max_error': float(np.max(class_errors))
        }
    
    # Print results
    print(f"\n==================== Results ====================")
    print(f"AUC Score: {auc_score:.4f}")
    print(f"Average Precision: {ap_score:.4f}")
    print(f"Optimal Threshold: {optimal_threshold:.6f}")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")
    
    print(f"\nClass-wise Reconstruction Error Statistics:")
    for class_name, stats in class_stats.items():
        print(f"{class_name}: Count={stats['count']}, "
              f"Mean={stats['mean_error']:.6f}±{stats['std_error']:.6f}, "
              f"Range=[{stats['min_error']:.6f}, {stats['max_error']:.6f}]")
    
    # Save results
    fold_dir = Path(cfg.results_dir) / downstream / f"fold_{fold}"
    fold_dir.mkdir(exist_ok=True, parents=True)
    
    # Save overall metrics
    overall_metrics = {
        "fold": fold,
        "auc_score": auc_score,
        "average_precision": ap_score,
        "optimal_threshold": optimal_threshold,
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1_score": f1
    }
    
    pd.DataFrame([overall_metrics]).to_csv(
        fold_dir / "IMPACT_AUC_overall_metrics.csv", index=False
    )
    
    # Save class statistics
    class_stats_df = pd.DataFrame.from_dict(class_stats, orient='index')
    class_stats_df.reset_index(inplace=True)
    class_stats_df.rename(columns={'index': 'class_name'}, inplace=True)
    class_stats_df.to_csv(fold_dir / "IMPACT_AUC_class_statistics.csv", index=False)
    
    # Generate Confusion Matrix
    cm = confusion_matrix(y_true, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, 
                                  display_labels=['Normal', 'Abnormal'])
    fig, ax = plt.subplots(figsize=(6, 6))
    disp.plot(ax=ax, cmap="Blues", values_format='d')
    plt.title(f"Confusion Matrix - Fold {fold}")
    plt.tight_layout()
    plt.savefig(fold_dir / "IMPACT_AUC_confusion_matrix.png")
    plt.close()
    
    # Reconstruction Error Distribution Histogram
    plt.figure(figsize=(10, 6))
    for class_idx in np.unique(y_true):
        class_name = idx_to_class[class_idx]
        class_mask = y_true == class_idx
        class_errors = anomaly_scores[class_mask]
        plt.hist(class_errors, alpha=0.7, label=f'{class_name} (n={len(class_errors)})', 
                bins=50, density=True)
    
    plt.axvline(optimal_threshold, color='red', linestyle='--', 
                label=f'Optimal Threshold: {optimal_threshold:.6f}')
    plt.xlabel('Reconstruction Error')
    plt.ylabel('Density')
    plt.title(f'Reconstruction Error Distribution - Fold {fold}')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(fold_dir / "IMPACT_AUC_error_distribution.png")
    plt.close()
    
    print(f"Results saved: {fold_dir}")
    
    return {
        'auc_score': auc_score,
        'average_precision': ap_score,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'optimal_threshold': optimal_threshold,
        'class_stats': class_stats,
        'confusion_matrix': cm
    }

# ============================================================================
# Main Function
# ============================================================================

def set_deterministic_mode(seed: int = 42):
    """Set deterministic mode for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
    torch.backends.cuda.enable_mem_efficient_sdp(False)

    try:
        torch.use_deterministic_algorithms(True, warn_only=True)
    except AttributeError:
        pass

def main():
    """Main function"""
    
    # Configure settings
    cfg = AUCTestConfig()
    
    # Check required files existence
    if not os.path.exists(cfg.pretrained_model_path):
        print(f"Pretrained model file does not exist: {cfg.pretrained_model_path}")
        return
    
    if not os.path.exists(cfg.downstream_model_dir):
        print(f"Downstream model directory does not exist: {cfg.downstream_model_dir}")
        return
    
    # Select downstream task
    downstream_input = input("Select downstream task (1.ColdSpray, 2.Yornew): ")
    try:
        downstream_input = int(downstream_input)
        downstreams_map = {1: 'ColdSpray', 2: 'Yornew'}
        if downstream_input not in downstreams_map:
            print("Invalid selection.")
            return
        downstream = downstreams_map[downstream_input]
    except ValueError:
        print("Please enter a number.")
        return
    
    print(f"\n=== Starting {downstream} Downstream AUC Task Test ===")
    
    set_deterministic_mode(cfg.seed or 42)

    # Variables for storing results
    all_fold_metrics = []
    classwise_metrics_by_class = defaultdict(lambda: defaultdict(list))
    confusion_matrix_accum = True
    confusion_matrix_count = 0
      
    # Run test for all folds
    for fold in range(1, 11):
        print(f"\n--- Testing AUC Fold {fold}... ---")
        
        # CSV file path
        csv_path = Path(cfg.dataset_root) / downstream / f"test_auc_list_{fold}.csv"

        if not csv_path.exists():
            print(f"CSV file not found: {csv_path}")
            continue
        
        try:
            # Run test
            metrics = test_downstream_auc_model(fold, downstream, str(csv_path), cfg)
            
            if metrics is None:
                print(f"Fold {fold} AUC test failed")
                continue
            
            # Collect metrics
            exclude_keys = ['class_stats', 'confusion_matrix']
            fold_metrics = {
                k: v for k, v in metrics.items()
                if k not in exclude_keys and np.isscalar(v)
            }
            fold_metrics['fold'] = fold
            all_fold_metrics.append(fold_metrics)
            
            # Accumulate confusion matrix
            if confusion_matrix_accum is None:
                confusion_matrix_accum = metrics['confusion_matrix'].astype(np.float32)
            else:
                confusion_matrix_accum += metrics['confusion_matrix'].astype(np.float32)
            confusion_matrix_count += 1
            
            # Collect per-class metrics
            for class_name, stats in metrics['class_stats'].items():
                for key in ['mean_error', 'std_error', 'min_error', 'max_error']:
                    classwise_metrics_by_class[class_name][key].append(stats[key])
            
            print(f"Fold {fold} AUC test complete")
            
        except Exception as e:
            print(f"Fold {fold} AUC test failed: {e}")
            continue
    
    # Analyze and save overall results
    if not all_fold_metrics:
        print("All fold AUC tests failed.")
        return
    
    df = pd.DataFrame(all_fold_metrics)
    
    print(f"\n==================== Final Aggregated AUC Metrics ====================")
    for col in ['auc_score', 'average_precision', 'accuracy', 'precision', 'recall', 'f1_score']:
        if col in df.columns:
            mean = df[col].mean()
            std = df[col].std()
            min_val = df[col].min()
            max_val = df[col].max()
            print(f"{col:18s} | Mean: {mean:.4f} | Std: {std:.4f} | Min: {min_val:.4f} | Max: {max_val:.4f}")
    
    print(f"\n==================== Per-class Aggregated Error Statistics ====================")
    for cname, metrics in classwise_metrics_by_class.items():
        print(f"\nClass: {cname}")
        for key, values in metrics.items():
            mean = np.mean(values)
            std = np.std(values)
            min_val = np.min(values)
            max_val = np.max(values)
            print(f"  {key:>12s} | Mean: {mean:.6f} | Std: {std:.6f} | Min: {min_val:.6f} | Max: {max_val:.6f}")
    
    # Save aggregated results
    agg_dir = Path(cfg.results_dir) / downstream / "aggregated"
    agg_dir.mkdir(exist_ok=True, parents=True)
    
    # Save fold-wise overall metrics
    df.to_csv(agg_dir / "IMPACT_AUC_all_folds_results.csv", index=False)
    
    # Save summary statistics
    summary_stats = df.describe()
    summary_stats.to_csv(agg_dir / "IMPACT_AUC_summary_statistics.csv")
    
    # Save per-class aggregated metrics
    per_class_rows = []
    for cname, metrics in classwise_metrics_by_class.items():
        row = {"class": cname}
        for k, vals in metrics.items():
            row[f"{k}_mean"] = float(np.mean(vals))
            row[f"{k}_std"] = float(np.std(vals))
            row[f"{k}_min"] = float(np.min(vals))
            row[f"{k}_max"] = float(np.max(vals))
        per_class_rows.append(row)
    
    pd.DataFrame(per_class_rows).to_csv(
        agg_dir / "IMPACT_AUC_per_class_error_stats.csv", index=False
    )
    
    # Save average confusion matrix
    if confusion_matrix_accum is not None:
        cm_avg = (confusion_matrix_accum / confusion_matrix_count).astype(int)
        fig, ax = plt.subplots(figsize=(6, 6))
        disp = ConfusionMatrixDisplay(confusion_matrix=cm_avg, 
                                      display_labels=['Normal', 'Abnormal'])
        disp.plot(ax=ax, cmap="Blues", xticks_rotation=45, colorbar=False, values_format='d')
        plt.title(f"Average Confusion Matrix - {downstream}")
        plt.tight_layout()
        plt.savefig(agg_dir / "IMPACT_AUC_confusion_matrix_avg.png")
        plt.close()
    
    print(f"\n=== All {downstream} AUC Tests Complete ===")
    print(f"Results saved to: {cfg.results_dir}/{downstream}")

if __name__ == "__main__":
    main()