
import yaml
import time
import torch
import collections
import sys
import os
import importlib
import math
from scipy import sparse
from dataset_fast import load_dataset
# from hlb_utils import evaluate  # evaluate uses threshold 0.5, too harsh for VSA
# Using local metrics or nDCG logic. 
# The user wants "just does the DCG calculation" on "input datasets".
# We'll use the repo's existing metrics utils where possible, or our custom one if xclib is missing.
# Note: In previous steps, xclib was missing.
# I will embed the robust nDCG calculation here to ensure it works.

def compute_ndcg_at_k_batch(scores, true_labels_sparse, k=5):
    """
    Compute nDCG@k for a batch of predictions.
    scores: (Batch, NumLabels) torch tensor
    true_labels_sparse: (Batch, NumLabels) scipy csr_matrix or similar
    """
    batch_size = scores.shape[0]
    ndcg_sum = 0.0
    
    # Move to CPU for processing with sparse matrix (or convert sparse to dense if small)
    scores_np = scores.detach().cpu().numpy()
    
    for i in range(batch_size):
        # Get true indices
        true_indices = true_labels_sparse[i].indices
        if len(true_indices) == 0:
            continue
            
        # Get top k predictions
        # We can implement this efficiently or just sort
        # For huge labels, sorting is slow. partitioning is better.
        row_scores = scores_np[i]
        if k < len(row_scores):
            top_k_idx = np.argpartition(row_scores, -k)[-k:]
            # partition doesn't sort the top k, so we sort them now
            top_k_idx = top_k_idx[np.argsort(row_scores[top_k_idx])[::-1]]
        else:
            top_k_idx = np.argsort(row_scores)[::-1]
            
        dcg = 0.0
        for rank, idx in enumerate(top_k_idx):
            if idx in true_indices:
                dcg += 1.0 / np.log2(rank + 2)
                
        # IDCG
        idcg = 0.0
        for rank in range(min(k, len(true_indices))):
            idcg += 1.0 / np.log2(rank + 2)
            
        if idcg > 0:
            ndcg_sum += dcg / idcg
            
    return ndcg_sum / batch_size

def compute_precision_at_1(y_true, y_pred):
    """
    Compute Precision@1 (Top-1 Accuracy).
    y_true: (Batch, NumLabels) - often one-hot-like or multi-hot
    y_pred: (Batch, NumLabels) - scores
    """
    # y_true might be indices (if sparse) or dense tensor?
    # In train loop: x_true, y_true = data[0], data[1]
    # In dataset_fast.py, labels are returned as tensors.
    # y_true is likely a dense tensor of shape (B, L) with 0/1s based on dataset_fast.py lines 38-45
    
    # Get index of highest prediction
    # y_pred: (B, L)
    top_pred_indices = torch.argmax(y_pred, dim=1) # (B,)
    
    # Check if that index is active in y_true
    # y_true: (B, L)
    # We gather the value at the predicted index
    
    # Ensure y_true is on same device
    # y_true is (B, L) float or int? labels[i, idx]=1 implies it's 0/1 tensor.
    
    # Gather: (B, 1) result of y_true[batch_idx, top_pred_idx]
    # result = y_true.gather(1, top_pred_indices.view(-1, 1))
    
    # But let's be robust if y_true is sparse or unknown type in future.
    # Based on earlier reads, y_true comes from data[1].to(device).
    
    vals = y_true.gather(1, top_pred_indices.unsqueeze(1))
    return torch.mean(vals.float()) * 100.0

# Since we don't have numpy imported yet in this snippet scope?
import numpy as np

