import torch
import torch.nn as nn
import torch.nn.functional as F
import time
# import torchvision  # No longer needed for datasets
import torchvision.transforms as transforms
import torchvision.transforms.v2 as T  # GPU transforms
import numpy as np
import os
import math # Import math for cosine annealing
from torch.utils.data import Dataset, DataLoader
import torch.distributions as dist  # Import distributions for Beta
# import h5py # No longer needed
from tqdm import tqdm # Import tqdm for loading progress
import multiprocessing # Import multiprocessing
import functools # Import functools for partial
import yaml  # Import YAML library
import argparse # Import argparse for command-line arguments
from typing import Dict # Import Dict for type hinting
import shutil # Import shutil for file copying

def cpu_clone_state_dict(model: nn.Module) -> Dict[str, torch.Tensor]:
    return {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

# --- Argument Parsing ---
def parse_args():
    parser = argparse.ArgumentParser(description='Train a SimpleCNN on downsampled ImageNet-1k')
    parser.add_argument('--config-path', type=str, required=True,
                        help='Path to the YAML configuration file.')
    parser.add_argument('--save-name', type=str, required=True,
                        help='Directory path to save the training history and summary.')
    return parser.parse_args()

# Parse command-line arguments early
args = parse_args()

# --- Worker function for parallel loading ---
def load_image_worker(index: int, image_dir: str) -> np.ndarray:
    """Loads a single .npy image file."""
    img_path = os.path.join(image_dir, f"{index}.npy")
    # No try-except here, rely on the main process catching errors via imap
    return np.load(img_path)

# --- Helper function for CutMix ---
def rand_bbox(size, lam):
    """Generates random bounding box coordinates for CutMix."""
    W = size[2]
    H = size[3]
    # Calculate cut ratio
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    # Uniformly sample the center of the patch
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    # Calculate box coordinates, clamping to image boundaries
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

class SimpleCNN(nn.Module):
    def __init__(self, width: int, num_classes: int):
        super().__init__()
        self.width = width
        self.num_classes = num_classes

        # Convolutional layers
        self.conv1 = nn.Conv2d(
            in_channels=3, out_channels=width, kernel_size=3, stride=3
        )
        self.bn1 = nn.BatchNorm2d(width)
        self.conv2 = nn.Conv2d(
            in_channels=width, out_channels=3 * width, kernel_size=3, stride=3
        )
        self.bn2 = nn.BatchNorm2d(3 * width)
        self.conv3 = nn.Conv2d(
            in_channels=3 * width, out_channels=9 * width, kernel_size=3,
            stride=3
        )
        self.bn3 = nn.BatchNorm2d(9 * width)
        self.conv4 = nn.Conv2d(
            in_channels=9 * width, out_channels=27 * width, kernel_size=3,
            stride=3
        )
        self.bn4 = nn.BatchNorm2d(27 * width)

        # Fully connected layers
        self.fc1 = nn.Linear(27 * width, 27 * width)
        self.fc2 = nn.Linear(27 * width, num_classes)

    def forward(self, x):
        # Input shape: (batch_size, 3, 81, 81)
        # Shape: (batch_size, width, 27, 27)
        x = F.relu(self.bn1(self.conv1(x)))
        # Shape: (batch_size, 3*width, 9, 9)
        x = F.relu(self.bn2(self.conv2(x)))
        # Shape: (batch_size, 9*width, 3, 3)
        x = F.relu(self.bn3(self.conv3(x)))
        # Shape: (batch_size, 27*width, 1, 1)
        x = F.relu(self.bn4(self.conv4(x)))

        # Flatten the output for the fully connected layers
        # Shape after flatten: (batch_size, 27*width)
        x = torch.flatten(x, 1)

        x = F.relu(self.fc1(x))  # Shape: (batch_size, 27*width)
        x = self.fc2(x)          # Shape: (batch_size, num_classes)
        return x


# --- Global Constants ---
# Define path to the configuration file from command-line arguments
CONFIG_PATH = args.config_path
# Define save name constant from command-line arguments
SAVE_NAME = args.save_name

# Load configuration from YAML file
print(f"Loading configuration from: {CONFIG_PATH}")
with open(CONFIG_PATH, 'r') as f:
    config_values = yaml.safe_load(f)

# Assign constants from the loaded configuration
# Script will fail with KeyError if any key is missing
NUM_EPOCHS_COSINE = config_values['NUM_EPOCHS_COSINE']
LEARNING_RATE = float(config_values['LEARNING_RATE'])
BATCH_SIZE = config_values['BATCH_SIZE']
MIXUP_ALPHA = config_values['MIXUP_ALPHA']
MIXUP_PROB = config_values['MIXUP_PROB']
CUTMIX_BETA = config_values['CUTMIX_BETA']
CUTMIX_PROB = config_values['CUTMIX_PROB']
PREPROCESSED_PATH = config_values['PREPROCESSED_PATH']
WIDTH = config_values['WIDTH']
NUM_CLASSES = config_values['NUM_CLASSES']
NUM_WORKERS = config_values['NUM_WORKERS']
CROP_SIZE = config_values['CROP_SIZE']
N_WARMUP = config_values['N_WARMUP']

print("Configuration loaded successfully.")

# --- Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# -----------------------------
# GPU-side augmentation pipelines
# -----------------------------
# CPU conversion: uint8 numpy → float tensor [0,1]
# --- CPU-side transforms (ensure fixed spatial size before collation) ---
cpu_train_tf = transforms.Compose([
    transforms.ToPILImage(),          # (H, W, C) uint8 → PIL
    transforms.RandomCrop(CROP_SIZE), # Random crop for training
    transforms.ToTensor(),            # → (C, H, W) float32 in [0,1]
])

cpu_val_tf = transforms.Compose([
    transforms.ToPILImage(),          # (H, W, C) uint8 → PIL
    transforms.CenterCrop(CROP_SIZE), # Deterministic crop for validation
    transforms.ToTensor(),
])

train_gpu_tf = nn.Sequential(
    T.RandomHorizontalFlip(),
    T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    T.RandAugment(),
    T.RandomErasing(p=0.25, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
).to(device)

val_gpu_tf = nn.Sequential(
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
).to(device)

# --- Data Preparation (Loading individual NPY files into memory) ---
print(f"Loading preprocessed data from: {PREPROCESSED_PATH}")
start_load_time = time.time()

# Define paths for label files and image directories
train_labels_path = os.path.join(PREPROCESSED_PATH, "train_labels.npy")
val_labels_path = os.path.join(PREPROCESSED_PATH, "val_labels.npy")
train_images_dir = os.path.join(PREPROCESSED_PATH, "train_images")
val_images_dir = os.path.join(PREPROCESSED_PATH, "val_images")

# Load labels first to know the number of images
print(f"Loading labels from {train_labels_path} and {val_labels_path}...")
train_labels_np = np.load(train_labels_path)
val_labels_np = np.load(val_labels_path)
num_train_images = len(train_labels_np)
num_val_images = len(val_labels_np)
print(f"Found {num_train_images} train labels and {num_val_images} val labels.")

# Load individual image .npy files into lists using multiprocessing
train_images_list_np = []
val_images_list_np = []

# Use NUM_WORKERS for parallel loading, adjust if needed
loading_workers = NUM_WORKERS
print(f"Using {loading_workers} workers for parallel image loading...")

print(f"Loading {num_train_images} training images from {train_images_dir}...")
# Prepare partial function for train images
train_load_func = functools.partial(load_image_worker, image_dir=train_images_dir)
with multiprocessing.Pool(processes=loading_workers) as pool:
    # Use imap to preserve order and wrap with tqdm for progress
    train_iterator = pool.imap(train_load_func, range(num_train_images))
    train_images_list_np = list(tqdm(train_iterator, total=num_train_images, desc="Loading train images"))

print(f"Loading {num_val_images} validation images from {val_images_dir}...")
# Prepare partial function for val images
val_load_func = functools.partial(load_image_worker, image_dir=val_images_dir)
with multiprocessing.Pool(processes=loading_workers) as pool:
    # Use imap to preserve order and wrap with tqdm for progress
    val_iterator = pool.imap(val_load_func, range(num_val_images))
    val_images_list_np = list(tqdm(val_iterator, total=num_val_images, desc="Loading val images"))


# Strict check to ensure loaded image count matches label count.
# This should always pass now that try...except is removed,
# but provides an explicit error if something unexpected occurred.
if len(train_images_list_np) != num_train_images:
    raise ValueError(
        f"Mismatch after loading: Number of train labels ({num_train_images}) "
        f"does not match number of loaded train images ({len(train_images_list_np)}). "
        f"Check data integrity in {train_images_dir}"
    )
if len(val_images_list_np) != num_val_images:
    raise ValueError(
        f"Mismatch after loading: Number of val labels ({num_val_images}) "
        f"does not match number of loaded val images ({len(val_images_list_np)}). "
        f"Check data integrity in {val_images_dir}"
    )

load_time = time.time() - start_load_time
print(f"Data loaded in {load_time:.2f}s")

# Print info about the loaded lists
print(f"Loaded {len(train_images_list_np)} training images (list of np arrays)")
if train_images_list_np:
    print(f"  Example shape of first train image: {train_images_list_np[0].shape}, "
          f"dtype: {train_images_list_np[0].dtype}")
print(
    f"Train Labels Shape: {train_labels_np.shape}, "
    f"dtype: {train_labels_np.dtype}"
)
print(f"Loaded {len(val_images_list_np)} validation images (list of np arrays)")
if val_images_list_np:
    print(f"  Example shape of first val image: {val_images_list_np[0].shape}, "
          f"dtype: {val_images_list_np[0].dtype}")
print(
    f"Val Labels Shape: {val_labels_np.shape}, "
    f"dtype: {val_labels_np.dtype}"
)


# Convert Labels to Tensors (Images remain as list of NumPy arrays)
print("Converting labels NumPy arrays to Tensors...")
train_labels = torch.from_numpy(train_labels_np).long()
val_labels = torch.from_numpy(val_labels_np).long()

print(
    f"Train Labels Tensor Shape: {train_labels.shape}, "
    f"dtype: {train_labels.dtype}"
)
print(
    f"Val Labels Tensor Shape: {val_labels.shape}, "
    f"dtype: {val_labels.dtype}"
)


# Define a custom Dataset that handles a list of NumPy arrays in memory
# Renaming back to something descriptive
class InMemoryListDataset(Dataset):
    def __init__(self, images_list, labels, transform=None):
        # Store the list of numpy arrays (H, W, C) uint8
        self.images_list = images_list
        self.labels = labels # Already a tensor
        self.transform = transform

        # Add a check to ensure images and labels align after loading
        if len(self.images_list) != len(self.labels):
            raise ValueError(
                f"Number of images ({len(self.images_list)}) does not match "
                f"number of labels ({len(self.labels)}). "
                "Check data loading and potential skipped files."
            )

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        # Get the numpy array (H, W, C) uint8
        image_np = self.images_list[idx]
        label = self.labels[idx]
        # Apply transforms which should handle conversion to PIL and then Tensor
        if self.transform:
            # Transforms expect PIL or Tensor (C, H, W)
            # ToPILImage handles (H, W, C) numpy array directly
            image = self.transform(image_np)
        else:
            # If no transform, convert manually (basic case)
            image = transforms.ToTensor()(image_np) # (H, W, C) -> (C, H, W) float

        return image, label


# --- Define Transforms for In-Memory NumPy Arrays ---
# --- CPU→GPU Data Augmentation Split ---
# Dataset now only converts to tensor; heavy augmentation on GPU.


# --- Create Datasets and DataLoaders ---
print("Creating InMemoryListDataset instances and DataLoaders...") # Updated message

train_dataset = InMemoryListDataset(
    train_images_list_np, train_labels, transform=cpu_train_tf
)
val_dataset = InMemoryListDataset(
    val_images_list_np, val_labels, transform=cpu_val_tf
)

# DataLoader loads batches from the InMemoryListDataset
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=True # pin_memory helps GPU transfer
)

val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=True
)
print("InMemoryListDataset and DataLoaders created successfully.") # Updated message


