import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, ConcatDataset
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import numpy as np
import random
from tqdm import tqdm
import logging
import wandb
from datetime import datetime
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score
from transformers import get_linear_schedule_with_warmup
from model.fsl_model import FSLModel
from utils.data_utils import simple_data_collator, Example
from data.uci_yaml import UCIDataset
from collections import OrderedDict

# For memory limiting
import resource
import psutil
import threading
import time

# ------------------------
# SETUP: CUDA, DDP, LOGGING
# ------------------------
import warnings
import os

# Suppress transformers cache warning
os.environ["HF_HOME"] = os.environ.get("TRANSFORMERS_CACHE", os.path.expanduser("/playpen-nvme/scribble/shbhat/.cache/huggingface"))
warnings.filterwarnings("ignore", message="Using `TRANSFORMERS_CACHE` is deprecated")
if torch.cuda.is_available():
    torch.set_float32_matmul_precision('high')
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "1"
os.environ["NCCL_P2P_DISABLE"] = "0"

logging.basicConfig(
    format='%(asctime)s - %(levelname)s - %(message)s',
    level=logging.INFO
)
logger = logging.getLogger(__name__)

def watch_memory(max_bytes, interval=1.0):
    p = psutil.Process(os.getpid())
    while True:
        used = p.memory_info().rss
        if used > max_bytes:
            logger.error(f"Out of RAM! using {used/1e9:.1f} GB > limit {max_bytes/1e9:.1f} GB")
            os._exit(1)
        time.sleep(interval)

# Limit process virtual memory to 100 GB
MAX_RAM = 100 * 1024**3
soft, hard = resource.getrlimit(resource.RLIMIT_AS)
resource.setrlimit(resource.RLIMIT_AS, (MAX_RAM, hard))
threading.Thread(target=watch_memory, args=(MAX_RAM,), daemon=True).start()

# Smart Cache Implementation
class SmartCache:
    def __init__(self, max_size=10000, max_memory_gb=5.0):
        self.cache = OrderedDict()
        self.max_size = max_size
        self.max_memory_gb = max_memory_gb
        self.memory_usage = 0
        
    def get(self, key):
        if key in self.cache:
            self.cache.move_to_end(key)
            return self.cache[key]
        return None
    
    def put(self, key, value):
        if key in self.cache:
            self.cache.move_to_end(key)
            return
            
        # Estimate tensor memory
        if isinstance(value, torch.Tensor):
            tensor_memory = value.element_size() * value.nelement() / 1024**3
        else:
            tensor_memory = 0.001  # 1MB estimate for other objects
            
        # Check memory limit
        if self.memory_usage + tensor_memory > self.max_memory_gb:
            # Remove oldest items until we have space
            while self.cache and self.memory_usage + tensor_memory > self.max_memory_gb:
                self.cache.popitem(last=False)
                self.memory_usage *= 0.9  # Rough estimate
        
        self.cache[key] = value
        self.memory_usage += tensor_memory
        
        # Size limit check
        if len(self.cache) > self.max_size:
            self.cache.popitem(last=False)
    
    def clear(self):
        self.cache.clear()
        self.memory_usage = 0
class LossSmoothing:
    def __init__(self, alpha=0.99):
        self.alpha = alpha
        self.smoothed_loss = None
    
    def update(self, loss):
        if self.smoothed_loss is None:
            self.smoothed_loss = loss
        else:
            self.smoothed_loss = self.alpha * self.smoothed_loss + (1 - self.alpha) * loss
        return self.smoothed_loss
# Dataset wrapper for DataLoader
class CollatedDataset(Dataset):
    """Wrapper to apply collation to a dataset"""
    def __init__(self, dataset):
        self.dataset = dataset
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        # Return the example directly - DataLoader will handle batching
        return self.dataset[idx]

def custom_collate_fn(batch):
    """Custom collate function that applies simple_data_collator"""
    return simple_data_collator(batch)
