import torch
import torch.nn as nn
import numpy as np
import collections
import time
import sys
import os
import argparse
import argparse
import importlib
import network as network_legacy

# Add path for utils
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from dataset_fast import load_dataset
from metrics import compute_inv_propensity, compute_prop_metrics
from scipy import sparse
from tqdm import tqdm

# -----------------------------------------------------------------------------
# Configuration
# -----------------------------------------------------------------------------

DATASETS = ["bibtex", "mediamill"] # Add others here: "eurlex", "wiki10", etc.

# Common params for HLB, VTB, MAP
COMMON_PARAMS = {
    "hidden": 512,
    "out_features": 400,
    "drop_rate": 0.0,
    "factor": 1,
}

# OSC Params
OSC_PARAMS = {
    "hidden": 512,
    "OSC_d": 125,
    "drop_rate": 0.2,
    "num_role_vecs": 125
}

TRAIN_PARAMS = {
    "epochs": 2, # Matching base repo - large datasets need more training
    "batch_size": 64,
    "lr": 1e-3,
    "gamma": 0.98
}

MODELS = {
    "hlb": "network",
    "vtb": "network_vtb",
    "map": "network_map",
    "osc": "network_clifford"
}


# -----------------------------------------------------------------------------
# Dataset Specific Configurations (Extracted from repo scripts)
# -----------------------------------------------------------------------------
DATASET_CONFIGS = {
    "bibtex": {
        "out_features": 400, "drop_rate": 0.0, "factor": 1,
        "requires_grad": True, "negative": False
    },
    "mediamill": {
        "out_features": 400, "drop_rate": 0.0, "factor": 1,
        "requires_grad": False, "negative": True
    },
    "delicious": {
        "out_features": 400, "drop_rate": 0.25, "factor": 1,
        "requires_grad": False, "negative": True
    },
    "eurlex": {
        "out_features": 1600, "drop_rate": 0.35, "factor": 2, # out_features=1600 matching train_vtb.py
        "requires_grad": False, "negative": True, "batch_size": 256,
        "vtb_requires_grad": False, # Explicitly disable learning for VTB on Eurlex
        "osc_d": 200,
        "osc_num_role_vecs": 200
    },
}

def get_dataset_config(dataset_name):
    # Default fallback using Global Params for batch size
    base_config = {
        "out_features": 400, "drop_rate": 0.0, "factor": 1,
        "requires_grad": True, "negative": False,
        "batch_size": TRAIN_PARAMS["batch_size"]
    }
    # Override if known
    ds_key = dataset_name.lower()
    if ds_key == "amazon-13k": ds_key = "amazoncat-13k"
    if ds_key in DATASET_CONFIGS:
        base_config.update(DATASET_CONFIGS[ds_key])
    return base_config

# -----------------------------------------------------------------------------
# Data Loader Proxy (Handles Fast vs Sparse formats)
# -----------------------------------------------------------------------------
from torch.utils.data import DataLoader

class SparseToDenseWrapper:
    """Wraps a loader that returns (x, y_indices) and converts y to dense multi-hot."""
    def __init__(self, loader, num_labels):
        self.loader = loader
        self.num_labels = num_labels

    def __iter__(self):
        for x, y_idxs in self.loader:
            # y_idxs is (Batch, MaxLabels) with padded indices (usually padding=num_labels)
            batch_size = x.shape[0]
            y_hot = torch.zeros((batch_size, self.num_labels), dtype=torch.float32)

            # Scatter ones
            # Filter out padding indices if they exist (usually num_labels is padding)
            # Assuming y_idxs are long/int

            # Slow but safe loop for arbitrary padding logic, or scatter
            for i in range(batch_size):
                indices = y_idxs[i]
                # Filter valid indices < num_labels
                valid = indices[indices < self.num_labels].long()
                y_hot[i, valid] = 1.0

            yield x, y_hot, y_idxs

            yield x, y_hot, y_idxs

    def __len__(self):
        return len(self.loader)

