import torch
import torch.nn as nn
import torch.nn.functional as F
import time
# import torchvision  # No longer needed for datasets
import torchvision.transforms as transforms
import torchvision.transforms.v2 as T  # GPU transforms
import numpy as np
import os
import math  # Added for cosine LR schedule in continuation mode
# import math # Import math for schedule calculation - No longer needed? (Check get_params)
from torch.utils.data import Dataset, DataLoader
import torch.distributions as dist  # Import distributions for Beta
# import h5py # No longer needed
from tqdm import tqdm # Import tqdm for loading progress
import multiprocessing # Import multiprocessing
import functools # Import functools for partial
import yaml  # Import YAML library
import argparse # Import argparse for command-line arguments
from typing import Dict # Import Dict for type hinting
import shutil # Import shutil for file copying

from ldKAN.dkan_2d import DKAN_2D_Layer # Import DKAN layer

def cpu_clone_state_dict(model: nn.Module) -> Dict[str, torch.Tensor]:
    return {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

# --- Argument Parsing ---
def parse_args():
    parser = argparse.ArgumentParser(description='Train a SimpleCNN with DKAN layers on downsampled ImageNet-1k')
    parser.add_argument('--config-path', type=str, required=True,
                        help='Path to the YAML configuration file.')
    parser.add_argument('--save-name', type=str, required=True,
                        help='Directory path to save the training history and summary for this run.')
    parser.add_argument('--calc-type', type=str, required=True, choices=['initial', 'continuation'],
                        help='Type of calculation: \'initial\' starts from scratch, \'continuation\' loads a previous run.')
    parser.add_argument('--start-calc-folder', type=str, default=None,
                        help='Directory path of the previous run to load for \'continuation\' mode.')
    return parser.parse_args()

# Parse command-line arguments early
args = parse_args()

# --- Worker function for parallel loading ---
def load_image_worker(index: int, image_dir: str) -> np.ndarray:
    """Loads a single .npy image file."""
    img_path = os.path.join(image_dir, f"{index}.npy")
    # No try-except here, rely on the main process catching errors via imap
    return np.load(img_path)

# --- Helper function for CutMix ---
def rand_bbox(size, lam):
    """Generates random bounding box coordinates for CutMix."""
    W = size[2]
    H = size[3]
    # Calculate cut ratio
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    # Uniformly sample the center of the patch
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    # Calculate box coordinates, clamping to image boundaries
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2


class PatchDKAN(nn.Module):
    """
    3×3‑stride‑3 convolution implemented with one DKAN_2D_Layer that is
    applied independently to every non‑overlapping 3×3 patch.
    Handles *input* padding for compatibility if needed.
    Output slicing is *not* handled here, assumes out_channels is compatible.
    """
    def __init__(self, in_channels: int, out_channels: int, k: int,
                 n_chunks: int, block_size_forward: int, block_size_backward: int,
                 tile_size_forward: int, tile_size_backward: int, init_scale: float):
        super().__init__()
        self.k = k
        self.in_channels = in_channels
        self.out_channels = out_channels # Assume this is divisible by tile_size_forward or handled upstream
        self.tile_size_forward = tile_size_forward

        # Calculate input dimension for DKAN layer
        self.patch_feature_dim = in_channels * k * k
        # Pad input dimension only if needed
        dkan_input_dim = ((self.patch_feature_dim - 1) // tile_size_forward + 1) * tile_size_forward
        self.input_padding = dkan_input_dim - self.patch_feature_dim

        # Output dimension passed to DKAN is the intended out_channels
        # (We rely on the guarantee that intermediate widths are divisible)
        dkan_output_dim = out_channels

        # BatchNorm *after* DKAN, operates on the DKAN output dimension
        self.bn = nn.BatchNorm1d(dkan_output_dim, affine=False)

        # DKAN layer
        self.dkan = DKAN_2D_Layer(
            n_chunks=n_chunks,
            input_dim=dkan_input_dim,
            output_dim=dkan_output_dim, # Use target out_channels directly
            block_size_forward=block_size_forward,
            block_size_backward=block_size_backward,
            tile_size_forward=tile_size_forward,
            tile_size_backward=tile_size_backward,
            apply_scale=False,
            apply_bias=False,
            cdf_grid=True,
            apply_tanh=False,
            init_scale=init_scale,
            batch_last=True,
            backward_fast_mode=True
        )

    def forward(self, x: torch.Tensor, weight_dkan: float, apply_relu_linear: bool) -> torch.Tensor:
        b, c, h, w = x.shape
        k = self.k
        assert h % k == 0 and w % k == 0, "H and W must be divisible by kernel/stride"
        h2, w2 = h // k, w // k
        num_patches = h2 * w2

        # 1. Gather patches and flatten spatially
        x = F.pixel_unshuffle(x, k)
        x = x.permute(0, 2, 3, 1).reshape(b * num_patches, self.patch_feature_dim)

        # 2. Pad input features if necessary
        if self.input_padding > 0:
            x = F.pad(x, (0, self.input_padding)) # Shape: (b*num_patches, dkan_input_dim)

        # 3. Apply BatchNorm1d - REMOVED FROM HERE
        # x = self.bn(x) # Shape remains (b*num_patches, dkan_input_dim)

        # 4. Transpose for DKAN
        x = x.transpose(0, 1).contiguous() # Shape: (dkan_input_dim, b*num_patches)

        # 5. Apply DKAN layer
        x = self.dkan(x, weight_dkan=weight_dkan, apply_relu_linear=apply_relu_linear, relu_last=False) # Shape: (dkan_output_dim, b*num_patches)

        # 6. Transpose back
        x = x.transpose(0, 1) # Shape: (b*num_patches, dkan_output_dim / out_channels)

        # 7. Apply BatchNorm1d - MOVED HERE
        x = self.bn(x) # Shape: (b*num_patches, dkan_output_dim / out_channels)

        # 8. Output slicing removed - assume output dim is correct

        # 9. Restore spatial layout
        x = x.reshape(b, num_patches, self.out_channels).transpose(1, 2).reshape(b, self.out_channels, h2, w2)

        return x

    def get_frobenius_regularization(self):
        return self.dkan.get_frobenius_regularization()


class SimpleCNNDKAN(nn.Module):
    """
    CNN using PatchDKAN layers and DKAN fully connected layers.
    Input:  (B, C, H, W) - e.g., (B, 3, CROP_SIZE, CROP_SIZE)
    Output: (B, num_classes)
    """
    def __init__(self, width: int, num_classes: int, k: int,
                 n_chunks: int, block_size_forward: int, block_size_backward: int,
                 tile_size_forward: int, tile_size_backward: int, init_scale: float):
        super().__init__()
        self.width = width
        self.num_classes = num_classes
        self.tile_size_forward = tile_size_forward
        self.k = k # Store patch size

        # PatchDKAN layers
        self.pd1 = PatchDKAN(3,          width,      k, n_chunks, block_size_forward, block_size_backward, tile_size_forward, tile_size_backward, init_scale) # 3 channels in
        self.pd2 = PatchDKAN(width,      3 * width,  k, n_chunks, block_size_forward, block_size_backward, tile_size_forward, tile_size_backward, init_scale)
        self.pd3 = PatchDKAN(3 * width,  9 * width,  k, n_chunks, block_size_forward, block_size_backward, tile_size_forward, tile_size_backward, init_scale)
        self.pd4 = PatchDKAN(9 * width, 27 * width,  k, n_chunks, block_size_forward, block_size_backward, tile_size_forward, tile_size_backward, init_scale)

        fc1_input_dim = 27 * width
        fc1_output_dim = 27 * width # Keep same dimension for fc1
        # No padding/slicing needed for FC1 as input/output dims (27*width) are divisible
        dkan_fc1_input_dim = fc1_input_dim
        dkan_fc1_output_dim = fc1_output_dim
        # self.fc1_input_padding = 0 # Removed
        # self.fc1_output_slicing = 0 # Removed

        # BatchNorm before FC1 DKAN (operates on original dimension) - REMOVED
        # self.bn_fc1 = nn.BatchNorm1d(dkan_fc1_input_dim, affine=False)

        self.dkan_fc1 = DKAN_2D_Layer(
            n_chunks, dkan_fc1_input_dim, dkan_fc1_output_dim, block_size_forward, block_size_backward,
            tile_size_forward, tile_size_backward, False, False, True, False, init_scale, True, True
        )

        fc2_input_dim = fc1_output_dim # Output of fc1 is input to fc2 (27*width)
        # Input dim is divisible, no padding needed
        dkan_fc2_input_dim = fc2_input_dim
        # self.fc2_input_padding = 0 # Removed

        # Calculate padded output dimension for FC2 DKAN only if num_classes isn't divisible
        dkan_fc2_output_dim = ((num_classes - 1) // tile_size_forward + 1) * tile_size_forward
        self.fc2_output_slicing = dkan_fc2_output_dim - num_classes # How many channels to slice off

        self.target_output_dim = num_classes # Store original target dim

        # BatchNorm before FC2 DKAN (operates on original input dimension)
        self.bn_fc2 = nn.BatchNorm1d(dkan_fc2_input_dim, affine=False)

        self.dkan_fc2 = DKAN_2D_Layer(
            n_chunks, dkan_fc2_input_dim, dkan_fc2_output_dim, # Use unpadded input dim, padded output dim
            block_size_forward, block_size_backward,
            tile_size_forward, tile_size_backward, False, False, True, False, init_scale, True, True
        )

    def forward(self, x, weight_dkan: float):
        # PatchDKAN layers
        x = self.pd1(x, weight_dkan, apply_relu_linear=False)
        x = self.pd2(x, weight_dkan, apply_relu_linear=True)
        x = self.pd3(x, weight_dkan, apply_relu_linear=True)
        x = self.pd4(x, weight_dkan, apply_relu_linear=True) # Output: (B, 27*width, 1, 1)

        # Flatten for FC layers
        x = torch.flatten(x, 1)     # (B, 27*width)

        # --- FC1 Layer ---
        # No input padding needed
        # Apply BN - REMOVED
        # x = self.bn_fc1(x) # Operates on 27*width
        # Transpose, DKAN, Transpose back
        x = x.transpose(0, 1).contiguous()
        x = self.dkan_fc1(x, weight_dkan, apply_relu_linear=True, relu_last=False) # Input/Output dim is 27*width
        x = x.transpose(0, 1)
        # No output slicing needed

        # --- FC2 Layer ---
        # No input padding needed
        # Apply BN
        x = self.bn_fc2(x) # Operates on 27*width
        # Transpose, DKAN, Transpose back
        x = x.transpose(0, 1).contiguous()
        # Final layer - no relu in linear part

        x = self.dkan_fc2(x, weight_dkan, apply_relu_linear=True, relu_last=False) # Input 27*width, Output dkan_fc2_output_dim
        x = x.transpose(0, 1)
        # Slice final output if necessary to match num_classes
        if self.fc2_output_slicing > 0:
            x = x[:, :self.target_output_dim] # Slice last dim

        return x.contiguous() # Ensure final output is contiguous

    def get_frobenius_regularization(self):
        reg = 0.0
        reg += self.pd1.get_frobenius_regularization()
        reg += self.pd2.get_frobenius_regularization()
        reg += self.pd3.get_frobenius_regularization()
        reg += self.pd4.get_frobenius_regularization()
        reg += self.dkan_fc1.get_frobenius_regularization()
        reg += self.dkan_fc2.get_frobenius_regularization()
        return reg


# --- Global Constants ---
# Define path to the configuration file from command-line arguments
CONFIG_PATH = args.config_path
# Define save name constant from command-line arguments
SAVE_NAME = args.save_name

# Load configuration from YAML file
print(f"Loading configuration from: {CONFIG_PATH}")
with open(CONFIG_PATH, 'r') as f:
    config_values = yaml.safe_load(f)

# --- Configuration Processing (Common) ---
# Assign common constants
BATCH_SIZE = config_values['BATCH_SIZE']
MIXUP_ALPHA = config_values['MIXUP_ALPHA']
MIXUP_PROB = config_values['MIXUP_PROB']
CUTMIX_BETA = config_values['CUTMIX_BETA']
CUTMIX_PROB = config_values['CUTMIX_PROB']
PREPROCESSED_PATH = config_values['PREPROCESSED_PATH']
WIDTH = config_values['WIDTH']
NUM_CLASSES = config_values['NUM_CLASSES']
NUM_WORKERS = config_values['NUM_WORKERS']
CROP_SIZE = config_values['CROP_SIZE']
N_CHUNKS = config_values['N_CHUNKS']
TILE_SIZE_FORWARD = config_values['TILE_SIZE_FORWARD']
TILE_SIZE_BACKWARD = config_values['TILE_SIZE_BACKWARD']
BLOCK_SIZE_FORWARD = config_values['BLOCK_SIZE_FORWARD']
BLOCK_SIZE_BACKWARD = config_values['BLOCK_SIZE_BACKWARD']
INIT_SCALE = float(config_values['INIT_SCALE'])
PATCH_KERNEL_SIZE = config_values.get('PATCH_KERNEL_SIZE', 3) # Default patch size if not specified

# --- Calculation Mode Specific Setup ---
start_epoch = 0
if args.calc_type == 'initial':
    print("Running in 'initial' mode.")
    # Load schedule parameters for initial training
    WARMUP_EPOCHS = config_values['WARMUP_EPOCHS']
    PURE_MLP_EPOCHS = config_values['PURE_MLP_EPOCHS']
    DKAN_TURN_ON_EPOCHS = config_values['DKAN_TURN_ON_EPOCHS']
    DKAN_TURN_ON_SCALE = config_values['DKAN_TURN_ON_SCALE']
    DKAN_TURN_ON_CAP = float(config_values['DKAN_TURN_ON_CAP'])
    DKAN_FROBENIUS_DECAY_EPOCHS = config_values['DKAN_FROBENIUS_DECAY_EPOCHS']
    DKAN_FROBENIUS_DECAY_SCALE = config_values['DKAN_FROBENIUS_DECAY_SCALE']
    FROBENIUS_WEIGHT_CAP = float(config_values['FROBENIUS_WEIGHT_CAP'])
    INITIAL_FROBENIUS_WEIGHT = float(config_values['INITIAL_FROBENIUS_WEIGHT'])
    DKAN_BASE_LR = float(config_values['DKAN_BASE_LR'])
    PURE_MLP_BASE_LR = float(config_values['PURE_MLP_BASE_LR'])

    # Calculate total epochs for initial training
    total_epochs = (
        WARMUP_EPOCHS + PURE_MLP_EPOCHS + DKAN_TURN_ON_EPOCHS +
        DKAN_FROBENIUS_DECAY_EPOCHS
    )

    # Initialize history
    history = {
        'train_loss': [], 'train_pure_loss': [], 'train_acc1': [],
        'train_acc5': [], 'val_acc1': [], 'val_acc5': [], 'lr': [],
        'dkan_weight': [], 'frobenius_weight': []
    }

elif args.calc_type == 'continuation':
    print(f"Running in 'continuation' mode.")
    if not args.start_calc_folder:
        raise ValueError("--start_calc_folder is required for 'continuation' mode.")
    if not os.path.isdir(args.start_calc_folder):
        raise FileNotFoundError(f"Start calculation folder not found: {args.start_calc_folder}")

    # Load continuation parameters
    LEARNING_RATE_CONTINUATION = float(config_values['LEARNING_RATE_CONTINUATION'])
    FROBENIUS_WEIGHT_CONTINUATION = float(config_values['FROBENIUS_WEIGHT_CONTINUATION'])
    N_EPOCHS_CONTINUATION = config_values['N_EPOCHS_CONTINUATION']

    # Load history
    start_logs_path = os.path.join(args.start_calc_folder, 'logs.npz')
    if not os.path.exists(start_logs_path):
        raise FileNotFoundError(f"Previous logs not found: {start_logs_path}")
    print(f"Loading previous history from: {start_logs_path}")
    loaded_history_np = np.load(start_logs_path)
    # Convert loaded numpy arrays back to lists
    history = {k: list(v) for k, v in loaded_history_np.items()}
    loaded_history_np.close()

    # Determine start epoch and total epochs for continuation
    if not history['lr']: # Check if history is empty
        raise ValueError("Loaded history is empty, cannot determine start epoch.")
    start_epoch = len(history['lr']) # Epochs are 0-indexed, length gives next epoch number
    total_epochs = start_epoch + N_EPOCHS_CONTINUATION
    print(f"Resuming from epoch {start_epoch}. Will train for {N_EPOCHS_CONTINUATION} epochs until epoch {total_epochs}.")

    # Get constant DKAN weight from the last recorded value
    if not history['dkan_weight']:
        raise ValueError("Loaded history does not contain 'dkan_weight'. Cannot determine continuation weight.")
    DKAN_WEIGHT_CONTINUATION = history['dkan_weight'][-1]
    print(f"Using constant DKAN weight from last history entry: {DKAN_WEIGHT_CONTINUATION:.3f}")

    # Load checkpoint
    start_checkpoint_path = os.path.join(args.start_calc_folder, 'model_last_epoch.pth')
    if not os.path.exists(start_checkpoint_path):
        raise FileNotFoundError(f"Previous checkpoint not found: {start_checkpoint_path}")
    print(f"Loading checkpoint from: {start_checkpoint_path}")
    checkpoint = torch.load(start_checkpoint_path, map_location='cpu') # Load to CPU first

else:
    raise ValueError(f"Invalid calc_type: {args.calc_type}")

print("Configuration loaded successfully.")

# --- Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# -----------------------------
# GPU-side augmentation pipelines
# -----------------------------
# --- CPU-side transforms (ensure fixed spatial size before collation) ---
cpu_train_tf = transforms.Compose([
    transforms.ToPILImage(),          # (H, W, C) uint8 → PIL
    transforms.RandomCrop(CROP_SIZE), # Random crop for training
    transforms.ToTensor(),            # → (C, H, W) float32 in [0,1]
])

cpu_val_tf = transforms.Compose([
    transforms.ToPILImage(),          # (H, W, C) uint8 → PIL
    transforms.CenterCrop(CROP_SIZE), # Center crop for validation
    transforms.ToTensor(),
])

# Training pipeline on GPU (mirrors original Compose order & hyper-params)
train_gpu_tf = nn.Sequential(
    T.RandomHorizontalFlip(),
    T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    T.RandAugment(),
    T.RandomErasing(p=0.25, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
).to(device)

# Validation pipeline on GPU
val_gpu_tf = nn.Sequential(
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
).to(device)

# --- Data Preparation (Loading individual NPY files into memory) ---
print(f"Loading preprocessed data from: {PREPROCESSED_PATH}")
start_load_time = time.time()

# Define paths for label files and image directories
train_labels_path = os.path.join(PREPROCESSED_PATH, "train_labels.npy")
val_labels_path = os.path.join(PREPROCESSED_PATH, "val_labels.npy")
train_images_dir = os.path.join(PREPROCESSED_PATH, "train_images")
val_images_dir = os.path.join(PREPROCESSED_PATH, "val_images")

# Load labels first to know the number of images
print(f"Loading labels from {train_labels_path} and {val_labels_path}...")
train_labels_np = np.load(train_labels_path)
val_labels_np = np.load(val_labels_path)
num_train_images = len(train_labels_np)
num_val_images = len(val_labels_np)
print(f"Found {num_train_images} train labels and {num_val_images} val labels.")

# Load individual image .npy files into lists using multiprocessing
train_images_list_np = []
val_images_list_np = []

# Use NUM_WORKERS for parallel loading, adjust if needed
loading_workers = NUM_WORKERS
print(f"Using {loading_workers} workers for parallel image loading...")

print(f"Loading {num_train_images} training images from {train_images_dir}...")
# Prepare partial function for train images
train_load_func = functools.partial(load_image_worker, image_dir=train_images_dir)
with multiprocessing.Pool(processes=loading_workers) as pool:
    # Use imap to preserve order and wrap with tqdm for progress
    train_iterator = pool.imap(train_load_func, range(num_train_images))
    train_images_list_np = list(tqdm(train_iterator, total=num_train_images, desc="Loading train images"))

print(f"Loading {num_val_images} validation images from {val_images_dir}...")
# Prepare partial function for val images
val_load_func = functools.partial(load_image_worker, image_dir=val_images_dir)
with multiprocessing.Pool(processes=loading_workers) as pool:
    # Use imap to preserve order and wrap with tqdm for progress
    val_iterator = pool.imap(val_load_func, range(num_val_images))
    val_images_list_np = list(tqdm(val_iterator, total=num_val_images, desc="Loading val images"))


# Strict check to ensure loaded image count matches label count.
# This should always pass now that try...except is removed,
# but provides an explicit error if something unexpected occurred.
if len(train_images_list_np) != num_train_images:
    raise ValueError(
        f"Mismatch after loading: Number of train labels ({num_train_images}) "
        f"does not match number of loaded train images ({len(train_images_list_np)}). "
        f"Check data integrity in {train_images_dir}"
    )
if len(val_images_list_np) != num_val_images:
    raise ValueError(
        f"Mismatch after loading: Number of val labels ({num_val_images}) "
        f"does not match number of loaded val images ({len(val_images_list_np)}). "
        f"Check data integrity in {val_images_dir}"
    )

load_time = time.time() - start_load_time
print(f"Data loaded in {load_time:.2f}s")

# Print info about the loaded lists
print(f"Loaded {len(train_images_list_np)} training images (list of np arrays)")
if train_images_list_np:
    print(f"  Example shape of first train image: {train_images_list_np[0].shape}, "
          f"dtype: {train_images_list_np[0].dtype}")
print(
    f"Train Labels Shape: {train_labels_np.shape}, "
    f"dtype: {train_labels_np.dtype}"
)
print(f"Loaded {len(val_images_list_np)} validation images (list of np arrays)")
if val_images_list_np:
    print(f"  Example shape of first val image: {val_images_list_np[0].shape}, "
          f"dtype: {val_images_list_np[0].dtype}")
print(
    f"Val Labels Shape: {val_labels_np.shape}, "
    f"dtype: {val_labels_np.dtype}"
)


# Convert Labels to Tensors (Images remain as list of NumPy arrays)
print("Converting labels NumPy arrays to Tensors...")
train_labels = torch.from_numpy(train_labels_np).long()
val_labels = torch.from_numpy(val_labels_np).long()

print(
    f"Train Labels Tensor Shape: {train_labels.shape}, "
    f"dtype: {train_labels.dtype}"
)
print(
    f"Val Labels Tensor Shape: {val_labels.shape}, "
    f"dtype: {val_labels.dtype}"
)


# Define a custom Dataset that handles a list of NumPy arrays in memory
# Renaming back to something descriptive
class InMemoryListDataset(Dataset):
    def __init__(self, images_list, labels, transform=None):
        # Store the list of numpy arrays (H, W, C) uint8
        self.images_list = images_list
        self.labels = labels # Already a tensor
        self.transform = transform

        # Add a check to ensure images and labels align after loading
        if len(self.images_list) != len(self.labels):
            raise ValueError(
                f"Number of images ({len(self.images_list)}) does not match "
                f"number of labels ({len(self.labels)}). "
                "Check data loading and potential skipped files."
            )

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        # Get the numpy array (H, W, C) uint8
        image_np = self.images_list[idx]
        label = self.labels[idx]
        # Apply transforms which should handle conversion to PIL and then Tensor
        if self.transform:
            # Transforms expect PIL or Tensor (C, H, W)
            # ToPILImage handles (H, W, C) numpy array directly
            image = self.transform(image_np)
        else:
            # If no transform, convert manually (basic case)
            image = transforms.ToTensor()(image_np) # (H, W, C) -> (C, H, W) float

        return image, label


# --- Define Transforms for In-Memory NumPy Arrays ---
# --- CPU→GPU Data Augmentation Split ---
#   • The dataset will now only convert uint8 NumPy arrays to float tensors in [0,1]
#   • All heavy augmentations (crop, jitter, RandAugment, erasing, normalize) will
#     be executed on the GPU using torchvision.transforms.v2.  Definitions of the
#     GPU pipelines are placed right after the `device` variable is created below.


# --- Create Datasets and DataLoaders ---
print("Creating InMemoryListDataset instances and DataLoaders...")

train_dataset = InMemoryListDataset(
    train_images_list_np, train_labels, transform=cpu_train_tf
)
val_dataset = InMemoryListDataset(
    val_images_list_np, val_labels, transform=cpu_val_tf
)

# DataLoader loads batches from the InMemoryListDataset
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=True
)

val_loader = DataLoader(
    val_dataset, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=True
)
print("InMemoryListDataset and DataLoaders created successfully.")


# --- Training Schedule Function (for 'initial' mode) ---
def get_params(epoch):
    """Calculates LR, DKAN weight, and Frobenius weight based on epoch for 'initial' mode."""
    # This function is ONLY used for 'initial' mode scheduling.
    # Global schedule variables (WARMUP_EPOCHS etc.) must be defined beforehand.

    phase1_end = WARMUP_EPOCHS
    phase2_end = phase1_end + PURE_MLP_EPOCHS
    phase3_end = phase2_end + DKAN_TURN_ON_EPOCHS
    phase4_end = phase3_end + DKAN_FROBENIUS_DECAY_EPOCHS

    if epoch < phase1_end: # 0. Warmup Phase
        lr = PURE_MLP_BASE_LR * (epoch + 1) / max(1, WARMUP_EPOCHS)
        dkan_weight = 0.0
        frobenius_weight = INITIAL_FROBENIUS_WEIGHT
    elif epoch < phase2_end: # 1. Pure MLP Phase
        lr = PURE_MLP_BASE_LR
        dkan_weight = 0.0
        frobenius_weight = INITIAL_FROBENIUS_WEIGHT
    elif epoch < phase3_end: # 2. DKAN Turn-on Phase
        offset_epoch = epoch - phase2_end
        lr = DKAN_BASE_LR
        current_turn_on_scale = max(1, DKAN_TURN_ON_SCALE)
        dkan_weight = min((offset_epoch + 1) / current_turn_on_scale, DKAN_TURN_ON_CAP)
        frobenius_weight = INITIAL_FROBENIUS_WEIGHT
    elif epoch < phase4_end: # 3. Frobenius Decay Phase (Now the last phase)
        offset_epoch = epoch - phase3_end
        lr = DKAN_BASE_LR
        current_turn_on_scale = max(1, DKAN_TURN_ON_SCALE)
        dkan_weight_at_phase_end = min(DKAN_TURN_ON_EPOCHS / current_turn_on_scale, DKAN_TURN_ON_CAP)
        dkan_weight = dkan_weight_at_phase_end
        current_fro_decay_scale = max(1, DKAN_FROBENIUS_DECAY_SCALE)
        frobenius_weight = INITIAL_FROBENIUS_WEIGHT / (10 ** ((offset_epoch + 1) / current_fro_decay_scale))
        frobenius_weight = max(frobenius_weight, FROBENIUS_WEIGHT_CAP)
    else:
        # This should not be reached if total_epochs is calculated correctly for 'initial' mode
        raise RuntimeError(f"Epoch {epoch} is outside the expected range [0, {phase4_end}) for initial mode.")

    return lr, dkan_weight, frobenius_weight

# --- Model, Loss, Optimizer Initialization ---
print("Initializing DKAN model...")
model = SimpleCNNDKAN(
    width=WIDTH, num_classes=NUM_CLASSES, k=PATCH_KERNEL_SIZE,
    n_chunks=N_CHUNKS,
    block_size_forward=BLOCK_SIZE_FORWARD, block_size_backward=BLOCK_SIZE_BACKWARD,
    tile_size_forward=TILE_SIZE_FORWARD, tile_size_backward=TILE_SIZE_BACKWARD,
    init_scale=INIT_SCALE
).to(device)

# Loss function remains CrossEntropy for evaluation, manual calculation during training
criterion = nn.CrossEntropyLoss()

# Optimizer - Initial LR is placeholder, will be set based on mode
# For initial mode, it's set by get_params(0) later.
# For continuation mode, it's loaded from checkpoint, and LR set by constant.
optimizer = torch.optim.Adam(model.parameters(), lr=0.0) # Initialize with dummy LR

# --- Load State if Continuation ---
if args.calc_type == 'continuation':
    print("Loading model and optimizer state dicts...")
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    # Make sure model is on the correct device after loading state
    model.to(device)
    # Optimizer state might need device transfer if saved from GPU
    # Adam typically stores state on the same device as parameters, so this might be okay
    # If issues arise, manually move optimizer state tensors:
    # for state in optimizer.state.values():
    #     for k, v in state.items():
    #         if isinstance(v, torch.Tensor):
    #             state[k] = v.to(device)
    print("Model and optimizer states loaded.")


# --- Variables to track best validation accuracy (for the current run) ---
best_val_acc1 = 0.0
best_epoch_acc1 = 0
best_val_acc5 = 0.0
best_epoch_acc5 = 0
best_model_state_dict_top5 = None # Tracks best model *during this run*


# --- Training Loop ---
print(f"Starting training from epoch {start_epoch} to {total_epochs}...")

for epoch in range(start_epoch, total_epochs):
    start_time = time.time()  # Record start time

    # --- Get parameters for this epoch based on mode ---
    if args.calc_type == 'initial':
        current_lr, current_dkan_weight, current_frobenius_weight = get_params(epoch)
    elif args.calc_type == 'continuation':
        # Cosine decay learning rate schedule for continuation phase
        t_cont = epoch - start_epoch  # Epoch index within continuation phase (0-based)
        t_max = max(1, N_EPOCHS_CONTINUATION - 1)  # Ensure no division by zero
        current_lr = 0.5 * LEARNING_RATE_CONTINUATION * (1 + math.cos(math.pi * t_cont / t_max))
        current_dkan_weight = DKAN_WEIGHT_CONTINUATION
        current_frobenius_weight = FROBENIUS_WEIGHT_CONTINUATION
    else: # Should not happen due to earlier check
        raise ValueError(f"Invalid calc_type: {args.calc_type}")

    # Update optimizer's learning rate
    for param_group in optimizer.param_groups:
        param_group['lr'] = current_lr

    model.train()  # Set model to training mode
    running_loss = 0.0
    running_pure_loss = 0.0 # Track pure loss separately
    correct_train = 0
    total_train = 0
    correct_train_top5 = 0
    total_train_acc_samples = 0

    mixup_beta_distribution = dist.Beta(MIXUP_ALPHA, MIXUP_ALPHA)
    cutmix_beta_distribution = dist.Beta(CUTMIX_BETA, CUTMIX_BETA)

    # Wrap train_loader with tqdm for progress bar
    # Use initial=start_epoch only if tqdm supports it well for continuation, otherwise just show current epoch
    train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{total_epochs}", leave=False)

    for i, (images, labels) in enumerate(train_pbar):
        images = images.to(device, non_blocking=True)
        images = train_gpu_tf(images)
        labels = labels.to(device, non_blocking=True)
        batch_actual_size = images.size(0)
        pure_loss = 0.0 # Loss before regularization

        # --- Augmentation Selection ---
        rand_prob = torch.rand(1).item()
        apply_mixup = rand_prob < MIXUP_PROB and MIXUP_ALPHA > 0
        apply_cutmix = MIXUP_PROB <= rand_prob < MIXUP_PROB + CUTMIX_PROB and CUTMIX_BETA > 0

        if apply_mixup:
            # --- Mixup Implementation ---
            gamma = mixup_beta_distribution.sample((batch_actual_size,)).to(device)
            indices = torch.randperm(batch_actual_size, device=device)
            images_j = images[indices]
            labels_j = labels[indices]
            gamma_img = gamma.view(batch_actual_size, 1, 1, 1)
            mixed_images = gamma_img * images + (1.0 - gamma_img) * images_j

            labels_one_hot = F.one_hot(labels, num_classes=NUM_CLASSES).float()
            labels_j_one_hot = F.one_hot(labels_j, num_classes=NUM_CLASSES).float()
            gamma_lbl = gamma.view(batch_actual_size, 1)
            mixed_labels = gamma_lbl * labels_one_hot + (1.0 - gamma_lbl) * labels_j_one_hot

            # Forward pass with mixed images and current DKAN weight
            outputs = model(mixed_images, current_dkan_weight)

            # Calculate custom cross-entropy loss for soft labels
            log_softmax_outputs = F.log_softmax(outputs, dim=1)
            pure_loss = -torch.sum(mixed_labels * log_softmax_outputs, dim=1).mean()

            # Accuracy calculation skipped for augmented batches

        elif apply_cutmix:
            # --- CutMix Implementation ---
            lam = cutmix_beta_distribution.sample().item()
            rand_index = torch.randperm(batch_actual_size, device=device)
            target_a = labels
            target_b = labels[rand_index]
            bbx1, bby1, bbx2, bby2 = rand_bbox(images.size(), lam)
            images_cutmix = images.clone()
            images_cutmix[:, :, bbx1:bbx2, bby1:bby2] = images[rand_index, :, bbx1:bbx2, bby1:bby2]
            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (images.size()[-1] * images.size()[-2]))

            # Forward pass with CutMix images and current DKAN weight
            outputs = model(images_cutmix, current_dkan_weight)

            target_a_one_hot = F.one_hot(target_a, num_classes=NUM_CLASSES).float()
            target_b_one_hot = F.one_hot(target_b, num_classes=NUM_CLASSES).float()
            mixed_labels = lam * target_a_one_hot + (1.0 - lam) * target_b_one_hot
            log_softmax_outputs = F.log_softmax(outputs, dim=1)
            pure_loss = -torch.sum(mixed_labels * log_softmax_outputs, dim=1).mean()

            # Accuracy calculation skipped for augmented batches

        else:  # Standard training step (no Mixup or CutMix)
            # Forward pass with current DKAN weight
            outputs = model(images, current_dkan_weight)

            # Calculate cross-entropy loss (this is the pure loss)
            # Manual calculation for consistency with augmented branches:
            labels_one_hot = F.one_hot(labels, num_classes=NUM_CLASSES).float()
            log_softmax_outputs = F.log_softmax(outputs, dim=1)
            pure_loss = -torch.sum(labels_one_hot * log_softmax_outputs, dim=1).mean()

            # Accuracy (standard calculation only for non-augmented batches)
            with torch.no_grad():
                _, predicted = torch.max(outputs.data, 1)
                correct_train += (predicted == labels).sum().item()
                _, predicted_top5 = torch.topk(outputs, 5, dim=1)
                labels_expanded = labels.view(-1, 1).expand_as(predicted_top5)
                correct_train_top5 += torch.sum(predicted_top5 == labels_expanded).item()
            total_train_acc_samples += batch_actual_size

        # --- Add Frobenius Regularization ---
        fro_reg = model.get_frobenius_regularization()
        loss = pure_loss + current_frobenius_weight * fro_reg

        # Backward and optimize (common step)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Accumulate losses
        running_loss += loss.item() * batch_actual_size
        running_pure_loss += pure_loss.item() * batch_actual_size # Accumulate pure loss
        total_train += batch_actual_size

        # Update tqdm progress bar description with current stats
        if (i + 1) % 50 == 0: # Update progress bar less frequently
             train_pbar.set_postfix({
                 'Loss': f'{loss.item():.4f}',
                 'PureLoss': f'{pure_loss.item():.4f}',
                 'LR': f'{current_lr:.6f}',
                 'DKAN_w': f'{current_dkan_weight:.3f}',
                 'Fro_w': f'{current_frobenius_weight:.4g}'
             })

    train_pbar.close() # Close the tqdm bar for the epoch

    # Calculate epoch metrics
    epoch_loss_train = running_loss / total_train if total_train > 0 else 0.0
    epoch_pure_loss_train = running_pure_loss / total_train if total_train > 0 else 0.0

    if total_train_acc_samples > 0:
        epoch_acc_train = 100 * correct_train / total_train_acc_samples
        epoch_acc_train_top5 = 100 * correct_train_top5 / total_train_acc_samples
    else:
        # If no clean batches were processed, accuracy is undefined (or 0)
        epoch_acc_train = 0.0
        epoch_acc_train_top5 = 0.0

    # --- Validation Loop ---
    model.eval()
    running_loss_val = 0.0
    correct_val = 0
    correct_val_top5 = 0
    total_val = 0
    # Wrap val_loader with tqdm
    val_pbar = tqdm(val_loader, desc="Validation", leave=False)
    with torch.no_grad():
        for images, labels in val_pbar:
            images = images.to(device, non_blocking=True)
            images = val_gpu_tf(images)
            labels = labels.to(device, non_blocking=True)
            # Use the *current* dkan weight (scheduled or constant) for validation
            outputs = model(images, current_dkan_weight)
            # Use standard criterion for validation loss (no regularization)
            loss = criterion(outputs, labels)

            running_loss_val += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total_val += labels.size(0)
            correct_val += (predicted == labels).sum().item()

            _, predicted_top5 = torch.topk(outputs, 5, dim=1)
            labels_expanded = labels.view(-1, 1).expand_as(predicted_top5)
            correct_val_top5 += torch.sum(predicted_top5 == labels_expanded).item()

    val_pbar.close() # Close val tqdm bar

    epoch_loss_val = running_loss_val / total_val if total_val > 0 else 0.0
    epoch_acc_val = 100 * correct_val / total_val if total_val > 0 else 0.0
    epoch_acc_val_top5 = 100 * correct_val_top5 / total_val if total_val > 0 else 0.0
    end_time = time.time()
    epoch_duration = end_time - start_time

    # --- Check and update best validation accuracies (within the current run) ---
    if epoch_acc_val > best_val_acc1:
        best_val_acc1 = epoch_acc_val
        best_epoch_acc1 = epoch + 1

    if epoch_acc_val_top5 > best_val_acc5:
        best_val_acc5 = epoch_acc_val_top5
        best_epoch_acc5 = epoch + 1
        print(f"  New best top-5 accuracy for this run: {best_val_acc5:.2f}%. Saving model state...")
        best_model_state_dict_top5 = cpu_clone_state_dict(model) # Save best model state of *this run*

    # Print epoch summary
    print(f'Epoch [{epoch+1}/{total_epochs}] -- Time: {epoch_duration:.2f}s')
    print(f'  LR: {current_lr:.6f} | DKAN_w: {current_dkan_weight:.3f} | Fro_w: {current_frobenius_weight:.4g}')
    print(f'  Train Loss: {epoch_loss_train:.4f} | Train Pure Loss: {epoch_pure_loss_train:.4f}')
    print(f'  Train Acc@1: {epoch_acc_train:.2f}% | Train Acc@5: {epoch_acc_train_top5:.2f}%')
    print(f'  Val Loss: {epoch_loss_val:.4f} | Val Acc@1: {epoch_acc_val:.2f}% | Val Acc@5: {epoch_acc_val_top5:.2f}%')
    # Note: Best validation below refers to the best achieved *during this specific run*
    print(f'  Best Val Acc@1 (this run): {best_val_acc1:.2f}% (Epoch {best_epoch_acc1}) | Best Val Acc@5 (this run): {best_val_acc5:.2f}% (Epoch {best_epoch_acc5})')


    # --- Record History (Append to potentially loaded history) ---
    history['train_loss'].append(epoch_loss_train)
    history['train_pure_loss'].append(epoch_pure_loss_train)
    history['train_acc1'].append(epoch_acc_train)
    history['train_acc5'].append(epoch_acc_train_top5)
    history['val_acc1'].append(epoch_acc_val)
    history['val_acc5'].append(epoch_acc_val_top5)
    history['lr'].append(current_lr)
    history['dkan_weight'].append(current_dkan_weight)
    history['frobenius_weight'].append(current_frobenius_weight)

print('Finished Training')

# --- Final Summary and Saving ---
print(f"\nPreparing summary and saving results to directory: {SAVE_NAME}...")

# Create the save directory if it doesn't exist
os.makedirs(SAVE_NAME, exist_ok=True)

# Prepare summary content (includes last pure loss and weights)
# Ensure variables exist even if loop didn't run (e.g., N_EPOCHS_CONTINUATION=0)
try:
    last_lr = history['lr'][-1]
    last_dkan_w = history['dkan_weight'][-1]
    last_fro_w = history['frobenius_weight'][-1]
    last_train_loss = history['train_loss'][-1]
    last_train_pure_loss = history['train_pure_loss'][-1]
    last_train_acc1 = history['train_acc1'][-1]
    last_train_acc5 = history['train_acc5'][-1]
    last_val_loss = epoch_loss_val # Use the last calculated validation loss
    last_val_acc1 = history['val_acc1'][-1]
    last_val_acc5 = history['val_acc5'][-1]
except IndexError:
    # Handle case where training loop didn't run (e.g., resuming finished job)
    print("Warning: History is empty or training loop did not run. Summary might be incomplete.")
    last_lr, last_dkan_w, last_fro_w = 0, 0, 0
    last_train_loss, last_train_pure_loss = 0, 0
    last_train_acc1, last_train_acc5 = 0, 0
    last_val_loss, last_val_acc1, last_val_acc5 = 0, 0, 0

summary_content = (
    f"--- Training Summary ({args.calc_type} mode) ---\n"
    f"Total Epochs Run in this Session: {total_epochs - start_epoch}\n"
    f"Total Epochs in History: {len(history['lr'])}\n"
    f"Last Epoch ({len(history['lr'])}):\n"
    f"  Learning Rate: {last_lr:.6f}\n"
    f"  DKAN Weight: {last_dkan_w:.3f}\n"
    f"  Frobenius Weight: {last_fro_w:.4g}\n"
    f"  Train Loss (Total): {last_train_loss:.4f}\n"
    f"  Train Loss (Pure): {last_train_pure_loss:.4f}\n"
    f"  Train Accuracy@1: {last_train_acc1:.2f}%\n"
    f"  Train Accuracy@5: {last_train_acc5:.2f}%\n"
    f"  Validation Loss: {last_val_loss:.4f}\n"
    f"  Validation Accuracy@1: {last_val_acc1:.2f}%\n"
    f"  Validation Accuracy@5: {last_val_acc5:.2f}%\n"
    "\nBest Validation Performance Recorded (during this run):\n"
    f"  Best Validation Accuracy@1: {best_val_acc1:.2f}% (at Epoch {best_epoch_acc1})\n"
    f"  Best Validation Accuracy@5: {best_val_acc5:.2f}% (at Epoch {best_epoch_acc5})\n"
)

# Define file paths within the save directory for this run
logs_path = os.path.join(SAVE_NAME, 'logs.npz')
summary_path = os.path.join(SAVE_NAME, 'summary.txt')

# Save History to logs.npz (includes appended history)
print(f"Saving training history to {logs_path}...")
np.savez(logs_path, **history)

# Save Summary to summary.txt
print(f"Saving summary to {summary_path}...")
with open(summary_path, 'w') as f:
    f.write(summary_content)

# Save Model Checkpoint at Last Epoch of this run
checkpoint_path = os.path.join(SAVE_NAME, 'model_last_epoch.pth')
print(f"Saving model checkpoint from last epoch of this run to {checkpoint_path}...")
last_epoch_state = {
    'model_state_dict': cpu_clone_state_dict(model),
    'optimizer_state_dict': optimizer.state_dict(),
}
torch.save(last_epoch_state, checkpoint_path) # Save combined state

# Save Best Model Checkpoint (Top-5 Acc) from this run if available
if best_model_state_dict_top5:
    best_checkpoint_path = os.path.join(SAVE_NAME, 'model_best_epoch_top_5.pth')
    print(f"Saving best model checkpoint (top-5 validation acc from this run) to {best_checkpoint_path}...")
    torch.save(best_model_state_dict_top5, best_checkpoint_path) # Already on CPU
else:
    print("No better model state (top-5) was recorded during this run.")

# Copy the configuration file used
config_dest_path = os.path.join(SAVE_NAME, 'configuration_used.yaml')
print(f"Copying configuration file to {config_dest_path}...")
shutil.copy2(CONFIG_PATH, config_dest_path) # copy2 preserves metadata

print(f"History, summary, model checkpoints (for {args.calc_type} run), and configuration saved successfully to {SAVE_NAME}.")