class RandomSamplingDataset(Dataset):
    """Dataset that randomly samples from multiple datasets for a fixed number of examples"""
    def __init__(self, datasets, examples_per_epoch, seed=42, batch_size=32):
        # Don't wrap datasets in CollatedDataset here - do it lazily
        self.dataset_configs = [(ds, len(ds)) for ds in datasets if len(ds) > 0]
        self.examples_per_epoch = examples_per_epoch
        self.total_examples = sum(size for _, size in self.dataset_configs)
        self.rng = random.Random(seed)
        self.batch_size = batch_size 
        # Pre-calculate dataset weights based on size
        self.weights = [size / self.total_examples for _, size in self.dataset_configs]
        
        # Lazy initialization
        self._datasets = None
        
        if is_main_process():
            logger.info(f"RandomSamplingDataset: {len(self.dataset_configs)} datasets, "
                       f"{self.total_examples} total examples, "
                       f"sampling {self.examples_per_epoch} per epoch")
    
    def _init_datasets(self):
        """Lazy initialization of datasets"""
        if self._datasets is None:
            self._datasets = [CollatedDataset(ds) for ds, _ in self.dataset_configs]
    
    def __len__(self):
        return min(self.examples_per_epoch, self.total_examples)
    
    # def __getitem__(self, idx):
    #     # Initialize datasets lazily in worker process
    #     self._init_datasets()
        
    #     # Randomly select a dataset weighted by size
    #     dataset_idx = self.rng.choices(range(len(self._datasets)), weights=self.weights)[0]
    #     dataset = self._datasets[dataset_idx]
        
    #     # Randomly select an example from that dataset
    #     example_idx = self.rng.randint(0, len(dataset) - 1)
    #     return dataset[example_idx]
    # def __getitem__(self, idx):
    #     # Initialize datasets lazily in worker process
    #     self._init_datasets()
        
    #     # Determine which dataset this index belongs to
    #     # Assuming batch_size is accessible (you might need to pass it in __init__)
    #     batch_size = 32  # or get from self.batch_size
    #     batch_idx = idx // batch_size
    #     dataset_idx = batch_idx % len(self._datasets)
        
    #     dataset = self._datasets[dataset_idx]
        
    #     # Get example within the dataset
    #     within_batch_idx = idx % batch_size
    #     example_idx = within_batch_idx % len(dataset)
        
    #     return dataset[example_idx]
    # def __getitem__(self, idx):
    #     # Initialize datasets lazily in worker process
    #     self._init_datasets()
        
    #     # Each batch should come from ONE dataset
    #     batch_size = 32  # or self.batch_size if available
    #     batch_idx = idx // batch_size
    #     dataset_idx = batch_idx % len(self._datasets)
        
    #     dataset = self._datasets[dataset_idx]
        
    #     # Within the batch, cycle through the dataset
    #     within_batch_idx = idx % batch_size
    #     example_idx = within_batch_idx % len(dataset)
        
    #     return dataset[example_idx]
    def __getitem__(self, idx):
        self._init_datasets()
        
        # Smoother dataset transitions - cycle through each dataset more gradually
        examples_per_dataset = 500  # Stay on each dataset longer
        dataset_idx = (idx // examples_per_dataset) % len(self._datasets)
        dataset = self._datasets[dataset_idx]
        
        # Cycle through the dataset examples
        example_idx = idx % len(dataset)
        return dataset[example_idx]
import torch
import gc
import psutil
import os

def diagnose_memory():
    """Print detailed memory diagnostics"""
    print("\n=== Memory Diagnostics ===")
    
    # System memory
    process = psutil.Process(os.getpid())
    ram_usage = process.memory_info().rss / 1024**3
    print(f"System RAM usage: {ram_usage:.2f} GB")
    
    if torch.cuda.is_available():
        # GPU memory
        for i in range(torch.cuda.device_count()):
            print(f"\nGPU {i} ({torch.cuda.get_device_name(i)}):")
            allocated = torch.cuda.memory_allocated(i) / 1024**3
            reserved = torch.cuda.memory_reserved(i) / 1024**3
            total = torch.cuda.get_device_properties(i).total_memory / 1024**3
            
            print(f"  Allocated: {allocated:.2f} GB")
            print(f"  Reserved:  {reserved:.2f} GB")
            print(f"  Total:     {total:.2f} GB")
            print(f"  Free:      {total - reserved:.2f} GB")
            
            # Memory stats
            stats = torch.cuda.memory_stats(i)
            print(f"  Active blocks: {stats.get('active.all.current', 0)}")
            print(f"  Reserved blocks: {stats.get('reserved_blocks.all.current', 0)}")
    
    # Python objects
    print(f"\nPython objects: {len(gc.get_objects())}")
    
    # Tensors in memory
    tensor_count = 0
    total_tensor_memory = 0
    for obj in gc.get_objects():
        if torch.is_tensor(obj):
            tensor_count += 1
            if obj.is_cuda:
                total_tensor_memory += obj.element_size() * obj.nelement()
    
    print(f"Live tensors: {tensor_count}")
    print(f"Tensor memory: {total_tensor_memory / 1024**3:.2f} GB")
    print("========================\n")

def clear_memory():
    """Aggressively clear memory"""
    # Garbage collection
    for _ in range(3):
        gc.collect()
    
    # Clear CUDA cache
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    
    print("Memory cleared!")

def setup_distributed():
    """Initialize distributed training (NCCL)."""
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ["WORLD_SIZE"])
        local_rank = int(os.environ.get("LOCAL_RANK", 0))
    else:
        rank, world_size, local_rank = 0, 1, 0

    if world_size > 1:
        torch.cuda.set_device(local_rank)
        dist.init_process_group(backend='nccl', init_method='env://')

    return rank, world_size, local_rank

def cleanup_distributed():
    if dist.is_initialized():
        dist.destroy_process_group()

def is_main_process():
    return not dist.is_initialized() or dist.get_rank() == 0

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

# ------------------------
# DATASET + MEDIAN-BINARIZATION
# ------------------------

def compute_median(dataset_id, args):
    """
    Compute the median of continuous targets in the train split of UCIDataset.
    Returns None for categorical targets.
    """
    ds = UCIDataset(
        dataset_id=dataset_id,
        model_names=[args.embedding_model],
        split='train',
        kshot_setting=args.kshot_setting,
        column_missing_setting=args.column_missing_setting,
        test_size=args.test_size,
        seed=args.seed
    )
    
    if not len(ds):
        return None
    
    # Get the first example to check the target feature type
    first_ex = ds[0]
    target_feat = first_ex.features[first_ex.target_column_id]
    
    # Only compute median for real-valued targets
    if target_feat.dtype != 'real':
        return None  # No median for categorical targets
    
    # Now safely compute median for numerical targets
    vals = []
    for ex in ds:
        try:
            val = float(ex.target_row[ex.target_column_id])
            vals.append(val)
        except (ValueError, TypeError):
            # Skip any non-numeric values (shouldn't happen for 'real' dtype)
            continue
    
    if not vals:
        return None
        
    return float(np.median(vals))