class PreloadedSparseLoader:
    """
    Loads sparse dataset into GPU RAM as sparse tensor to avoid Worker/SHM overhead.
    Parses 'idx:val' format directly from text file.
    """
    def __init__(self, data_path, num_features, num_labels, batch_size, device, desc="Parsing Data"):
        self.batch_size = batch_size
        self.device = device
        self.num_labels = num_labels
        
        cache_path = data_path + ".pt"
        if os.path.exists(cache_path):
            print(f"Loading cached tensors from {cache_path}...")
            cache = torch.load(cache_path)
            self.data_size = cache["data_size"]
            i_tens = cache["indices"]
            v_tens = cache["values"]
            print(f"Loaded cache! Data size: {self.data_size}")
            self.y_list = cache["y_list"]
        else:
            print(f"Parsing raw text file from {data_path}...")
            indices = []
            values = []
            rows = []
            y_list = []
            
            with open(data_path, 'r') as f:
                lines = f.readlines()
                
            # Skip header if it exists and looks like "N D L"
            # Most of our datasets have a header line. Check first line.
            header = lines[0].strip().split(' ')
            if len(header) == 3 and header[0].isdigit():
                 lines = lines[1:]
            
            self.data_size = len(lines)
            
            for i, line in tqdm(enumerate(lines), total=self.data_size, desc=desc):
                parts = line.strip().split(' ')
                
                # Labels (comma sep)
                labs = [int(l) for l in parts[0].split(',') if l]
                y_list.append(torch.tensor(labs, dtype=torch.long))
    
                # Features (idx:val)
                for item in parts[1:]:
                    if not item: continue
                    # Handle potential edge cases where text parsing might fail?
                    # Assuming rigorous format "idx:val"
                    try:
                        idx, val = item.split(':')
                        rows.append(i)
                        indices.append(int(idx))
                        values.append(float(val))
                    except ValueError:
                        continue # Skip malformed features
    
            # Create Sparse Tensor X on GPU
            i_tens = torch.LongTensor([rows, indices])
            v_tens = torch.FloatTensor(values)
            
            # Save to cache
            print(f"Saving cache to {cache_path}...")
            torch.save({
                "data_size": self.data_size,
                "y_list": y_list,
                "indices": i_tens,
                "values": v_tens
            }, cache_path)
            self.y_list = y_list
        
        self.X_gpu = torch.sparse_coo_tensor(i_tens, v_tens, size=(self.data_size, num_features)).to(device)
        # self.y_list is already set inside if/else blocks

    def __iter__(self):
        # Shuffle indices
        perm = torch.randperm(self.data_size)
        
        for i in range(0, self.data_size, self.batch_size):
            batch_idxs = perm[i:i+self.batch_size]
            
            # Slice X (Sparse -> Dense) on GPU
            # index_select on sparse is not fully supported in old torch versions, 
            # but we can try. If fails, we might need strict slicing or coalesce.
            # Using torch.index_select on sparse usually returns sparse. 
            # We want dense batch: (B, F).
            # Workaround: X_gpu is (N, F). batch_idxs is (B).
            # We can use: torch.index_select(self.X_gpu, 0, batch_idxs.to(self.device)).to_dense()
            
            x_batch = torch.index_select(self.X_gpu, 0, batch_idxs.to(self.device)).to_dense()
            
            # Build Y Batch
            # Max labels in this batch?
            batch_y_raw = [self.y_list[idx] for idx in batch_idxs]
            max_len = max([len(y) for y in batch_y_raw])
            
            # Pad indices
            y_idxs = torch.full((len(batch_idxs), max_len), self.num_labels, dtype=torch.long)
            for b, y in enumerate(batch_y_raw):
                 y_idxs[b, :len(y)] = y
            
            # Create Dense Y
            y_hot = torch.zeros((len(batch_idxs), self.num_labels), dtype=torch.float32, device=self.device)
            # Loop is fast enough for batch size 8
            for b, y in enumerate(batch_y_raw):
                 y_hot[b, y.to(self.device)] = 1.0

            yield x_batch, y_hot, y_idxs

    def __len__(self):
        return (self.data_size + self.batch_size - 1) // self.batch_size