def run_experiment(exp_config, datasets_config, networks_config):
    # Lookup Configs
    dataset_name = exp_config['dataset']
    network_name = exp_config['network']
    
    d_conf = next((item for item in datasets_config if item['name'] == dataset_name), None)
    n_conf = next((item for item in networks_config if item['name'] == network_name), None)
    
    if not d_conf or not n_conf:
        print(f"Error: Could not find config for dataset '{dataset_name}' or network '{network_name}'")
        return

    print(f"\n{'='*40}")
    print(f"Running Experiment: {dataset_name} + {network_name}")
    print(f"{'='*40}")

    # 1. Load Data
    # Resolve paths relative to config file location (script_dir)
    # We assume config file is in the same dir as this script: script_dir
    # Note: script_dir needs to be passed in or calculated. 
    # run_experiment is called from main, where script_dir is known.
    # Let's calculate script_dir here to be safe or pass it.
    script_dir = os.path.dirname(os.path.abspath(__file__))

    def resolve_path(path):
        if not os.path.isabs(path):
            return os.path.join(script_dir, path)
        return path

    data_path = resolve_path(d_conf['data_file'])
    train_path = resolve_path(d_conf['train_file'])
    test_path = resolve_path(d_conf['test_file'])

    print(f"Loading Dataset: {data_path}")
    train_loader, test_loader, info = load_dataset(
        data_file=data_path,
        train_file=train_path,
        test_file=test_path,
        batch_size=exp_config.get('batch_size', 64),
        num_workers=0
    )
    
    # 2. Load Network Class
    print(f"Loading Network Module: {n_conf['module']}")
    try:
        module = importlib.import_module(n_conf['module'])
        NetworkClass = getattr(module, n_conf['class'])
    except Exception as e:
        print(f"Failed to import network class: {e}")
        return

    # 3. Instantiate Network
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is not available. This script requires a GPU.")
        
    device = torch.device("cuda:0")
    print(f"Device: {device}")
    
    params = n_conf.get('params', {}).copy()
    # Inject dynamic params from dataset info
    params['in_features'] = info['features']
    params['labels'] = info['labels']
    params['device'] = device # VSA network needs device in init
    
    # Filter params acceptable by the class? Or assume **kwargs?
    # Network classes usually have specific init args.
    # We'll pass explicitly known ones or try to pass all
    if 'hidden' not in params: params['hidden'] = 512
    if 'out_features' not in params: params['out_features'] = 400
    
    # Construct
    # Note: Traditional Network(in, hidden, out, labels...)
    # VSA Network(device, in, hidden, out, labels...)
    # We'll try to handle arguments intelligently
    
    try:
        # Check if class takes 'device' in init (VSA one does, Original likely doesn't in 1st arg)
        # We can inspect or just try/except
        # For now, let's map common args
        
        # Base Args
        init_args = {
            'in_features': params['in_features'],
            'hidden': params['hidden'],
            'out_features': params['out_features'],
            'labels': params['labels'],
            'drop_rate': params.get('drop_rate', 0.0)
        }
        
        if n_conf['module'] == 'network_clifford':
             init_args['device'] = device
             if 'num_role_vecs' in params:
                 init_args['num_role_vecs'] = params['num_role_vecs']
        
        network = NetworkClass(**init_args)
        
    except TypeError as e:
        print(f"Init Error (Retrying with device arg adjustment): {e}")
        # Fallback logic if needed
        return

    network.to(device)
    print("Network Initialized.")

    # 4. Train Loop
    epochs = exp_config.get('epochs', 5)
    lr = exp_config.get('lr', 1e-3)
    
    loss_function = network.loss # Custom loss method
    # Original uses network.loss(logits, true)
    
    optimizer = torch.optim.Adam(network.parameters(), lr=lr)
    # Repo uses ExponentialLR with gamma=0.98
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)
    
    print("Starting Training...")

    logger = collections.defaultdict(list)
    
    for epoch in range(1, epochs + 1):
        tic = time.time()
        logger["loss"], logger["acc"] = [], []
        network.train()
        
        for i, data in enumerate(train_loader):
            x_true, y_true = data[0].to(device), data[1].to(device) # Labels are (B, L) one-hot or indices?
            
            if epoch == 1 and i == 0:
                 # Debug verification
                 pass
            
            # ... (Rest of loop logic unchanged until step) ...
            
            optimizer.zero_grad()
            y_logits = network(x_true)
            
            if n_conf['module'] == 'network_clifford':
                # Use MatMull logic: Target = Y_true @ BoundLabels
                target_vecs = torch.mm(y_true.float(), network.bound_labels)
                
                from hlb_utils import cosine_similarity
                
                sim = cosine_similarity(y_logits, target_vecs, dim=-1)
                loss = torch.mean(1.0 - sim)
                
            else:
               # Original HLB
               loss = network.loss(y_logits, y_true)
            
            loss.backward()
            optimizer.step()
            
            # Evaluate Batch
            y_pred = network.inference(y_logits)
            # acc = evaluate(y_true, y_pred)
            acc = compute_precision_at_1(y_true, y_pred)
            
            logger["loss"].append(loss.item())
            logger["acc"].append(acc.item())

        logger["train_loss"].append(sum(logger["loss"])/len(logger["loss"]))
        logger["train_acc"].append(sum(logger["acc"])/len(logger["acc"]))

        # Check for test loader
        logger["loss"], logger["acc"] = [], []
        network.eval()
        
        with torch.no_grad():
             for data in test_loader:
                 x_true, y_true = data[0].to(device), data[1].to(device)
                 y_logits = network(x_true)
                 
                 if n_conf['module'] == 'network_clifford':
                     target_vecs = torch.mm(y_true.float(), network.bound_labels)
                     sim = cosine_similarity(y_logits, target_vecs, dim=-1)
                     loss = torch.mean(1.0 - sim)
                 else:
                     loss = network.loss(y_logits, y_true)
                     
                 y_pred = network.inference(y_logits)
                 # acc = evaluate(y_true, y_pred)
                 acc = compute_precision_at_1(y_true, y_pred)
                 
                 logger["loss"].append(loss.item())
                 logger["acc"].append(acc.item())
        
        logger["test_loss"].append(sum(logger["loss"])/len(logger["loss"]))
        logger["test_acc"].append(sum(logger["acc"])/len(logger["acc"]))

        # Step Scheduler
        scheduler.step()
        
        toc = time.time()
        
        # Format: Epoch: [  1/25], train loss: 0.6801, train acc:  8.79%, test loss: 0.XXXX, test acc: XX.XX%, etc:  0.44s
        log_str = (f"Epoch: [{epoch:>3d}/{epochs}], "
                   f"train loss: {logger['train_loss'][-1]:>6.4f}, "
                   f"train acc: {logger['train_acc'][-1]:>5.2f}%, "
                   f"test loss: {logger['test_loss'][-1]:>6.4f}, "
                   f"test acc: {logger['test_acc'][-1]:>5.2f}%, "
                   f"etc: {toc-tic:>5.2f}s")
        print(log_str)
            
        # print(f"Epoch {epoch}: Mean Loss = {sum(losses)/len(losses):.4f} (Time: {time.time()-tic:.2f}s)")

    # 5. Evaluation (Repo Metrics or Fallback)
    print("Evaluating...")
    
    use_repo_metrics = False
    try:
        from metrics import compute_inv_propensity, compute_prop_metrics, display_metrics
        use_repo_metrics = True
        print("Using Repo Metrics (pyxclib).")
    except ImportError:
        print("Repo Metrics (pyxclib) not module found. Using fallback nDCG calculation.")
        
    if use_repo_metrics:
        # Calculate Inverse Propensity Scores on TRAIN set
        print("Computing Propensity Scores from Train Set...")
        train_labels = train_loader.dataset.labels.numpy()
        inv_prop = compute_inv_propensity(train_labels, A=0.55, B=1.5)
        
        network.eval()
        logger_metrics = []
        
        with torch.no_grad():
            for data in test_loader:
                 x_true, y_true = data[0].to(device), data[1]
                 y_logits = network(x_true)
                 
                 y_pred = network.inference(y_logits)
                     
                 y_true_np = sparse.csr_matrix(y_true.numpy())
                 y_pred_np = sparse.csr_matrix(y_pred.cpu().numpy())
                 
                 batch_metrics = compute_prop_metrics(y_true_np, y_pred_np, inv_prop_scores=inv_prop, topk=5)
                 logger_metrics.append(batch_metrics)
        
        display_metrics(logger_metrics)
        
    else:
        # Fallback: Standard nDCG
        network.eval()
        all_scores = []
        all_true = []
        
        with torch.no_grad():
            for data in test_loader:
                 x_true, y_true = data[0].to(device), data[1]
                 y_logits = network(x_true)
                 
                 scores = network.inference(y_logits)
                     
                 all_scores.append(scores.cpu())
                 all_true.append(sparse.csr_matrix(y_true.numpy()))
        
        full_scores = torch.cat(all_scores, dim=0)
        full_true = sparse.vstack(all_true)
        
        final_ndcg = compute_ndcg_at_k_batch(full_scores, full_true, k=5)
        print(f"Final nDCG@5 (Standard): {final_ndcg:.4f}")

if __name__ == "__main__":
    # Fix: Resolve config path relative to this script file
    script_dir = os.path.dirname(os.path.abspath(__file__))
    config_path = os.path.join(script_dir, "config_ndcg.yaml")
    
    # Also ensure the script dir is in sys.path for network imports if running from outside
    if script_dir not in sys.path:
        sys.path.append(script_dir)

    with open(config_path, "r") as f:
        config = yaml.safe_load(f)
        
    for exp in config['experiments']:
        run_experiment(exp, config['datasets'], config['networks'])