class UCITaskDataset(Dataset):
    """
    Wraps UCIDataset and binarizes any continuous ("real") target according to median.
    """
    def __init__(self, dataset_id, split, args, median=None):
        self.dataset = UCIDataset(
            dataset_id=dataset_id,
            model_names=[args.embedding_model],
            split=split,
            kshot_setting=args.kshot_setting,
            column_missing_setting=args.column_missing_setting,
            test_size=args.test_size,
            seed=args.seed
        )
        self.examples = list(self.dataset)
        self.median = median

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        ex = self.examples[idx]
        feat = ex.features[ex.target_column_id]
        # If the target is continuous and we have a median, binarize it
        if feat.dtype == 'real' and self.median is not None:
            val = float(ex.target_row[ex.target_column_id])
            bin_val = 1 if val >= self.median else 0
            new_row = ex.target_row.copy()
            new_row[ex.target_column_id] = bin_val
            ex = Example(
                description=ex.description,
                features=ex.features,
                fewshot_rows=ex.fewshot_rows,
                target_row=new_row,
                target_column_id=ex.target_column_id,
                missing_column_ids=ex.missing_column_ids
            )
        return ex

# ------------------------
# EVALUATION (CLASSIFICATION ONLY)
# ------------------------

# def evaluate_model(model, device, dataset_ids, split, args):
#     """
#     Evaluate the model on given UCITaskDataset instances.
#     Returns a dict: {dataset_name: {loss, accuracy, f1}}.
#     """
#     import sklearn.metrics as skm

#     model.eval()
#     results = {}
#     all_metrics = {}

#     with torch.no_grad():
#         for name, dsid in dataset_ids.items():
#             try:
#                 median = compute_median(dsid, args)
#                 ds = UCITaskDataset(dsid, split, args, median)
#                 if len(ds) == 0:
#                     logger.warning(f"Skipping empty dataset: {name}")
#                     continue

#                 total_loss = 0.0
#                 count = 0
#                 y_true = []
#                 y_pred = []

#                 for example in ds:
#                     with torch.amp.autocast(device_type='cuda', enabled=args.use_fp16):
#                         inp, mask, pos_ids, pos = model._prepare_example(example)
#                         inp = inp.to(device)
#                         mask = mask.to(device)
#                         pos_ids = pos_ids.to(device)
#                         seq_len = inp.size(1)
#                         max_allowed = model.model.config.max_position_embeddings
#                         if seq_len > max_allowed:
#                             inp = inp[:, :max_allowed, :]
#                             mask = mask[:, :max_allowed]
#                             pos_ids = pos_ids[:, :max_allowed]
#                             pos = min(pos, max_allowed - 1)
#                             seq_len = max_allowed

#                         token_type_ids = torch.zeros((1, seq_len), dtype=torch.long, device=device)
#                         out = model.model(
#                             inputs_embeds=inp,
#                             attention_mask=mask,
#                             token_type_ids=token_type_ids,
#                             position_ids=pos_ids,
#                             return_dict=True
#                         )
#                         hidden = out.last_hidden_state[0, pos]
#                         feature = example.features[example.target_column_id]
#                         has_categories = (
#                             hasattr(feature, 'categories') 
#                             and feature.categories 
#                             and len(feature.categories) > 0
#                         )

#                         # Classification head (always used)
#                         vec = model.class_head(hidden)

#                         if has_categories:
#                             # Multi-class classification
#                             cats = feature.categories
#                             protos = torch.stack([
#                                 feature.categories_embedding[model.config.embedding_model][c].to(device)
#                                 for c in cats
#                             ], dim=0)
#                             logits = vec.unsqueeze(0) @ protos.t()
                            
#                             # FIXED: Handle string categorical values
#                             target_value = example.target_row[example.target_column_id]
#                             target_str = str(target_value)
                            
#                             try:
#                                 target_idx = cats.index(target_str)
#                             except ValueError:
#                                 # Try without quotes if present
#                                 if target_str.startswith("'") and target_str.endswith("'"):
#                                     target_str = target_str[1:-1]
#                                     try:
#                                         target_idx = cats.index(target_str)
#                                     except ValueError:
#                                         logger.warning(f"Target value '{target_str}' not found in categories")
#                                         target_idx = 0
#                                 else:
#                                     logger.warning(f"Target value '{target_str}' not found in categories")
#                                     target_idx = 0
                            
#                             target_tensor = torch.tensor([target_idx], device=device)
#                             loss = F.cross_entropy(logits, target_tensor)
#                             pred_idx = logits.argmax(dim=1).item()
#                             y_true.append(target_idx)
#                             y_pred.append(pred_idx)

#                         else:
#                             # Binary classification of continuous target
#                             binary_embs = torch.randn(2, model.config.model_dim, device=device)
#                             binary_embs = F.normalize(binary_embs, dim=1)
#                             logits = vec.unsqueeze(0) @ binary_embs.t()
#                             true_val = float(example.target_row[example.target_column_id])
#                             true_bin = 1 if true_val >= (median if median is not None else 0.0) else 0
#                             target_idx = torch.tensor([true_bin], device=device)
#                             loss = F.cross_entropy(logits, target_idx)
#                             pred_bin = logits.argmax(dim=1).item()
#                             y_true.append(true_bin)
#                             y_pred.append(pred_bin)

#                         total_loss += loss.item()
#                         count += 1