def load_dataset_proxy(dataset_name, batch_size):
    """
    Unified loader that switches between dataset_fast (Bibtex/Mediamill/Delicious)
    and dataset_sparse (Eurlex/Amazon/Wiki10/etc) based on dataset name.
    """
    sparse_datasets = ["eurlex", "amazon13k", "wiki10-31k", "delicious200k",
                      "wiki10", "eurlex-4.3k", "deliciouslarge", "amazoncat-13k", "amazon-13k"]

    if dataset_name.lower() in sparse_datasets:
        # Use dataset_sparse
        import dataset_sparse as ds_sparse

        # Construct paths for Sparse format
        dataset_lower = dataset_name.lower()

        # Handle different directory naming conventions
        if dataset_lower == "eurlex":
            root = "../data/Eurlex"
        elif dataset_lower == "wiki10":
            root = "../data/Wiki10"
        elif dataset_lower == "eurlex-4.3k":
            root = "../data/EURLex-4.3K"
        elif dataset_lower == "deliciouslarge":
            root = "../data/DeliciousLarge"
        elif dataset_lower in ["amazoncat-13k", "amazon-13k"]:
            root = "../data/AmazonCat-13K"
        else:
            # Try capitalized version
            root = f"../data/{dataset_name.capitalize()}"
            if not os.path.exists(root):
                root = f"../data/{dataset_name}"

        # Filename conventions vary by dataset
        if dataset_lower == "eurlex":
            train_file = f"{root}/eurlex_train.txt"
            test_file = f"{root}/eurlex_test.txt"
        elif dataset_lower == "deliciouslarge":
            train_file = f"{root}/deliciousLarge_train.txt"
            test_file = f"{root}/deliciousLarge_test.txt"
        else:
            # Standard naming: train.txt / test.txt
            train_file = f"{root}/train.txt"
            test_file = f"{root}/test.txt"

        if not os.path.exists(train_file):
            print(f"Skipping {dataset_name}: File not found {train_file}")
            return None, None, None, None

        # Call sparse loader
        # Using num_workers=4 for maximum performance (~110 it/s on eurlex-4.3k)
        # Note: May cause bus error initially, but auto-retry logic will handle it by reducing batch size
        tr_loader, tst_loader, info = ds_sparse.load_dataset(
            train_file=train_file,
            test_file=test_file,
            batch_size=batch_size,
            num_workers=0,
            pin_memory=True
        )

        # Cache propensity scores (avoid re-computing every trial)
        try:
            # Ensure cache dir exists
            if not os.path.exists(f"{root}/cache"):
                 os.makedirs(f"{root}/cache", exist_ok=True)

            # Need to know the original name for caching
            cache_name = dataset_name.lower()
            # Note: Propensity caching uses the sparse loader internally, which is fine.
            inv_prop = ds_sparse.cache_propensity(tr_loader, info, cache_name, root=root)
        except Exception as e:
            print(f"Warning: Failed to cache/load propensity: {e}. Will compute on fly.")
            inv_prop = None

        if dataset_lower in ["amazoncat-13k", "amazon-13k", "eurlex-4.3k", "deliciouslarge", "wiki10"]:
             # Use Preloaded Loader for all large sparse datasets
             device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
             
             tr_wrap = PreloadedSparseLoader(train_file, info["features"], info["labels"], batch_size, device, desc=f"Loading {dataset_name}")
             # Test loader can remain standard sparse or use preloader. Using preloader ensures consistency.
             # Use smaller batch size for inference to avoid OOM (broadcasting 13k labels)
             test_batch_size = min(batch_size, 128) 
             tst_wrap = PreloadedSparseLoader(test_file, info["features"], info["labels"], test_batch_size, device, desc=f"Loading {dataset_name} (Test)")
             return tr_wrap, tst_wrap, info, inv_prop

        # Wrap loaders
        tr_wrap = SparseToDenseWrapper(tr_loader, info["labels"])
        tst_wrap = SparseToDenseWrapper(tst_loader, info["labels"])

        return tr_wrap, tst_wrap, info, inv_prop

    else:
        # User dataset_fast (Bibtex, Mediamill, Delicious)
        data_path = f"../data/{dataset_name.capitalize()}/{dataset_name.capitalize()}_data.txt"
        tr_split = f"../data/{dataset_name.capitalize()}/{dataset_name.lower()}_trSplit.txt"
        tst_split = f"../data/{dataset_name.capitalize()}/{dataset_name.lower()}_tstSplit.txt"

        if not os.path.exists(data_path):
             print(f"Skipping {dataset_name}: Data file not found at {data_path}")
             return None, None, None, None

        tr, tst, info = load_dataset(
            data_file=data_path,
            train_file=tr_split,
            test_file=tst_split,
            batch_size=batch_size,
            num_workers=0,
            pin_memory=True
        )
        return tr, tst, info, None