# --- Model, Loss, Optimizer ---
print("Initializing model...")
model = SimpleCNN(width=WIDTH, num_classes=NUM_CLASSES).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# --- Variables to track best validation accuracy ---
best_val_acc1 = 0.0
best_epoch_acc1 = 0
best_val_acc5 = 0.0
best_epoch_acc5 = 0
best_model_state_dict_top5 = None # Initialize variable for best state dict (top-5)

# --- History Tracking ---
history = {
    'train_loss': [],
    'train_acc1': [],
    'train_acc5': [],
    'val_acc1': [],
    'val_acc5': [],
    'lr': []       # Add entry for learning rate
}

# --- Training Loop ---
print("Starting training...")
total_epochs = N_WARMUP + NUM_EPOCHS_COSINE # Calculate total epochs
for epoch in range(total_epochs):
    start_time = time.time()  # Record start time

    # --- Learning Rate Schedule ---
    if epoch < N_WARMUP:
        # Linear Warmup
        current_lr = LEARNING_RATE * (epoch + 1) / N_WARMUP
    else:
        # Cosine Annealing
        cosine_epoch = epoch - N_WARMUP  # Epoch number within the cosine phase (0 to NUM_EPOCHS_COSINE-1)
        current_lr = LEARNING_RATE * 0.5 * (1 + math.cos(math.pi * cosine_epoch / NUM_EPOCHS_COSINE))

    # Update optimizer's learning rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = current_lr

    model.train()  # Set model to training mode
    running_loss = 0.0
    correct_train = 0
    total_train = 0
    correct_train_top5 = 0  # Initialize top-5 counter for training
    total_train_acc_samples = 0  # Denominator for accuracy calculation

    mixup_beta_distribution = dist.Beta(MIXUP_ALPHA, MIXUP_ALPHA)
    cutmix_beta_distribution = dist.Beta(CUTMIX_BETA, CUTMIX_BETA) # Although beta=1.0, keep for consistency

    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device, non_blocking=True)
        images = train_gpu_tf(images)
        labels = labels.to(device, non_blocking=True)
        batch_actual_size = images.size(0)  # Get actual batch size
        loss = 0.0  # Initialize loss for the step

        # --- Augmentation Selection ---
        rand_prob = torch.rand(1).item()
        apply_mixup = rand_prob < MIXUP_PROB and MIXUP_ALPHA > 0
        apply_cutmix = MIXUP_PROB <= rand_prob < MIXUP_PROB + CUTMIX_PROB and CUTMIX_BETA > 0

        if apply_mixup:
            # --- Mixup Implementation ---
            # Sample a gamma for each sample in the batch
            gamma = mixup_beta_distribution.sample((batch_actual_size,)).to(device)

            # Get shuffled indices
            indices = torch.randperm(batch_actual_size, device=device)
            images_j = images[indices]
            labels_j = labels[indices]

            # Reshape gamma for broadcasting with images (N, C, H, W)
            gamma_img = gamma.view(batch_actual_size, 1, 1, 1)
            # Mix images
            mixed_images = gamma_img * images + (1.0 - gamma_img) * images_j

            # One-hot encode labels
            labels_one_hot = F.one_hot(
                labels, num_classes=NUM_CLASSES
            ).float()
            labels_j_one_hot = F.one_hot(
                labels_j, num_classes=NUM_CLASSES
            ).float()

            # Reshape gamma for broadcasting with labels (N, NumClasses)
            gamma_lbl = gamma.view(batch_actual_size, 1)
            # Mix labels
            mixed_labels = (
                gamma_lbl * labels_one_hot
                + (1.0 - gamma_lbl) * labels_j_one_hot
            )

            # Forward pass with mixed images
            outputs = model(mixed_images)

            # Calculate custom cross-entropy loss for soft labels
            log_softmax_outputs = F.log_softmax(outputs, dim=1)
            # Sum over classes, then average over batch
            loss = -torch.sum(
                mixed_labels * log_softmax_outputs, dim=1
            ).mean()

            # Accuracy calculation removed from Mixup branch
            # correct_train and correct_train_top5 will only be updated
            # in the 'else' block (clean images)

        elif apply_cutmix:
            # --- CutMix Implementation ---
            # Sample lambda
            lam = cutmix_beta_distribution.sample().item() # Single lambda for the batch

            # Get shuffled indices
            rand_index = torch.randperm(batch_actual_size, device=device)
            target_a = labels
            target_b = labels[rand_index]

            # Generate bounding box coordinates
            bbx1, bby1, bbx2, bby2 = rand_bbox(images.size(), lam)

            # Create mixed image by replacing the patch
            images_cutmix = images.clone() # Clone to avoid modifying original images if needed elsewhere
            images_cutmix[:, :, bbx1:bbx2, bby1:bby2] = images[rand_index, :, bbx1:bbx2, bby1:bby2]

            # Adjust lambda to match pixel ratio
            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (images.size()[-1] * images.size()[-2]))

            # Forward pass with CutMix images
            outputs = model(images_cutmix)

            # Calculate loss for soft labels (similar to Mixup)
            # One-hot encode labels
            target_a_one_hot = F.one_hot(target_a, num_classes=NUM_CLASSES).float()
            target_b_one_hot = F.one_hot(target_b, num_classes=NUM_CLASSES).float()

            # Mix labels based on adjusted lambda
            mixed_labels = lam * target_a_one_hot + (1.0 - lam) * target_b_one_hot

            # Calculate custom cross-entropy loss for soft labels
            log_softmax_outputs = F.log_softmax(outputs, dim=1)
            # Sum over classes, then average over batch
            loss = -torch.sum(mixed_labels * log_softmax_outputs, dim=1).mean()

            # Accuracy calculation removed from CutMix branch

        else:  # Standard training step (no Mixup or CutMix)
            # Forward pass
            outputs = model(images)

            # Calculate cross-entropy loss manually for consistency
            labels_one_hot = F.one_hot(
                labels, num_classes=NUM_CLASSES
            ).float()
            log_softmax_outputs = F.log_softmax(outputs, dim=1)
            loss = -torch.sum(
                labels_one_hot * log_softmax_outputs, dim=1
            ).mean()

            # Accuracy (standard calculation)
            with torch.no_grad():
                # Top-1
                _, predicted = torch.max(outputs.data, 1)
                correct_train += (predicted == labels).sum().item()
                # Top-5
                _, predicted_top5 = torch.topk(outputs, 5, dim=1)
                labels_expanded = labels.view(-1, 1).expand_as(predicted_top5)
                correct_train_top5 += torch.sum(
                    predicted_top5 == labels_expanded
                ).item()
            # Increment denominator for accuracy calculation
            total_train_acc_samples += batch_actual_size

        # Backward and optimize (common step)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * batch_actual_size
        # Use actual size in case of drop_last=False
        total_train += batch_actual_size

        if (i + 1) % 100 == 0:  # Print stats every 100 batches
            print(f'Epoch [{epoch+1}/{total_epochs}], '  # Use total_epochs
                  f'Step [{i+1}/{len(train_loader)}], '
                  f'Loss: {loss.item():.4f}, '
                  f'LR: {current_lr:.6f}')  # Print loss and LR for this step

    # Calculate epoch metrics (ensure denominators are not zero)
    if total_train > 0:
        epoch_loss_train = running_loss / total_train
    else:
        epoch_loss_train = 0.0

    if total_train_acc_samples > 0:
        epoch_acc_train = 100 * correct_train / total_train_acc_samples
        # Calculate top-5 training accuracy
        epoch_acc_train_top5 = (
            100 * correct_train_top5 / total_train_acc_samples
        )
    else:
        # If no clean batches were processed, accuracy is undefined (or 0)
        epoch_acc_train = 0.0
        epoch_acc_train_top5 = 0.0

    # --- Validation Loop ---
    model.eval()  # Set model to evaluation mode
    running_loss_val = 0.0
    correct_val = 0
    correct_val_top5 = 0  # Initialize top-5 counter
    total_val = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            images = val_gpu_tf(images)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss_val += loss.item() * images.size(0)
            # Top-1 accuracy
            _, predicted = torch.max(outputs.data, 1)
            total_val += labels.size(0)
            correct_val += (predicted == labels).sum().item()

            # Top-5 accuracy
            _, predicted_top5 = torch.topk(outputs, 5, dim=1)
            labels_expanded = labels.view(-1, 1).expand_as(predicted_top5)
            # Check if predicted_top5 contains label
            correct_val_top5 += torch.sum(
                predicted_top5 == labels_expanded
            ).item()

    epoch_loss_val = running_loss_val / total_val
    epoch_acc_val = 100 * correct_val / total_val
    epoch_acc_val_top5 = 100 * correct_val_top5 / total_val
    end_time = time.time()  # Record end time
    epoch_duration = end_time - start_time  # Calculate duration

    # --- Check and update best validation accuracies ---
    if epoch_acc_val > best_val_acc1:
        best_val_acc1 = epoch_acc_val
        best_epoch_acc1 = epoch + 1  # epoch is 0-indexed

    if epoch_acc_val_top5 > best_val_acc5:
        best_val_acc5 = epoch_acc_val_top5
        best_epoch_acc5 = epoch + 1  # epoch is 0-indexed
        # Save the best model state dict (top-5) to CPU
        print(f"  New best top-5 accuracy: {best_val_acc5:.2f}%. Saving model state...")
        best_model_state_dict_top5 = cpu_clone_state_dict(model)

    print(f'Epoch [{epoch+1}/{total_epochs}] -- '  # Use total_epochs
          f'Time: {epoch_duration:.2f}s -- '
          f'Train Loss: {epoch_loss_train:.4f}, '
          f'Train Acc@1: {epoch_acc_train:.2f}%, '
          f'Train Acc@5: {epoch_acc_train_top5:.2f}% -- '
          f'Val Loss: {epoch_loss_val:.4f}, '
          f'Val Acc@1: {epoch_acc_val:.2f}%, '
          f'Val Acc@5: {epoch_acc_val_top5:.2f}% -- '
          f'LR: {current_lr:.6f}') # Log current LR

    # --- Record History ---
    history['train_loss'].append(epoch_loss_train)
    history['train_acc1'].append(epoch_acc_train)
    history['train_acc5'].append(epoch_acc_train_top5)
    history['val_acc1'].append(epoch_acc_val)
    history['val_acc5'].append(epoch_acc_val_top5)
    history['lr'].append(current_lr) # Record current learning rate