#                 avg_loss = total_loss / (count if count > 0 else 1)
#                 metrics = {'loss': avg_loss}

#                 if y_true:
#                     y_true_np = np.array(y_true)
#                     y_pred_np = np.array(y_pred)
#                     accuracy = accuracy_score(y_true_np, y_pred_np)
#                     metrics['accuracy'] = accuracy
#                     unique_classes = np.unique(np.concatenate([y_true_np, y_pred_np]))
#                     if len(unique_classes) <= 2:
#                         f1 = f1_score(y_true_np, y_pred_np, average='binary', zero_division=0)
#                     else:
#                         f1 = f1_score(y_true_np, y_pred_np, average='macro', zero_division=0)
#                     metrics['f1'] = f1

#                 results[name] = metrics
#                 for metric_name, metric_value in metrics.items():
#                     all_metrics[f"eval/{split}/{name}/{metric_name}"] = metric_value

#                 metrics_str = f"loss={avg_loss:.4f}"
#                 if 'accuracy' in metrics:
#                     metrics_str += f", accuracy={metrics['accuracy']:.4f}"
#                 if 'f1' in metrics:
#                     metrics_str += f", f1={metrics['f1']:.4f}"
#                 logger.info(f"{name} → {metrics_str}")

#             except Exception as e:
#                 logger.error(f"Error evaluating dataset {name}: {e}")
#                 continue

#     if args.use_wandb and is_main_process():
#         wandb.log(all_metrics)

#     return results
def evaluate_model(model, device, dataset_ids, split, args):
    """
    Evaluate the model on given UCITaskDataset instances.
    Returns a dict: {dataset_name: {loss, accuracy, f1}}.
    """
    import sklearn.metrics as skm

    model.eval()
    results = {}
    all_metrics = {}
    
    # Get the actual model (unwrap DDP if needed)
    actual_model = model.module if isinstance(model, DDP) else model

    with torch.no_grad():
        for name, dsid in dataset_ids.items():
            try:
                median = compute_median(dsid, args)
                ds = UCITaskDataset(dsid, split, args, median)
                if len(ds) == 0:
                    logger.warning(f"Skipping empty dataset: {name}")
                    continue

                total_loss = 0.0
                count = 0
                y_true = []
                y_pred = []

                for example in ds:
                    with torch.amp.autocast(device_type='cuda', enabled=args.use_fp16):
                        # Use actual_model instead of model
                        inp, mask, pos_ids, pos = actual_model._prepare_example(example)
                        inp = inp.to(device)
                        mask = mask.to(device)
                        pos_ids = pos_ids.to(device)
                        seq_len = inp.size(1)
                        max_allowed = actual_model.model.config.max_position_embeddings
                        if seq_len > max_allowed:
                            inp = inp[:, :max_allowed, :]
                            mask = mask[:, :max_allowed]
                            pos_ids = pos_ids[:, :max_allowed]
                            pos = min(pos, max_allowed - 1)
                            seq_len = max_allowed

                        token_type_ids = torch.zeros((1, seq_len), dtype=torch.long, device=device)
                        out = actual_model.model(
                            inputs_embeds=inp,
                            attention_mask=mask,
                            token_type_ids=token_type_ids,
                            position_ids=pos_ids,
                            return_dict=True
                        )
                        hidden = out.last_hidden_state[0, pos]
                        feature = example.features[example.target_column_id]
                        has_categories = (
                            hasattr(feature, 'categories') 
                            and feature.categories 
                            and len(feature.categories) > 0
                        )

                        # Classification head (always used)
                        vec = actual_model.class_head(hidden)

                        if has_categories:
                            # Multi-class classification
                            cats = feature.categories
                            protos = torch.stack([
                                feature.categories_embedding[actual_model.config.embedding_model][c].to(device)
                                for c in cats
                            ], dim=0)
                            logits = vec.unsqueeze(0) @ protos.t()
                            
                            # FIXED: Handle string categorical values
                            target_value = example.target_row[example.target_column_id]
                            target_str = str(target_value)
                            
                            try:
                                target_idx = cats.index(target_str)
                            except ValueError:
                                # Try without quotes if present
                                if target_str.startswith("'") and target_str.endswith("'"):
                                    target_str = target_str[1:-1]
                                    try:
                                        target_idx = cats.index(target_str)
                                    except ValueError:
                                        logger.warning(f"Target value '{target_str}' not found in categories")
                                        target_idx = 0
                                else:
                                    logger.warning(f"Target value '{target_str}' not found in categories")
                                    target_idx = 0
                            
                            target_tensor = torch.tensor([target_idx], device=device)
                            loss = F.cross_entropy(logits, target_tensor)
                            pred_idx = logits.argmax(dim=1).item()
                            y_true.append(target_idx)
                            y_pred.append(pred_idx)
                            # Inside the evaluation loop, after computing logits
                            if count < 5:  # First 5 examples
                                logger.info(f"Example {i}:")
                                logger.info(f"  Logits: {logits.detach().cpu().numpy()}")
                                logger.info(f"  Prediction: {pred_idx if has_categories else pred_bin}")
                                logger.info(f"  True label: {target_idx if has_categories else true_bin}")
                                logger.info(f"  Loss: {loss.item()}")

                        else:
                            # Binary classification of continuous target
                            # Use the model's binary prototypes if available
                            if hasattr(actual_model, 'binary_prototypes'):
                                binary_embs = F.normalize(actual_model.binary_prototypes, dim=1)
                            else:
                                binary_embs = torch.randn(2, actual_model.config.model_dim, device=device)
                                binary_embs = F.normalize(binary_embs, dim=1)
                            
                            logits = vec.unsqueeze(0) @ binary_embs.t()
                            true_val = float(example.target_row[example.target_column_id])
                            true_bin = 1 if true_val >= (median if median is not None else 0.0) else 0
                            target_idx = torch.tensor([true_bin], device=device)
                            loss = F.cross_entropy(logits, target_idx)
                            pred_bin = logits.argmax(dim=1).item()
                            y_true.append(true_bin)
                            y_pred.append(pred_bin)

                        total_loss += loss.item()
                        count += 1

                avg_loss = total_loss / (count if count > 0 else 1)
                metrics = {'loss': avg_loss}

                if y_true:
                    y_true_np = np.array(y_true)
                    y_pred_np = np.array(y_pred)
                    accuracy = accuracy_score(y_true_np, y_pred_np)
                    metrics['accuracy'] = accuracy
                    unique_classes = np.unique(np.concatenate([y_true_np, y_pred_np]))
                    if len(unique_classes) <= 2:
                        f1 = f1_score(y_true_np, y_pred_np, average='binary', zero_division=0)
                    else:
                        f1 = f1_score(y_true_np, y_pred_np, average='macro', zero_division=0)
                    metrics['f1'] = f1

                results[name] = metrics
                for metric_name, metric_value in metrics.items():
                    all_metrics[f"eval/{split}/{name}/{metric_name}"] = metric_value

                metrics_str = f"loss={avg_loss:.4f}"
                if 'accuracy' in metrics:
                    metrics_str += f", accuracy={metrics['accuracy']:.4f}"
                if 'f1' in metrics:
                    metrics_str += f", f1={metrics['f1']:.4f}"
                logger.info(f"{name} → {metrics_str}")

            except Exception as e:
                logger.error(f"Error evaluating dataset {name}: {e}")
                import traceback
                traceback.print_exc()
                continue

    if args.use_wandb and is_main_process():
        wandb.log(all_metrics)

    return results