# -----------------------------------------------------------------------------
# Evaluation Helper
# -----------------------------------------------------------------------------
def evaluate_model(network, test_loader, device, model_name, inv_prop, debug=False):
    network.eval()
    
    # Lists to store sparse components
    row_list_true = []
    col_list_true = []
    
    row_list_pred = []
    col_list_pred = []
    val_list_pred = []
    
    offset = 0
    topk_k = 5
    
    import scipy.sparse as sparse
    import numpy as np # Added numpy import

    with torch.no_grad():
        # Add TQDM to show progress
        loader_iter = test_loader
        if not debug:
            loader_iter = tqdm(test_loader, desc="Evaluating", leave=False)
            
        for i, data in enumerate(loader_iter):
            if len(data) == 3:
                x_true, y_hot, y_idxs = data
                x_true = x_true.to(device)
                # y_idxs is (Batch, MaxLen), padded with num_labels
                # We can use y_idxs to build sparse true matrix
                
                # For loss, we might need y_hot, but for metrics we just need indices
            else:
                x_true, y_true = data[0].to(device), data[1].float().to(device)
                # If wrapped dataset not used, we might not have y_idxs easily
                # But our Preloaded loader yields y_idxs. 
                # If SparseToDenseWrapper yields y_idxs.
                # If not, we trigger fallback.
                pass 

            y_logits = network(x_true)
            
            # Predict
            # Instead of full inference (dense), use topk directly on logits/scores
            # network.inference might do something specific (cosine sim), so check it
            y_pred = network.inference(y_logits) 
            
            # Optim: Keep only top-k for metrics (saves massive RAM)
            # metrics.py uses topk=5 anyway
            vals, inds = torch.topk(y_pred, k=topk_k)
            
            vals = vals.cpu().numpy().flatten()
            inds = inds.cpu().numpy().flatten()
            
            batch_size = len(vals) // topk_k
            rows = np.repeat(np.arange(offset, offset + batch_size), topk_k)

            # Filter out indices that exceed valid labels (e.g. padding index from spmap)
            # We need num_labels for this. It is usually available in test_loader.
            if hasattr(test_loader, 'num_labels'):
                 valid_mask = inds < test_loader.num_labels
                 rows = rows[valid_mask]
                 inds = inds[valid_mask]
                 vals = vals[valid_mask]

            row_list_pred.append(rows)
            col_list_pred.append(inds)
            val_list_pred.append(vals)
            
            # Build TRUE sparse components
            # Handle Preloaded Loader which gives y_idxs
            if len(data) == 3:
                # y_idxs: (B, L_max). Padded with num_labels.
                y_idxs = y_idxs.cpu().numpy()
                valid_mask = y_idxs < test_loader.num_labels # Or info["labels"]
                
                # Determine rows and cols
                # vectorized approach to get (row, col) from padded y_idxs
                # rows: i repeated
                # cols: y_idxs[i, j] where valid
                
                # Fast way:
                # rows = np.repeat(np.arange(offset, offset+batch_size), y_idxs.shape[1])
                # cols = y_idxs.flatten()
                # filter by valid_mask
                
                curr_rows = np.repeat(np.arange(offset, offset + batch_size), y_idxs.shape[1])
                curr_cols = y_idxs.flatten()
                mask = valid_mask.flatten()
                
                row_list_true.append(curr_rows[mask])
                col_list_true.append(curr_cols[mask])
                
            else:
                # Fallback for dense y_true (bibtex etc)
                # y_true is (B, L) dense
                y_true = y_true.cpu().numpy()
                y_csr = sparse.csr_matrix(y_true)
                y_coo = y_csr.tocoo()
                row_list_true.append(y_coo.row + offset)
                col_list_true.append(y_coo.col)
            
            offset += batch_size
            
            if debug and i > 5: 
                break

    # Construct Final CSR Matrices
    true_rows = np.concatenate(row_list_true)
    true_cols = np.concatenate(col_list_true)
    true_vals = np.ones_like(true_rows, dtype=np.float32)
    
    num_samples = offset
    # total labels? We need to know num_labels
    if hasattr(test_loader, 'num_labels'):
        num_labels = test_loader.num_labels
    elif hasattr(test_loader, 'dataset') and hasattr(test_loader.dataset, 'num_labels'):
        num_labels = test_loader.dataset.num_labels
    else:
        num_labels = 1000000 # Fallback safe upper bound? Or infer max col?
        if len(col_list_true) > 0:
            num_labels = max(num_labels, true_cols.max() + 1)
        if len(col_list_pred) > 0:
            num_labels = max(num_labels, np.concatenate(col_list_pred).max() + 1)

    y_true_csr = sparse.csr_matrix((true_vals, (true_rows, true_cols)), shape=(num_samples, num_labels))
    
    pred_rows = np.concatenate(row_list_pred)
    pred_cols = np.concatenate(col_list_pred)
    pred_vals = np.concatenate(val_list_pred)
    
    y_pred_csr = sparse.csr_matrix((pred_vals, (pred_rows, pred_cols)), shape=(num_samples, num_labels))

    metrics = compute_prop_metrics(
        y_true_csr,
        y_pred_csr,
        inv_prop_scores=inv_prop,
        topk=5
    )
    return metrics