print('Finished Training')

# --- Final Summary and Saving ---
print(f"\nPreparing summary and saving results to directory: {SAVE_NAME}...")

# Create the save directory if it doesn't exist
os.makedirs(SAVE_NAME, exist_ok=True)

# Prepare summary content
summary_content = (
    "--- Training Summary ---\n"
    f"Last Epoch ({total_epochs}):\n"
    f"  Train Accuracy@1: {epoch_acc_train:.2f}%\n"
    f"  Train Accuracy@5: {epoch_acc_train_top5:.2f}%\n"
    f"  Validation Accuracy@1: {epoch_acc_val:.2f}%\n"
    f"  Validation Accuracy@5: {epoch_acc_val_top5:.2f}%\n"
    "\nBest Validation Performance:\n"
    f"  Best Validation Accuracy@1: {best_val_acc1:.2f}% (Epoch {best_epoch_acc1})\n"
    f"  Best Validation Accuracy@5: {best_val_acc5:.2f}% (Epoch {best_epoch_acc5})\n"
)

# Define file paths within the save directory
logs_path = os.path.join(SAVE_NAME, 'logs.npz')
summary_path = os.path.join(SAVE_NAME, 'summary.txt')

# Save History to logs.npz
print(f"Saving training history to {logs_path}...")
np.savez(logs_path, **history)

# Save Summary to summary.txt
print(f"Saving summary to {summary_path}...")
with open(summary_path, 'w') as f:
    f.write(summary_content)

# Save Model Checkpoint at Last Epoch
checkpoint_path = os.path.join(SAVE_NAME, 'model_last_epoch.pth')
print(f"Saving model checkpoint to {checkpoint_path}...")
torch.save(model.state_dict(), checkpoint_path)

# Save Best Model Checkpoint (Top-5 Acc) if available
if best_model_state_dict_top5:
    best_checkpoint_path = os.path.join(SAVE_NAME, 'model_best_epoch_top_5.pth')
    print(f"Saving best model checkpoint (top-5 validation acc) to {best_checkpoint_path}...")
    torch.save(best_model_state_dict_top5, best_checkpoint_path)
else:
    print("No best model state (top-5) was recorded during training.")

# Copy the configuration file used
config_dest_path = os.path.join(SAVE_NAME, 'configuration_used.yaml')
print(f"Copying configuration file to {config_dest_path}...")
shutil.copy2(CONFIG_PATH, config_dest_path) # copy2 preserves metadata

print("History, summary, model checkpoints, and configuration saved successfully.") # Updated final message