# ------------------------
# TRAINING LOOP WITH LR SCHEDULE & EARLY STOPPING
# ------------------------

def train(args):
    # 1) Setup distributed, seeds, and wandb
    rank, world_size, local_rank = setup_distributed()
    set_seed(args.seed + rank)

    if args.use_wandb and is_main_process():
        wandb.init(
            project=args.wandb_project,
            name=f"uim_classif_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
            config=vars(args)
        )

    if is_main_process():
        os.makedirs(args.output_dir, exist_ok=True)
    if dist.is_initialized():
        dist.barrier()

    # 2) Load all training dataset IDs
    temp = UCIDataset(
        dataset_id=args.dataset_id,
        model_names=[args.embedding_model],
        split="train",
        kshot_setting=args.kshot_setting,
        column_missing_setting=args.column_missing_setting,
        test_size=args.test_size,
        seed=args.seed
    )
    all_ids = temp.train_dataset_ids[:5]
    if is_main_process():
        logger.info(f"Found {len(all_ids)} total training datasets")
    test_ds = UCITaskDataset(all_ids[0], 'train', args, None)
    if len(test_ds) > 0:
        ex = test_ds[0]
        logger.info(f"Dataset info:")
        logger.info(f"  Features: {len(ex.features)}")
        logger.info(f"  Target feature: {ex.features[ex.target_column_id].name}")
        logger.info(f"  Target type: {ex.features[ex.target_column_id].dtype}")
        if ex.features[ex.target_column_id].categories:
            logger.info(f"  Categories: {ex.features[ex.target_column_id].categories}")
    # 3) Build per-dataset UCITaskDataset lists (with median precomputed)
#     train_datasets, valid_datasets = [], []
#     loading_iter = tqdm(all_ids, desc="Loading datasets") if is_main_process() else all_ids

#     for d_id in loading_iter:
#         try:
#             median = compute_median(d_id, args)
#             tr = UCITaskDataset(d_id, 'train', args, median)
#             if len(tr) > 0:
#                 train_datasets.append(tr)
#             va = UCITaskDataset(d_id, 'test', args, median)
#             if len(va) > 0:
#                 valid_datasets.append(va)
#         except Exception as e:
#             if is_main_process():
#                 logger.warning(f"Error loading dataset {d_id}: {e}")

    # 4) Device setup
    if local_rank >= 0:
        device = torch.device(f'cuda:{local_rank}')
    else:
        device = torch.device('cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu')

#     # 5) Model, optimizer, scheduler
    model_config = argparse.Namespace(
        model_name=args.model_name,
        embedding_model=args.embedding_model,
        embedding_dim=args.embedding_dim,
        model_dim=args.model_dim,
        num_mixtures=args.num_mixtures,
        num_heads=args.num_heads,
        num_inds=args.num_inds,
        num_layers=args.num_layers
    )
    model = FSLModel(model_config).to(device)

    if world_size > 1:
        model = DDP(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True
        )
        if is_main_process():
            logger.info(f"Using DistributedDataParallel with {world_size} GPUs")
    else:
        if is_main_process():
            logger.info("Using single GPU/CPU")

    optimizer = optim.AdamW(
        model.parameters(),
        lr=args.learning_rate,
        weight_decay=args.weight_decay
    )

#     # Create DataLoaders
#     all_train_data = ConcatDataset([CollatedDataset(ds) for ds in train_datasets])
#     all_valid_data = ConcatDataset([CollatedDataset(ds) for ds in valid_datasets])