# -----------------------------------------------------------------------------
# Runner
# -----------------------------------------------------------------------------

def train_model(model_name, dataset_name, trial_id, device, debug=False, batch_size_override=None):
    print(f"\n--- Running {model_name.upper()} on {dataset_name} (Trial {trial_id}) ---")

    # 1.5 Get Config
    if dataset_name.lower() == "amazon-13k":
        dataset_name = "amazoncat-13k"

    ds_config = get_dataset_config(dataset_name)
    if batch_size_override:
        batch_size = batch_size_override
        print(f"DEBUG: Using overridden batch size: {batch_size}")
    else:
        batch_size = ds_config["batch_size"]
    
    # Epoch Config: 10 for small, 25 for large
    large_datasets = ["wiki10", "eurlex-4.3k", "deliciouslarge", "amazoncat-13k", "amazon-13k"]
    if dataset_name.lower() in large_datasets:
        epochs = 2
    else:
        epochs = 2

    # 1. Load Data via Proxy
    train_loader, test_loader, info, cached_inv_prop = load_dataset_proxy(dataset_name, batch_size)

    if train_loader is None:
        return None

    # Use info dict for size if available, or placeholder (dataset_fast prints it earlier anyway)
    ds_size = info.get('data_size', '?')
    print(f"Dataset size: {ds_size}, No. of features: {info['features']}, No. of labels: {info['labels']}")

    # Calculate Inv Propensity
    if cached_inv_prop is not None:
        inv_prop = cached_inv_prop
    else:
        # Fallback (fast datasets or cache failed)
        # Note: This iterates the loader which can be slow for large datasets if uncached
        print("Computing propensity scores (uncached)...")
        print("Computing propensity scores (uncached)...")
        train_labels_list = []
        for data in train_loader:
            if len(data) == 3:
                 train_labels_list.append(data[1]) # y_hot
            else:
                 train_labels_list.append(data[1])
        train_labels = torch.cat(train_labels_list).numpy()
        inv_prop = compute_inv_propensity(train_labels, A=0.55, B=1.5)

    # 2. Initialize Model
    # Determine which network to use
    large_datasets = ["wiki10", "eurlex-4.3k", "deliciouslarge", "amazoncat-13k", "amazon-13k"]

    if dataset_name.lower() in large_datasets:
        # Large datasets use specialized sparse networks
        if model_name == "hlb":
            module_name = "network_sparse_unified"
        elif model_name == "vtb":
            module_name = "network_spvtb"
        elif model_name == "map":
            module_name = "network_spmap"
        else:
            # OSC uses regular network
            module_name = MODELS[model_name]
    elif model_name == "vtb" and dataset_name.lower() == "eurlex":
        # Use specialized sparse VTB for Eurlex (small dataset special case)
        module_name = "network_spvtb"
    else:
        if model_name == "hlb":
            # Use Legacy Network for small datasets
            module_name = "network"
        else:
            module_name = MODELS[model_name]
    
    module = importlib.import_module(module_name)
    importlib.reload(module)

    params = OSC_PARAMS if model_name == "osc" else COMMON_PARAMS

    # 1.5 Get Config
    ds_config = get_dataset_config(dataset_name)

    # Logic: HLB/MAP follow the dataset config (e.g. Reservoir Mode).
    # VTB/Clifford MUST LEARN, so they override 'requires_grad' to True always.

    requires_grad = ds_config["requires_grad"]
    negative = ds_config["negative"]
    drop_rate = ds_config["drop_rate"]
    factor = ds_config["factor"]
    out_features = ds_config["out_features"]



    if model_name == "vtb" and "vtb_requires_grad" in ds_config:
        # Override VTB specific learning mode (e.g. Eurlex needs False)
        requires_grad = ds_config["vtb_requires_grad"]

    num_role_vecs = params.get("num_role_vecs", None)

    if model_name == "osc":
        # Start with global defaults (or params dict)
        osc_d = params.get("OSC_d", 125) # Default 125 if not present
        out_features = osc_d ** 2
        drop_rate = params["drop_rate"] # e.g. 0.2

        # Override if specific key present in DATASET CONFIG
        if "osc_d" in ds_config:
             osc_d = ds_config["osc_d"]
             out_features = osc_d ** 2
        if "osc_out_features" in ds_config:
             out_features = ds_config["osc_out_features"] # Keep legacy override just in case
        if "osc_drop_rate" in ds_config:
             drop_rate = ds_config["osc_drop_rate"]
        if "osc_num_role_vecs" in ds_config:
             num_role_vecs = ds_config["osc_num_role_vecs"]

    # Get hidden dimension (dataset-specific override or default)
    hidden = ds_config.get("hidden", params["hidden"])

    init_params = {
        "in_features": info["features"],
        "labels": info["labels"],
        "hidden": hidden,
        "out_features": out_features,
        "drop_rate": drop_rate,
        "factor": factor
    }

    # Add args only if the model supports them or we are generic
    if model_name in ["hlb", "map", "vtb"]:
        init_params["requires_grad"] = requires_grad
        init_params["negative"] = negative

    # Add reduce_dim support for sparse network (large datasets)
    if "reduce_dim" in ds_config:
        init_params["reduce_dim"] = ds_config["reduce_dim"]
        init_params["kernel_dim"] = ds_config.get("kernel_dim", 7)

    # Add device for sparse networks that require it
    if dataset_name.lower() in large_datasets and model_name == "hlb":
        init_params["device"] = device
    elif model_name == "map":
        init_params["device"] = device

    if model_name == "vtb":
        init_params["batch_size"] = batch_size

    if model_name == "osc":
        init_params["num_role_vecs"] = num_role_vecs
        init_params["device"] = device

    # Sanitize args: Remove 'device' for models that don't accept it in __init__
    # Legacy HLB (small datasets), MAP, and VTB do not take 'device'
    is_legacy_hlb = (model_name == "hlb" and dataset_name.lower() not in large_datasets)
    if (is_legacy_hlb or model_name in ["map", "vtb"]) and "device" in init_params:
        del init_params["device"]

    # print(f"Initializing Network with {init_params}...")
    network = module.Network(**init_params)
    network.to(device)

    optimizer = torch.optim.Adam(network.parameters(), lr=TRAIN_PARAMS["lr"])
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=TRAIN_PARAMS["gamma"])

    # 3. Train Loop
    prev_loss = None
    for epoch in range(1, epochs + 1):
        network.train()
        train_loss = []
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs} [Train]", leave=False)
        for data in progress_bar:
            if len(data) == 3:
                # Sparse dataset wrapper returns (x, y_hot, y_idxs)
                x_true, y_hot, y_idxs = data
                x_true = x_true.to(device)
                
                # MAP requires indices. VTB works better with indices. Others use dense.
                if model_name == "map":
                     y_true = y_idxs.long().to(device) # Keep as indices
                elif model_name == "vtb":
                     y_true = y_idxs.long().to(device) # VTB supports sparse indices
                else:
                     y_true = y_hot.to(device)
            else:
                # Fast dataset returns (x, y_hot)
                x_true, y_true = data[0].to(device), data[1].float().to(device)
            
            # Fast Debug Mode
            if debug and len(train_loss) > 10:
                print("DEBUG: Breaking training loop early")
                break
                
            optimizer.zero_grad()
            y_logits = network(x_true)

            if model_name == "osc":
                 target_vecs = torch.mm(y_true.float(), network.bound_labels)
                 from hlb_utils import cosine_similarity
                 sim = cosine_similarity(y_logits, target_vecs, dim=-1)
                 loss = torch.mean(1.0 - sim)
            else:
                loss = network.loss(y_logits, y_true)

            loss.backward()
            optimizer.step()

            train_loss.append(loss.item())
            progress_bar.set_postfix({'loss': f'{loss.item():.4f}', 'avg_loss': f'{np.mean(train_loss):.4f}'})
        scheduler.step()
        
        # Per-epoch evaluation for large datasets (monitoring)
        large_datasets_eval = ["wiki10", "eurlex-4.3k", "deliciouslarge", "amazoncat-13k", "amazon-13k"]
        if dataset_name.lower() in large_datasets_eval:
             # Early stopping based on avg_loss convergence
             current_loss = np.mean(train_loss)
             if prev_loss is not None and (prev_loss - current_loss) < 5e-4:
                 print(f"Early stopping: Loss improvement {prev_loss - current_loss:.6f} < 1e-4")
                 break
             prev_loss = current_loss

    # 4. Final Evaluation
    # 4. Final Evaluation
    metrics = evaluate_model(network, test_loader, device, model_name, inv_prop, debug=debug)

    # Print Full Table for this Trial
    metric_names = ["Precision@k", "nDCG@k", "PSprec@k", "PSnDCG@k"]
    header = f"{'Metric':<15}" + "".join([f"{k:<10}" for k in range(1, 6)])
    print(header)
    print("-" * len(header))
    
    for i, m_name in enumerate(metric_names):
        row_str = f"{m_name:<15}"
        for k in range(5):
            val = metrics[i][k] * 100
            row_str += f"{val:<10.2f}"
        print(row_str)
    print("")

    # Store full metrics array (4, 5)
    # Print trial summary (P@1 and nDCG@5) just for progress logging
    print(f"Trial {trial_id} Result: P@1={metrics[0][0]*100:.2f}, nDCG@5={metrics[1][4]*100:.2f}")
    return metrics

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--trials", type=int, default=3, help="Number of trials per experiment")
    parser.add_argument("--datasets", nargs="+", default=DATASETS, help="Datasets to run")
    parser.add_argument("--models", nargs="+", default=list(MODELS.keys()), choices=list(MODELS.keys()), help="Models to run (hlb, vtb, map, osc)")
    parser.add_argument("--debug", action="store_true", help="Run in fast debug mode (limited batches)")
    args = parser.parse_args()

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using Device: {device}")

    # Results Structure: {Dataset: {Model: [Matrix_Trial1, Matrix_Trial2, ...]}}
    results_db = collections.defaultdict(lambda: collections.defaultdict(list))

    for dataset in args.datasets:
        # Define output filename per dataset
        models_str = "-".join(args.models)
        output_filename = f"XMLResults_{dataset}_{args.trials}_{models_str}.txt"
        
        # Initialize file or clear previous run
        with open(output_filename, 'w') as f:
            f.write(f"Experiment Results for {dataset}\n")
            f.write("="*100 + "\n\n")

        for model in args.models:
            print(f"\nStarting {model.upper()} on {dataset}...")
            start_time = time.time()
            for t in range(1, args.trials + 1):
                # Retry loop
                max_retries = 5
                retry_count = 0
                success = False
                
                # Check config or use last successful
                current_config = get_dataset_config(dataset)
                current_bs = current_config["batch_size"]

                while not success:
                    try:
                        res = train_model(model, dataset, t, device, debug=args.debug, batch_size_override=current_bs)
                        if res is not None:
                            results_db[dataset][model].append(res)
                            
                            # LOG TRIAL RESULT IMMEDIATELLY
                            # res is (4, 5) metrics array
                            metric_names = ["Precision@k", "nDCG@k", "PSprec@k", "PSnDCG@k"]
                            
                            with open(output_filename, 'a') as f:
                                f.write(f"\nModel: {model} | Trial: {t}\n")
                                
                                # Print Table Header
                                header = f"{'Metric':<15}" + "".join([f"{k:<10}" for k in range(1, 6)]) + "\n"
                                f.write(header)
                                separator = "-" * (len(header) - 1) + "\n"
                                f.write(separator)
                                
                                # Print Table Rows
                                for i, m_name in enumerate(metric_names):
                                    row_str = f"{m_name:<15}"
                                    for k in range(5):
                                        val = res[i][k] * 100
                                        row_str += f"{val:<10.2f}"
                                    f.write(row_str + "\n")
                                f.write("-" * 50 + "\n")
                                
                        success = True
                    except Exception as e:
                        # Check for OOM or generic error
                        print(f"Error running {model} on {dataset} trial {t}: {e}")
                        
                        # Only retry if OOM (RuntimeError)
                        is_oom = "CUDA out of memory" in str(e)
                        
                        if is_oom and current_bs > 1:
                            new_bs = current_bs // 2
                            print(f"⚠️ Reducing batch size from {current_bs} to {new_bs} and retrying...")
                            current_bs = new_bs
                            
                            # Clean up memory
                            torch.cuda.empty_cache()
                            import gc
                            gc.collect()
                        else:
                            import traceback
                            traceback.print_exc()
                            print(f"❌ Critical Error or Batch size too small ({current_bs}). Skipping trial.")
                            break
            dur = time.time() - start_time
            print(f"Finished {model.upper()} in {dur:.1f}s")
            
        # --- End of Dataset: Append Aggregated Results ---
        print(f"\nAggregating results for {dataset}...")
        metric_names = ["Precision@k", "nDCG@k", "PSprec@k", "PSnDCG@k"]
        
        with open(output_filename, 'a') as f:
            f.write("\n" + "="*100 + "\n")
            f.write(f"{'FINAL AGGREGATED RESULTS (N=' + str(args.trials) + ')':^100}\n")
            f.write("="*100 + "\n\n")
            
            # Repopulate from results_db since we have separate files now, 
            # we just process the current dataset's results
            models_res = results_db[dataset]
            
            for model_name in MODELS.keys(): # Print in order of definition (hlb, vtb, map, clifford)
                if model_name not in args.models: continue
                
                trial_results = models_res.get(model_name, [])
                if not trial_results:
                    failed_line = f"--- Model: {model_name} (FAILED) ---\n"
                    f.write(failed_line)
                    continue

                model_line = f"--- Model: {model_name} ---\n"
                f.write(model_line)

                # trial_results is List of (4, 5) ndarrays or lists
                # Stack them to shape (Trials, 4, 5)
                stack = np.array(trial_results) * 100
                mean = np.mean(stack, axis=0)
                std = np.std(stack, axis=0)

                # Print Table
                header = f"{'Metric':<15}" + "".join([f"{k:<16}" for k in range(1, 6)]) + "\n"
                f.write(header)
                separator = "-" * (len(header) - 1) + "\n"
                f.write(separator)

                for i, m_name in enumerate(metric_names):
                    row_str = f"{m_name:<15}"
                    for k in range(5):
                        val_str = f"{mean[i, k]:5.2f}±{std[i, k]:<5.2f}"
                        row_str += f"{val_str:<16}"
                    f.write(row_str + "\n")
                f.write("\n")
            
        print(f"✓ Results updated in: {output_filename}")



if __name__ == "__main__":
    main()