#     # Create samplers for distributed training
#     if world_size > 1:
#         train_sampler = DistributedSampler(
#             all_train_data,
#             num_replicas=world_size,
#             rank=rank,
#             shuffle=True,
#             seed=args.seed
#         )
#         valid_sampler = DistributedSampler(
#             all_valid_data,
#             num_replicas=world_size,
#             rank=rank,
#             shuffle=False
#         )
#     else:
#         train_sampler = None
#         valid_sampler = None

#     # Create DataLoaders
#     train_loader = DataLoader(
#         all_train_data,
#         batch_size=args.batch_size,
#         shuffle=(train_sampler is None),
#         sampler=train_sampler,
#         num_workers=min(8, os.cpu_count() or 1),
#         pin_memory=True,
#         persistent_workers=True,
#         prefetch_factor=4,
#         collate_fn=custom_collate_fn,
#         drop_last=True
#     )

#     valid_loader = DataLoader(
#         all_valid_data,
#         batch_size=args.batch_size,
#         shuffle=False,
#         sampler=valid_sampler,
#         num_workers=4,
#         pin_memory=True,
#         persistent_workers=True,
#         collate_fn=custom_collate_fn
#     )
    # Load all datasets as before
    train_datasets, valid_datasets = [], []
    loading_iter = tqdm(all_ids, desc="Loading datasets") if is_main_process() else all_ids

    for d_id in loading_iter:
        try:
            median = compute_median(d_id, args)
            tr = UCITaskDataset(d_id, 'train', args, median)
            if len(tr) > 0:
                train_datasets.append(tr)
            va = UCITaskDataset(d_id, 'test', args, median)
            if len(va) > 0:
                valid_datasets.append(va)
                            # After creating train and validation datasets
            logger.info(f"Train dataset size: {len(tr)}")
            logger.info(f"Validation dataset size: {len(va)}")

            # Check if examples overlap
            train_example = tr[0]
            val_example = va[0]
            logger.info(f"Train example target: {train_example.target_row[train_example.target_column_id]}")
            logger.info(f"Val example target: {val_example.target_row[val_example.target_column_id]}")

            # Check label distribution
            train_labels = [tr[i].target_row[tr[i].target_column_id] for i in range(min(100, len(tr)))]
            val_labels = [va[i].target_row[va[i].target_column_id] for i in range(min(100, len(va)))]
            logger.info(f"Unique train labels: {set(train_labels)}")
            logger.info(f"Unique val labels: {set(val_labels)}")
        except Exception as e:
            if is_main_process():
                logger.warning(f"Error loading dataset {d_id}: {e}")

    if is_main_process():
        total_train_examples = sum(len(ds) for ds in train_datasets)
        total_val_examples = sum(len(ds) for ds in valid_datasets)
        logger.info(f"Loaded {len(train_datasets)} training datasets with {total_train_examples} total examples")
        logger.info(f"Loaded {len(valid_datasets)} validation datasets with {total_val_examples} total examples")

    # Create sampling datasets
    train_data = RandomSamplingDataset(
        train_datasets, 
        args.max_examples_per_epoch,
        seed=args.seed,
        batch_size=args.batch_size
    )
    valid_data = RandomSamplingDataset(
        valid_datasets,
        args.max_val_examples_per_epoch,
        seed=args.seed,
        batch_size=args.batch_size
    )

    # Create samplers for distributed training
    if world_size > 1:
        train_sampler = DistributedSampler(
            train_data,
            num_replicas=world_size,
            rank=rank,
            shuffle=True,
            seed=args.seed
        )
        valid_sampler = DistributedSampler(
            valid_data,
            num_replicas=world_size,
            rank=rank,
            shuffle=False
        )
    else:
        train_sampler = None
        valid_sampler = None

    # Create DataLoaders
    train_loader = DataLoader(
        train_data,
        batch_size=args.batch_size,
        shuffle=(train_sampler is None),
        sampler=train_sampler,
        num_workers=0,  # Reasonable number
        pin_memory=False,
        persistent_workers=False,
        # prefetch_factor=2,
        collate_fn=custom_collate_fn,
        drop_last=True
    )

    valid_loader = DataLoader(
        valid_data,
        batch_size=args.batch_size,
        shuffle=False,
        sampler=valid_sampler,
        num_workers=0,
        pin_memory=False,
        persistent_workers=False,
        collate_fn=custom_collate_fn
    )

    if is_main_process():
        logger.info(f"Train batches per epoch: {len(train_loader)} ({len(train_data)} examples / {args.batch_size} batch_size)")
        logger.info(f"Validation batches per epoch: {len(valid_loader)}")

    # Suppress transformers warnings
    import warnings
    warnings.filterwarnings("ignore", message="Using `TRANSFORMERS_CACHE` is deprecated")
    # Compute total training steps and warmup steps
    steps_per_epoch = len(train_loader)
    total_update_steps = (steps_per_epoch // args.gradient_accumulation_steps) * args.num_epochs
    warmup_steps = int(0.1 * total_update_steps)

    from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

    # Replace linear scheduler with:
    scheduler = CosineAnnealingWarmRestarts(
        optimizer, 
        T_0=steps_per_epoch * 2,  # Restart every 2 epochs
        T_mult=1,
        eta_min=1e-6
    )

    scaler = torch.cuda.amp.GradScaler() if (args.use_fp16 and torch.cuda.is_available()) else None
    optimizer.zero_grad(set_to_none=True)

    # 6) Early stopping setup
    best_val_loss = float('inf')
    no_improve_counter = 0
    patience_epochs = args.patience_epochs
    global_step = 0

    # Training loop
    for epoch in range(args.num_epochs):
        model.train()
        epoch_loss = 0.0
        num_batches = 0
        
        # Set epoch for distributed sampler
        if train_sampler is not None:
            train_sampler.set_epoch(epoch)
        
        # Create progress bar
        train_iter = tqdm(train_loader, desc=f"Epoch {epoch+1}/{args.num_epochs} [Train]") if is_main_process() else train_loader
        
        for batch_idx, batch in enumerate(train_iter):
            global_step += 1
            num_batches += 1
            
            # Memory diagnostics every 100 steps
            if is_main_process() and (batch_idx + 1) % 100 == 0:
                allocated_before = torch.cuda.memory_allocated() / 1024**3
                
                # Only clear if using more than 20GB (you have 44GB)
                if allocated_before > 20.0:
                    logger.warning(f"High memory usage detected: {allocated_before:.2f} GB")
                    master = model.module if isinstance(model, DDP) else model
                    if hasattr(master, 'clear_all_caches'):
                        master.clear_all_caches()
                    clear_memory()
                    
                    allocated_after = torch.cuda.memory_allocated() / 1024**3
                    logger.info(f"Memory reduced from {allocated_before:.2f} GB to {allocated_after:.2f} GB")
            dataset_losses = {}
            #             # Track per-dataset losses
            # dataset_losses = {}

            # In training loop


            # Every epoch, log dataset difficulties
            if is_main_process() and batch_idx % 100 == 0:
                for ds_idx, losses in dataset_losses.items():
                    if losses:
                        avg_loss = sum(losses) / len(losses)
                        logger.info(f"Dataset {ds_idx}: avg_loss={avg_loss:.4f}")
            with torch.amp.autocast(device_type='cuda', enabled=bool(scaler)):
                # dataset_name = batch.get('dataset_name', 'unknown')
                loss = model(batch)
                #                 # Track per-dataset loss
                # if dataset_name not in dataset_losses:
                #     dataset_losses[dataset_name] = []
                # dataset_losses[dataset_name].append(loss.item())
                
                # # Log problematic datasets
                # if loss.item() > 2.0:
                #     logger.warning(f"High loss {loss.item():.4f} from dataset: {dataset_name}")
                # loss = loss / args.gradient_accumulation_steps
            current_dataset_idx = (batch_idx * args.batch_size // 500) % len(train_datasets)
            if current_dataset_idx not in dataset_losses:
                dataset_losses[current_dataset_idx] = []
            dataset_losses[current_dataset_idx].append(loss.item())
            if scaler:
                scaler.scale(loss).backward()
            else:
                loss.backward()

            if (batch_idx + 1) % args.gradient_accumulation_steps == 0:
                if scaler:
                    scaler.unscale_(optimizer)
                    nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                    optimizer.step()

                scheduler.step()
                optimizer.zero_grad(set_to_none=True)

            epoch_loss += loss.item() * args.gradient_accumulation_steps

            if is_main_process():
                loss_smoother = LossSmoothing(alpha=0.95)
                smoothed_loss = loss_smoother.update(loss.item())
                train_iter.set_postfix({"loss": smoothed_loss * args.gradient_accumulation_steps})
                if args.use_wandb and (global_step % args.log_interval == 0):
                    wandb.log({
                        "train/loss": loss.item() * args.gradient_accumulation_steps,
                        "train/lr": scheduler.get_last_lr()[0],
                        "step": global_step
                    })

        # Aggregate training loss across all ranks
        avg_train_loss = epoch_loss / num_batches
        if dist.is_initialized():
            avg_loss_tensor = torch.tensor(avg_train_loss, device=device)
            dist.all_reduce(avg_loss_tensor, op=dist.ReduceOp.AVG)
            avg_train_loss = avg_loss_tensor.item()

        if is_main_process():
            logger.info(f"Epoch {epoch+1} train loss: {avg_train_loss:.4f}")

        # 8) Validation at end of epoch
        model.eval()
        val_loss_accum = 0.0
        num_val_batches = 0
        
        val_iter = tqdm(valid_loader, desc=f"Epoch {epoch+1}/{args.num_epochs} [Valid]") if is_main_process() else valid_loader

        with torch.no_grad():
            for batch in val_iter:
                with torch.amp.autocast(device_type='cuda', enabled=bool(scaler)):
                    loss = model(batch)
                val_loss_accum += loss.item()
                num_val_batches += 1
                if is_main_process():
                    val_iter.set_postfix({"val_loss": loss.item()})

        # Aggregate validation loss
        avg_val_loss = val_loss_accum / num_val_batches
        if dist.is_initialized():
            avg_val_tensor = torch.tensor(avg_val_loss, device=device)
            dist.all_reduce(avg_val_tensor, op=dist.ReduceOp.AVG)
            avg_val_loss = avg_val_tensor.item()

        if is_main_process():
            logger.info(f"Epoch {epoch+1} val loss: {avg_val_loss:.4f}")
            if args.use_wandb:
                wandb.log({"val/loss": avg_val_loss, "epoch": epoch+1})

        # 9) Early stopping logic
        if is_main_process():
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                no_improve_counter = 0
                # Save best checkpoint
                state_dict = model.module.state_dict() if isinstance(model, DDP) else model.state_dict()
                torch.save(state_dict, os.path.join(args.output_dir, "best_model.pt"))
                logger.info(f"New best val loss: {best_val_loss:.4f} → saved best_model.pt")
            else:
                no_improve_counter += 1
                logger.info(f"No improvement in val loss for {no_improve_counter} epoch(s).")

            if no_improve_counter >= patience_epochs:
                logger.info(f"Early stopping triggered (patience = {patience_epochs} epochs).")
                break

    # 10) Final epoch save if not already saved
    if is_main_process():
        final_ckpt = os.path.join(args.output_dir, "last_model.pt")
        final_state = model.module.state_dict() if isinstance(model, DDP) else model.state_dict()
        torch.save(final_state, final_ckpt)
        logger.info(f"Saved final checkpoint to {final_ckpt}")

        # 11) Post-training, load best_model.pt and run full evaluation
        logger.info("Loading best_model.pt for final evaluation...")
        best_state = torch.load(os.path.join(args.output_dir, "best_model.pt"), map_location=device)
        if isinstance(model, DDP):
            model.module.load_state_dict(best_state)
        else:
            model.load_state_dict(best_state)

        logger.info("Final evaluation on test split:")
        DATASET_IDS = {
            "0153_BNG_cylinder-bands": 849,
            "Raisin": 850,
            "Steel Industry Energy Consumption": 851,
            "Higher Education Students Performance Evaluation": 856,
            "Risk Factor Prediction of Chronic Kidney Disease": 857,
            "Maternal Health Risk": 863,
            "Room Occupancy Estimation": 864,
            "Cirrhosis Patient Survival Prediction": 878,
            "SUPPORT2": 880,
            "NHANES Age Prediction": 887,
            "AIDS Clinical Trials Group Study 175": 890,
            "CDC Diabetes Health Indicators": 891,
            "Differentiated Thyroid Cancer Recurrence": 915,
            "Infrared Thermography Temperature": 925,
            "National Poll on Healthy Aging": 936,
            "Regensburg Pediatric Appendicitis": 938,
            "RT-IoT2022": 942,
        }
        results = evaluate_model(model, device, DATASET_IDS, 'test', args)
        for name, metrics in results.items():
            if 'accuracy' in metrics and 'f1' in metrics:
                logger.info(f"  {name}: loss={metrics['loss']:.4f}, accuracy={metrics['accuracy']:.4f}, f1={metrics['f1']:.4f}")
            else:
                logger.info(f"  {name}: loss={metrics['loss']:.4f}")

    cleanup_distributed()

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--kshot_setting', type=str, default='fixed::5')
    parser.add_argument('--column_missing_setting', type=str, default='range::1:3')
    parser.add_argument('--test_size', type=float, default=0.2)
    parser.add_argument('--model_name', type=str, default='bert-base-uncased')
    parser.add_argument('--embedding_model', type=str, default='bert-base-uncased')
    parser.add_argument('--embedding_dim', type=int, default=768)
    parser.add_argument('--model_dim', type=int, default=768)
    parser.add_argument('--num_heads', type=int, default=8)
    parser.add_argument('--num_inds', type=int, default=16)
    parser.add_argument('--num_mixtures', type=int, default=5)
    parser.add_argument('--num_layers', type=int, default=6)
    parser.add_argument('--batch_size', type=int, default=32,
                        help="Increased batch size to utilize 44GB GPU")
    parser.add_argument('--eval_batch_size', type=int, default=32)
    parser.add_argument('--num_epochs', type=int, default=10)
    parser.add_argument('--learning_rate', type=float, default=0.00001,
                        help="Initial learning rate")
    parser.add_argument('--min_lr', type=float, default=1e-6)
    parser.add_argument('--weight_decay', type=float, default=0.01,
                        help="Stronger weight decay for regularization")
    parser.add_argument('--max_grad_norm', type=float, default=0.5)
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--no_cuda', action='store_true')
    parser.add_argument('--use_fp16', action='store_true')
    parser.add_argument('--output_dir', type=str, default='/playpen-nvme/scribble/shbhat/universal_machine/src/universal_inference_machine/outputs_classification_new750')
    parser.add_argument('--log_interval', type=int, default=100,
                        help="Log to wandb every N steps")
    parser.add_argument('--cache_clear_interval', type=int, default=1000,
                        help="Clear model caches every N steps (increased)")
    parser.add_argument('--use_wandb', action='store_true')
    parser.add_argument('--wandb_project', type=str, default='universal-inference-machine')
    parser.add_argument('--dataset_id', type=int, default=1)
    parser.add_argument('--gradient_accumulation_steps', type=int, default=2,
                        help='Reduced due to larger batch size')
    parser.add_argument('--patience_epochs', type=int, default=3,
                        help="Patience for early stopping (in full epochs)")
     # Add these arguments to your parser:
    parser.add_argument('--max_examples_per_epoch', type=int, default=600000,
                         help="Maximum examples to process per epoch")
    parser.add_argument('--max_val_examples_per_epoch', type=int, default=50000,
                         help="Maximum validation examples per epoch")
    args = parser.parse_args()

    try:
        os.makedirs(args.output_dir, exist_ok=True)
        train(args)
    except Exception as e:
        logger.error(f"Training failed with error: {e}")
        raise

if __name__ == '__main__':
    # Set multiprocessing start method for CUDA
    import multiprocessing
    multiprocessing.set_start_method('spawn', force=True)
    main()