# # import os
# # import torch
# # import random
# # import numpy as np
# # import pandas as pd
# # import yaml
# # from torch.utils.data import Dataset
# # from typing import List, Dict, Tuple, Optional, Union
# # from sklearn.model_selection import train_test_split
# # from sklearn.preprocessing import LabelEncoder, StandardScaler
# # import logging

# # from utils.data_utils import Feature, TabularData, Example, parse_missing_setting, parse_kshot_setting, sample_missing_columns
# # from utils.embedding_utils import get_embedding_model

# # logger = logging.getLogger(__name__)

# # class UCIDataset(Dataset):
# #     def __init__(self, 
# #                  dataset_id: Union[int, str], 
# #                  model_names: List[str], 
# #                  split: str, 
# #                  kshot_setting: str, 
# #                  column_missing_setting: str, 
# #                  test_size: float = 0.2, 
# #                  seed: int = 42,
# #                  embedding_model: str = "bert-base-uncased",
# #                  yaml_config_path: str = ""):
# #         """
# #         Initialize UCI dataset from YAML configuration
        
# #         Args:
# #             dataset_id: Dataset ID matching a key in YAML
# #             model_names: List of model names for embeddings 
# #             split: 'train' or 'test'
# #             kshot_setting: K-shot setting string
# #             column_missing_setting: Column missing setting string
# #             test_size: Proportion for test split
# #             seed: Random seed
# #             embedding_model: Name of embedding model to use
# #             yaml_config_path: Path to YAML configuration file
# #         """
# #         super().__init__()

# #         self.kshot_setting = kshot_setting
# #         self.column_missing_setting = column_missing_setting
# #         self.dataset_id = dataset_id
# #         self.embedding_model = embedding_model
# #         self.yaml_config_path = yaml_config_path
        
# #         # Load YAML and get all dataset IDs
# #         with open(self.yaml_config_path, 'r') as f:
# #             yaml_config = yaml.safe_load(f)
        
# #         # Get all dataset IDs from YAML keys
# #         self.dataset_ids = list(yaml_config.keys())
# #         logger.info(f"Found {len(self.dataset_ids)} datasets in YAML configuration")
        
# #         # Create train/val/test splits of dataset IDs
# #         random.seed(seed)
# #         random.shuffle(self.dataset_ids)
# #         total = len(self.dataset_ids)
# #         train_size = int(0.8 * total)
# #         val_size = int(0.15 * total)
        
# #         self.train_dataset_ids = self.dataset_ids[:train_size]
# #         self.val_dataset_ids = self.dataset_ids[train_size:train_size+val_size]
# #         self.test_dataset_ids = self.dataset_ids[train_size+val_size:]
        
# #         logger.info(f"Train datasets: {len(self.train_dataset_ids)}")
# #         logger.info(f"Val datasets: {len(self.val_dataset_ids)}")
# #         logger.info(f"Test datasets: {len(self.test_dataset_ids)}")

# #         # Preprocess dataset
# #         self.tabular_data = self.preprocess_dataset(dataset_id, model_names, test_size, seed)
# #         self.data = getattr(self.tabular_data, f"{split}_rows")

# #     def __len__(self):
# #         return len(self.data)

# #     def __getitem__(self, idx):
# #         # Parse k-shot setting
# #         kshot = parse_kshot_setting(self.kshot_setting)
        
# #         # Sample k examples from training data
# #         if len(self.tabular_data.train_rows) <= kshot:
# #             # If not enough examples, just use all with replacement
# #             fewshot_row_ids = [random.randint(0, len(self.tabular_data.train_rows)-1) for _ in range(kshot)]
# #         else:
# #             # Sample without replacement
# #             fewshot_row_ids = random.sample(range(len(self.tabular_data.train_rows)), k=kshot)
            
# #         fewshot_rows = [self.tabular_data.train_rows[i] for i in fewshot_row_ids]

# #         # Get target row
# #         target_row = self.data[idx]
        
# #         # Sample missing columns
# #         num_cols = len(target_row)
# #         target_column_id = random.randrange(num_cols)
# #         sampled = sample_missing_columns(num_cols, self.column_missing_setting)
# #         missing_column_ids = [i for i in sampled if 0 <= i < num_cols]
# #         if target_column_id in missing_column_ids:
# #             missing_column_ids.remove(target_column_id)

# #         return Example(
# #             description=self.tabular_data.description,
# #             features=self.tabular_data.features,
# #             fewshot_rows=fewshot_rows,
# #             target_row=target_row,
# #             target_column_id=target_column_id,
# #             missing_column_ids=missing_column_ids
# #         )

# #     def preprocess_dataset(self, dataset_id, model_names: List[str], test_size: float, seed: int) -> TabularData:
# #         """
# #         Preprocess dataset from YAML configuration and local CSV files
# #         """
# #         # Load YAML configuration
# #         with open(self.yaml_config_path, 'r') as f:
# #             yaml_config = yaml.safe_load(f)
        
# #         # Find dataset by ID
# #         if dataset_id not in yaml_config:
# #             raise ValueError(f"Dataset {dataset_id} not found in YAML configuration")
        
# #         dataset_config = yaml_config[dataset_id]
        
# #         # Read CSV file
# #         df = pd.read_csv(dataset_config['path'])
        
# #         # Create dataset description
# #         description = dataset_config['dataset_description']
        
# #         # Create feature list
# #         features = []
# #         feature_descriptions = []
# #         categories_list = []
        
# #         # Process feature columns
# #         for col in df.columns:
# #             feature_name = col
            
# #             # Get feature description from YAML
# #             if col in dataset_config.get('feature_descriptions', {}):
# #                 feature_desc = f"{col}: {dataset_config['feature_descriptions'][col]}"
# #             else:
# #                 feature_desc = f"{col}: A feature in the dataset"
                
# #             feature_descriptions.append(feature_desc)
            
# #             # Determine feature type based on data
# #             col_dtype = df[col].dtype
# #             if col_dtype.name in ['object', 'category']:
# #                 dtype = 'categorical'
# #                 categories = list(df[col].astype(str).unique())
# #                 categories_list.append(categories)
# #                 value_range = []
# #             else:
# #                 dtype = "real"
# #                 categories = []
# #                 categories_list.append([])
# #                 values = df[col].values
# #                 if values.size > 0:
# #                     value_range = [float(values.min()), float(values.max())]
# #                 else:
# #                     value_range = []
            
# #             # Create feature object
# #             features.append(Feature(
# #                 name=col,
# #                 description=feature_desc,
# #                 description_embedding={},  # Will fill later
# #                 dtype=dtype,
# #                 categories=categories,
# #                 categories_embedding={},  # Will fill later
# #                 value_range=value_range
# #             ))
        
# #         # Generate embeddings for features
# #         desc_embeddings, cat_embeddings = self.get_feature_embeddings(
# #             feature_descriptions, 
# #             categories_list, 
# #             model_names
# #         )
        
# #         # Add embeddings to features
# #         for i, feature in enumerate(features):
# #             feature.description_embedding = desc_embeddings[feature.description]
# #             if feature.dtype == "categorical" and feature.categories:
# #                 feature.categories_embedding = {model_name: {} for model_name in model_names}
# #                 for cat in feature.categories:
# #                     if i < len(cat_embeddings) and cat in cat_embeddings[i]:
# #                         for model_name in model_names:
# #                             if model_name in cat_embeddings[i][cat]:
# #                                 feature.categories_embedding[model_name][cat] = cat_embeddings[i][cat][model_name]
        
# #         # Normalize data
# #         preprocessed_df = pd.DataFrame()
        
# #         # Process each column
# #         for i, feature in enumerate(features):
# #             col_name = feature.name
# #             raw = df[col_name]
            
# #             if feature.dtype == "categorical":
# #                 # Encode categorical features
# #                 le = LabelEncoder()
# #                 preprocessed_df[col_name] = le.fit_transform(df[col_name].astype(str))
# #             else:
# #                 vals = raw.values.astype(float)
# #                 min_v = vals.min()
# #                 max_v = vals.max()
# #                 if max_v > min_v:
# #                     preprocessed_df[col_name] = (vals - min_v) / (max_v - min_v)
# #                 else:
# #                     preprocessed_df[col_name] = np.zeros_like(vals)
        
# #         # Split into train and test sets
# #         train_df, test_df = train_test_split(
# #             preprocessed_df, 
# #             test_size=test_size, 
# #             random_state=seed
# #         )
        
# #         # Convert to list format
# #         train_rows = train_df.values.tolist()
# #         test_rows = test_df.values.tolist()
        
# #         # Create final TabularData object
# #         return TabularData(
# #             description=description,
# #             features=features,
# #             train_rows=train_rows,
# #             test_rows=test_rows
# #         )

# #     def get_feature_embeddings(
# #         self, 
# #         descriptions: List[str], 
# #         categories_list: List[List[str]], 
# #         model_names: List[str]
# #     ) -> Tuple[Dict[str, Dict[str, torch.Tensor]], List[Dict[str, Dict[str, torch.Tensor]]]]:
# #         """
# #         Generate embeddings for feature descriptions and categories using the specified embedding model
        
# #         Args:
# #             descriptions: List of feature descriptions
# #             categories_list: List of category lists for each feature
# #             model_names: List of model names (for compatibility)
            
# #         Returns:
# #             Tuple of (description_embeddings, categories_embeddings)
# #         """
# #         # Initialize embedding dictionaries
# #         desc_embeddings = {desc: {} for desc in descriptions}
# #         cat_embeddings = []
        
# #         # Initialize category embeddings structure
# #         for cats in categories_list:
# #             cat_dict = {}
# #             for cat in cats:
# #                 cat_dict[cat] = {}
# #             cat_embeddings.append(cat_dict)
        
# #         logger.info(f"Loading embedding model: {self.embedding_model}")
# #         embedding_model = get_embedding_model(self.embedding_model)
        
# #         # Process descriptions in batches to avoid GPU memory issues
# #         batch_size = 32
# #         for i in range(0, len(descriptions), batch_size):
# #             batch_descriptions = descriptions[i:i+batch_size]
# #             batch_indices = list(range(i, min(i+batch_size, len(descriptions))))
            
# #             logger.info(f"Generating embeddings for descriptions batch {i//batch_size + 1}/{(len(descriptions) + batch_size - 1)//batch_size}")
# #             batch_embeddings = embedding_model(batch_descriptions)
            
# #             # Store embeddings
# #             for batch_idx, desc_idx in enumerate(batch_indices):
# #                 desc = descriptions[desc_idx]
# #                 # Store same embedding for all model names in the list for compatibility
# #                 for model_name in model_names:
# #                     desc_embeddings[desc][model_name] = batch_embeddings[batch_idx]
        
# #         # Process all categories
# #         all_categories = []
# #         category_mapping = []  # To map back to the original structure
        
# #         for feature_idx, cats in enumerate(categories_list):
# #             for cat in cats:
# #                 all_categories.append(cat)
# #                 category_mapping.append((feature_idx, cat))
        
# #         # Process categories in batches
# #         for i in range(0, len(all_categories), batch_size):
# #             batch_categories = all_categories[i:i+batch_size]
# #             batch_indices = list(range(i, min(i+batch_size, len(all_categories))))
            
# #             logger.info(f"Generating embeddings for categories batch {i//batch_size + 1}/{(len(all_categories) + batch_size - 1)//batch_size}")
# #             batch_embeddings = embedding_model(batch_categories)
            
# #             # Store embeddings
# #             for batch_idx, cat_idx in enumerate(batch_indices):
# #                 feature_idx, cat = category_mapping[cat_idx]
# #                 # Store same embedding for all model names in the list for compatibility
# #                 for model_name in model_names:
# #                     cat_embeddings[feature_idx][cat][model_name] = batch_embeddings[batch_idx]
        
# #         return desc_embeddings, cat_embeddings
import os
import torch
import random
import numpy as np
import pandas as pd
import yaml
from torch.utils.data import Dataset
from typing import List, Dict, Tuple, Optional, Union
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
import logging

from utils.data_utils import Feature, TabularData, Example, parse_missing_setting, parse_kshot_setting, sample_missing_columns
from utils.embedding_utils import get_embedding_model

logger = logging.getLogger(__name__)

# class UCIDataset(Dataset):
#     def __init__(self, 
#                  dataset_id: Union[int, str], 
#                  model_names: List[str], 
#                  split: str, 
#                  kshot_setting: str, 
#                  column_missing_setting: str, 
#                  test_size: float = 0.2, 
#                  seed: int = 42,
#                  embedding_model: str = "bert-base-uncased",
#                 #  yaml_config_path: str = ""):
#                 yaml_config_path: str = ""):
#         """
#         Initialize UCI dataset from YAML configuration
        
#         Args:
#             dataset_id: Dataset ID - either integer index or string key from YAML
#             model_names: List of model names for embeddings 
#             split: 'train' or 'test'
#             kshot_setting: K-shot setting string
#             column_missing_setting: Column missing setting string
#             test_size: Proportion for test split
#             seed: Random seed
#             embedding_model: Name of embedding model to use
#             yaml_config_path: Path to YAML configuration file
#         """
#         super().__init__()

#         self.kshot_setting = kshot_setting
#         self.column_missing_setting = column_missing_setting
#         self.embedding_model = embedding_model
#         self.yaml_config_path = yaml_config_path
        
#         # Load YAML and get all dataset IDs
#         try:
#             with open(self.yaml_config_path, 'r') as f:
#                 yaml_config = yaml.safe_load(f)
#         except yaml.YAMLError as e:
#             # logger.warning(f"YAML parsing error, attempting to load line by line: {e}")
#             yaml_config = self._load_yaml_safe(self.yaml_config_path)
        
#         # Get all dataset IDs from YAML keys
#         all_dataset_ids = list(yaml_config.keys())
#         logger.info(f"Found {len(all_dataset_ids)} datasets in YAML configuration")
        
#         # Validate datasets - only keep ones that can be loaded
#         self.dataset_ids = []
#         for ds_id in all_dataset_ids:
#             try:
#                 # Quick validation - check if path exists and is readable
#                 if 'path' in yaml_config[ds_id] and os.path.exists(yaml_config[ds_id]['path']):
#                     self.dataset_ids.append(ds_id)
#                 else:
#                     logger.warning(f"Skipping dataset {ds_id}: path not found or not specified")
#             except Exception as e:
#                 logger.warning(f"Skipping dataset {ds_id} due to error: {e}")
        
#         # logger.info(f"Successfully validated {len(self.dataset_ids)} datasets")
        
#         # Create train/val/test splits of dataset IDs
#         random.seed(seed)
#         random.shuffle(self.dataset_ids)
#         total = len(self.dataset_ids)
#         train_size = int(0.8 * total)
#         val_size = int(0.15 * total)
        
#         self.train_dataset_ids = self.dataset_ids[:train_size]
#         self.val_dataset_ids = self.dataset_ids[train_size:train_size+val_size]
#         self.test_dataset_ids = self.dataset_ids[train_size+val_size:]
        
#         logger.info(f"Train datasets: {len(self.train_dataset_ids)}")
#         logger.info(f"Val datasets: {len(self.val_dataset_ids)}")
#         logger.info(f"Test datasets: {len(self.test_dataset_ids)}")
        
#         # Handle dataset_id as either integer index or string key
#         if isinstance(dataset_id, int):
#             # If integer, use it as index into dataset_ids list
#             if 0 <= dataset_id < len(self.dataset_ids):
#                 actual_dataset_id = self.dataset_ids[dataset_id]
#                 logger.info(f"Converting integer dataset_id {dataset_id} to string key: {actual_dataset_id}")
#             else:
#                 # logger.warning(f"Integer dataset_id {dataset_id} out of range, using first dataset")
#                 actual_dataset_id = self.dataset_ids[0] if self.dataset_ids else None
#         else:
#             # If string, use directly
#             actual_dataset_id = dataset_id
            
#         if actual_dataset_id is None:
#             raise ValueError("No valid datasets found in YAML configuration")
            
#         self.dataset_id = actual_dataset_id

#         # Preprocess dataset
#         try:
#             self.tabular_data = self.preprocess_dataset(actual_dataset_id, model_names, test_size, seed)
#             self.data = getattr(self.tabular_data, f"{split}_rows")
#         except Exception as e:
#             logger.error(f"Failed to preprocess dataset {actual_dataset_id}: {e}")
#             # Create empty dataset as fallback
#             self.tabular_data = TabularData(
#                 description=f"Failed to load dataset {actual_dataset_id}",
#                 features=[],
#                 train_rows=[],
#                 test_rows=[]
#             )
#             self.data = []

#     def _load_yaml_safe(self, yaml_path):
#         """Attempt to load YAML file by parsing it manually when safe_load fails"""
#         yaml_config = {}
#         current_key = None
#         current_data = {}
        
#         with open(yaml_path, 'r') as f:
#             for line_num, line in enumerate(f, 1):
#                 try:
#                     # Try to identify dataset keys (non-indented lines ending with ':')
#                     if line.strip() and not line.startswith(' ') and line.strip().endswith(':'):
#                         # Save previous dataset if exists
#                         if current_key and current_data:
#                             yaml_config[current_key] = current_data
                        
#                         # Start new dataset
#                         current_key = line.strip()[:-1]
#                         current_data = {
#                             'dataset_description': '',
#                             'path': '',
#                             'feature_descriptions': {}
#                         }
#                     elif current_key:
#                         # Parse dataset properties
#                         if 'dataset_description:' in line:
#                             current_data['dataset_description'] = line.split('dataset_description:', 1)[1].strip().strip('"')
#                         elif 'path:' in line:
#                             current_data['path'] = line.split('path:', 1)[1].strip().strip('"')
#                         elif ':' in line and line.strip().startswith(' ') and 'feature_descriptions' not in line:
#                             # Parse feature descriptions
#                             parts = line.strip().split(':', 1)
#                             if len(parts) == 2:
#                                 feat_name = parts[0].strip()
#                                 feat_desc = parts[1].strip().strip('"')
#                                 current_data['feature_descriptions'][feat_name] = feat_desc
#                 except Exception as e:
#                     # logger.warning(f"Error parsing line {line_num}: {line.strip()[:50]}... - {e}")
#                     continue
        
#         # Save last dataset
#         if current_key and current_data:
#             yaml_config[current_key] = current_data
        
#         return yaml_config

#     def __len__(self):
#         return len(self.data)

#     def __getitem__(self, idx):
#         # Handle empty datasets
#         if not self.data or not self.tabular_data.train_rows:
#             logger.warning(f"Empty dataset encountered for {self.dataset_id}")
#             return Example(
#                 description=self.tabular_data.description,
#                 features=[],
#                 fewshot_rows=[],
#                 target_row=[],
#                 target_column_id=0,
#                 missing_column_ids=[]
#             )
        
#         # Parse k-shot setting
#         kshot = parse_kshot_setting(self.kshot_setting)
        
#         # Sample k examples from training data
#         if len(self.tabular_data.train_rows) <= kshot:
#             # If not enough examples, just use all with replacement
#             fewshot_row_ids = [random.randint(0, len(self.tabular_data.train_rows)-1) for _ in range(kshot)]
#         else:
#             # Sample without replacement
#             fewshot_row_ids = random.sample(range(len(self.tabular_data.train_rows)), k=kshot)
            
#         fewshot_rows = [self.tabular_data.train_rows[i] for i in fewshot_row_ids]

#         # Get target row
#         target_row = self.data[idx]
        
#         # Sample missing columns
#         num_cols = len(target_row)
#         target_column_id = random.randrange(num_cols)
#         sampled = sample_missing_columns(num_cols, self.column_missing_setting)
#         missing_column_ids = [i for i in sampled if 0 <= i < num_cols]
#         if target_column_id in missing_column_ids:
#             missing_column_ids.remove(target_column_id)

#         return Example(
#             description=self.tabular_data.description,
#             features=self.tabular_data.features,
#             fewshot_rows=fewshot_rows,
#             target_row=target_row,
#             target_column_id=target_column_id,
#             missing_column_ids=missing_column_ids
#         )

#     def preprocess_dataset(self, dataset_id, model_names: List[str], test_size: float, seed: int) -> TabularData:
#         """
#         Preprocess dataset from YAML configuration and local CSV files
#         """
#         # Load YAML configuration
#         try:
#             with open(self.yaml_config_path, 'r') as f:
#                 yaml_config = yaml.safe_load(f)
#         except yaml.YAMLError as e:
#             logger.warning(f"YAML parsing error in preprocess_dataset, using safe loader: {e}")
#             yaml_config = self._load_yaml_safe(self.yaml_config_path)
        
#         # Find dataset by ID
#         if dataset_id not in yaml_config:
#             raise ValueError(f"Dataset {dataset_id} not found in YAML configuration")
        
#         dataset_config = yaml_config[dataset_id]
        
#         # Read CSV file with error handling
#         try:
#             df = pd.read_csv(dataset_config['path'])
#         except Exception as e:
#             logger.error(f"Error reading CSV file for dataset {dataset_id}: {e}")
#             # Try with different encoding or error handling
#             try:
#                 df = pd.read_csv(dataset_config['path'], encoding='latin-1', on_bad_lines='skip')
#             except Exception as e2:
#                 logger.error(f"Failed to read dataset {dataset_id} even with error handling: {e2}")
#                 raise
        
#         # Create dataset description
#         description = dataset_config['dataset_description']
#         if len(description) > 200:  # Limit description length
#             description = description[:197] + "..."
    
#         # Create feature list
#         features = []
#         feature_descriptions = []
#         categories_list = []
        
#         # Process feature columns
#         for col in df.columns:
#             feature_name = col
            
#             # Get feature description from YAML
#             if col in dataset_config.get('feature_descriptions', {}):
#                 feature_desc = f"{col}: {dataset_config['feature_descriptions'][col][:50]}"
#             else:
#                 feature_desc = f"{col}: A feature in the dataset"
#             # feature_desc = col
#             feature_descriptions.append(feature_desc)
            
#             # Determine feature type based on data
#             col_dtype = df[col].dtype
#             if col_dtype.name in ['object', 'category']:
#                 dtype = 'categorical'
#                 categories = list(df[col].astype(str).unique())
#                 categories_list.append(categories)
#                 value_range = []
#             else:
#                 dtype = "real"
#                 categories = []
#                 categories_list.append([])
#                 values = df[col].values
#                 if values.size > 0:
#                     value_range = [float(values.min()), float(values.max())]
#                 else:
#                     value_range = []
            
#             # Create feature object
#             features.append(Feature(
#                 name=col,
#                 description=feature_desc,
#                 # description=col,
#                 description_embedding={},  # Will fill later
#                 dtype=dtype,
#                 categories=categories,
#                 categories_embedding={},  # Will fill later
#                 value_range=value_range
#             ))
        
#         # Generate embeddings for features
#         try:
#             desc_embeddings, cat_embeddings = self.get_feature_embeddings(
#                 feature_descriptions, 
#                 categories_list, 
#                 model_names
#             )
#         except Exception as e:
#             logger.error(f"Error generating embeddings for dataset {dataset_id}: {e}")
#             # Create empty embeddings as fallback
#             desc_embeddings = {desc: {model: torch.zeros(768) for model in model_names} for desc in feature_descriptions}
#             cat_embeddings = [{cat: {model: torch.zeros(768) for model in model_names} for cat in cats} for cats in categories_list]
        
#         # Add embeddings to features
#         for i, feature in enumerate(features):
#             feature.description_embedding = desc_embeddings[feature.description]
#             if feature.dtype == "categorical" and feature.categories:
#                 feature.categories_embedding = {model_name: {} for model_name in model_names}
#                 for cat in feature.categories:
#                     if i < len(cat_embeddings) and cat in cat_embeddings[i]:
#                         for model_name in model_names:
#                             if model_name in cat_embeddings[i][cat]:
#                                 feature.categories_embedding[model_name][cat] = cat_embeddings[i][cat][model_name]
        
#         # Normalize data
#         preprocessed_df = pd.DataFrame()
        
#         # Process each column
#         for i, feature in enumerate(features):
#             col_name = feature.name
#             try:
#                 raw = df[col_name]
                
#                 if feature.dtype == "categorical":
#                     # Encode categorical features
#                     # le = LabelEncoder()
#                     # preprocessed_df[col_name] = le.fit_transform(df[col_name].astype(str))
#                     preprocessed_df[col_name] = df[col_name].astype(str)
#                 else:
#                     vals = raw.values.astype(float)
#                     min_v = vals.min()
#                     max_v = vals.max()
#                     if max_v > min_v:
#                         preprocessed_df[col_name] = (vals - min_v) / (max_v - min_v)
#                     else:
#                         preprocessed_df[col_name] = np.zeros_like(vals)
#             except Exception as e:
#                 logger.warning(f"Error processing column {col_name} in dataset {dataset_id}: {e}")
#                 # Fill with zeros as fallback
#                 preprocessed_df[col_name] = np.zeros(len(df))
        
#         # Split into train and test sets
#         train_df, test_df = train_test_split(
#             preprocessed_df, 
#             test_size=test_size, 
#             random_state=seed,
#             shuffle=True
#         )
#         train_indices = train_df.index
#         test_indices = test_df.index
#         train_df = train_df.reset_index(drop=True)
#         test_df = test_df.reset_index(drop=True)
#                 # ADD THIS:
#         logger.info(f"Dataset {dataset_id} split:")
#         logger.info(f"  Original data shape: {preprocessed_df.shape}")
#         logger.info(f"  Train shape: {train_df.shape}")
#         logger.info(f"  Test shape: {test_df.shape}")

#         # Check if they overlap
#         train_set = set(map(tuple, train_df.values))
#         test_set = set(map(tuple, test_df.values))
#         overlap = train_set.intersection(test_set)
#         if overlap:
#             logger.error(f"TRAIN/TEST OVERLAP: {len(overlap)} examples!")
#         # Convert to list format
#         train_rows = train_df.values.tolist()
#         test_rows = test_df.values.tolist()
        
#         # Create final TabularData object
#         return TabularData(
#             description=description,
#             features=features,
#             train_rows=train_rows,
#             test_rows=test_rows
#         )

#     def get_feature_embeddings(
#         self, 
#         descriptions: List[str], 
#         categories_list: List[List[str]], 
#         model_names: List[str]
#     ) -> Tuple[Dict[str, Dict[str, torch.Tensor]], List[Dict[str, Dict[str, torch.Tensor]]]]:
#         """
#         Generate embeddings for feature descriptions and categories using the specified embedding model
        
#         Args:
#             descriptions: List of feature descriptions
#             categories_list: List of category lists for each feature
#             model_names: List of model names (for compatibility)
            
#         Returns:
#             Tuple of (description_embeddings, categories_embeddings)
#         """
#         # Initialize embedding dictionaries
#         desc_embeddings = {desc: {} for desc in descriptions}
#         cat_embeddings = []
        
#         # Initialize category embeddings structure
#         for cats in categories_list:
#             cat_dict = {}
#             for cat in cats:
#                 cat_dict[cat] = {}
#             cat_embeddings.append(cat_dict)
        
#         # logger.info(f"Loading embedding model: {self.embedding_model}")
#         embedding_model = get_embedding_model(self.embedding_model)
        
#         # Process descriptions in batches to avoid GPU memory issues
#         batch_size = 32
#         for i in range(0, len(descriptions), batch_size):
#             batch_descriptions = descriptions[i:i+batch_size]
#             batch_indices = list(range(i, min(i+batch_size, len(descriptions))))
            
#             # logger.info(f"Generating embeddings for descriptions batch {i//batch_size + 1}/{(len(descriptions) + batch_size - 1)//batch_size}")
#             batch_embeddings = embedding_model(batch_descriptions)
            
#             # Store embeddings
#             for batch_idx, desc_idx in enumerate(batch_indices):
#                 desc = descriptions[desc_idx]
#                 # Store same embedding for all model names in the list for compatibility
#                 for model_name in model_names:
#                     desc_embeddings[desc][model_name] = batch_embeddings[batch_idx]
        
#         # Process all categories
#         all_categories = []
#         category_mapping = []  # To map back to the original structure
        
#         for feature_idx, cats in enumerate(categories_list):
#             for cat in cats:
#                 all_categories.append(cat)
#                 category_mapping.append((feature_idx, cat))
        
#         # Process categories in batches
#         for i in range(0, len(all_categories), batch_size):
#             batch_categories = all_categories[i:i+batch_size]
#             batch_indices = list(range(i, min(i+batch_size, len(all_categories))))
            
#             # logger.info(f"Generating embeddings for categories batch {i//batch_size + 1}/{(len(all_categories) + batch_size - 1)//batch_size}")
#             batch_embeddings = embedding_model(batch_categories)
            
#             # Store embeddings
#             for batch_idx, cat_idx in enumerate(batch_indices):
#                 feature_idx, cat = category_mapping[cat_idx]
#                 # Store same embedding for all model names in the list for compatibility
#                 for model_name in model_names:
#                     cat_embeddings[feature_idx][cat][model_name] = batch_embeddings[batch_idx]
        
#         return desc_embeddings, cat_embeddings
class UCIDataset(Dataset):
    def __init__(self, 
                 dataset_id: Union[int, str], 
                 model_names: List[str], 
                 split: str, 
                 kshot_setting: str, 
                 column_missing_setting: str, 
                 test_size: float = 0.2, 
                 seed: int = 42,
                 embedding_model: str = "bert-base-uncased",
                yaml_config_path: str = ""):
        """
        Initialize UCI dataset from YAML configuration
        
        Args:
            dataset_id: Dataset ID - either integer index or string key from YAML
            model_names: List of model names for embeddings 
            split: 'train' or 'test'
            kshot_setting: K-shot setting string
            column_missing_setting: Column missing setting string
            test_size: Proportion for test split
            seed: Random seed
            embedding_model: Name of embedding model to use
            yaml_config_path: Path to YAML configuration file
        """
        super().__init__()

        self.kshot_setting = kshot_setting
        self.column_missing_setting = column_missing_setting
        self.embedding_model = embedding_model
        self.yaml_config_path = yaml_config_path
        
        # Load YAML and get all dataset IDs
        try:
            with open(self.yaml_config_path, 'r') as f:
                yaml_config = yaml.safe_load(f)
        except yaml.YAMLError as e:
            # logger.warning(f"YAML parsing error, attempting to load line by line: {e}")
            yaml_config = self._load_yaml_safe(self.yaml_config_path)
        
        # Get all dataset IDs from YAML keys
        all_dataset_ids = list(yaml_config.keys())
        logger.info(f"Found {len(all_dataset_ids)} datasets in YAML configuration")
        
        # Validate datasets - only keep ones that can be loaded
        self.dataset_ids = []
        for ds_id in all_dataset_ids:
            try:
                # Quick validation - check if path exists and is readable
                if 'path' in yaml_config[ds_id] and os.path.exists(yaml_config[ds_id]['path']):
                    self.dataset_ids.append(ds_id)
                else:
                    logger.warning(f"Skipping dataset {ds_id}: path not found or not specified")
            except Exception as e:
                logger.warning(f"Skipping dataset {ds_id} due to error: {e}")
        
        # logger.info(f"Successfully validated {len(self.dataset_ids)} datasets")
        
        # Create train/val/test splits of dataset IDs
        random.seed(seed)
        random.shuffle(self.dataset_ids)
        total = len(self.dataset_ids)
        train_size = int(0.8 * total)
        val_size = int(0.15 * total)
        
        self.train_dataset_ids = self.dataset_ids[:train_size]
        self.val_dataset_ids = self.dataset_ids[train_size:train_size+val_size]
        self.test_dataset_ids = self.dataset_ids[train_size+val_size:]
        
        logger.info(f"Train datasets: {len(self.train_dataset_ids)}")
        logger.info(f"Val datasets: {len(self.val_dataset_ids)}")
        logger.info(f"Test datasets: {len(self.test_dataset_ids)}")
        
        # Handle dataset_id as either integer index or string key
        if isinstance(dataset_id, int):
            # If integer, use it as index into dataset_ids list
            if 0 <= dataset_id < len(self.dataset_ids):
                actual_dataset_id = self.dataset_ids[dataset_id]
                logger.info(f"Converting integer dataset_id {dataset_id} to string key: {actual_dataset_id}")
            else:
                # logger.warning(f"Integer dataset_id {dataset_id} out of range, using first dataset")
                actual_dataset_id = self.dataset_ids[0] if self.dataset_ids else None
        else:
            # If string, use directly
            actual_dataset_id = dataset_id
            
        if actual_dataset_id is None:
            raise ValueError("No valid datasets found in YAML configuration")
            
        self.dataset_id = actual_dataset_id

        # Preprocess dataset
        try:
            self.tabular_data = self.preprocess_dataset(actual_dataset_id, model_names, test_size, seed)
            self.data = getattr(self.tabular_data, f"{split}_rows")
        except Exception as e:
            logger.error(f"Failed to preprocess dataset {actual_dataset_id}: {e}")
            # Create empty dataset as fallback
            self.tabular_data = TabularData(
                description=f"Failed to load dataset {actual_dataset_id}",
                features=[],
                train_rows=[],
                test_rows=[]
            )
            self.data = []

    def _load_yaml_safe(self, yaml_path):
        """Attempt to load YAML file by parsing it manually when safe_load fails"""
        yaml_config = {}
        current_key = None
        current_data = {}
        
        with open(yaml_path, 'r') as f:
            for line_num, line in enumerate(f, 1):
                try:
                    # Try to identify dataset keys (non-indented lines ending with ':')
                    if line.strip() and not line.startswith(' ') and line.strip().endswith(':'):
                        # Save previous dataset if exists
                        if current_key and current_data:
                            yaml_config[current_key] = current_data
                        
                        # Start new dataset
                        current_key = line.strip()[:-1]
                        current_data = {
                            'dataset_description': '',
                            'path': '',
                            'feature_descriptions': {}
                        }
                    elif current_key:
                        # Parse dataset properties
                        if 'dataset_description:' in line:
                            current_data['dataset_description'] = line.split('dataset_description:', 1)[1].strip().strip('"')
                        elif 'path:' in line:
                            current_data['path'] = line.split('path:', 1)[1].strip().strip('"')
                        elif ':' in line and line.strip().startswith(' ') and 'feature_descriptions' not in line:
                            # Parse feature descriptions
                            parts = line.strip().split(':', 1)
                            if len(parts) == 2:
                                feat_name = parts[0].strip()
                                feat_desc = parts[1].strip().strip('"')
                                current_data['feature_descriptions'][feat_name] = feat_desc
                except Exception as e:
                    # logger.warning(f"Error parsing line {line_num}: {line.strip()[:50]}... - {e}")
                    continue
        
        # Save last dataset
        if current_key and current_data:
            yaml_config[current_key] = current_data
        
        return yaml_config

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Handle empty datasets
        if not self.data or not self.tabular_data.train_rows:
            logger.warning(f"Empty dataset encountered for {self.dataset_id}")
            return Example(
                description=self.tabular_data.description,
                features=[],
                fewshot_rows=[],
                target_row=[],
                target_column_id=0,
                missing_column_ids=[]
            )
        
        # Parse k-shot setting
        kshot = parse_kshot_setting(self.kshot_setting)
        
        # Sample k examples from training data
        if len(self.tabular_data.train_rows) <= kshot:
            # If not enough examples, just use all with replacement
            fewshot_row_ids = [random.randint(0, len(self.tabular_data.train_rows)-1) for _ in range(kshot)]
        else:
            # Sample without replacement
            fewshot_row_ids = random.sample(range(len(self.tabular_data.train_rows)), k=kshot)
            
        fewshot_rows = [self.tabular_data.train_rows[i] for i in fewshot_row_ids]

        # Get target row
        target_row = self.data[idx]
        
        # Sample missing columns
        num_cols = len(target_row)
        target_column_id = random.randrange(num_cols)
        sampled = sample_missing_columns(num_cols, self.column_missing_setting)
        missing_column_ids = [i for i in sampled if 0 <= i < num_cols]
        if target_column_id in missing_column_ids:
            missing_column_ids.remove(target_column_id)

        return Example(
            description=self.tabular_data.description,
            features=self.tabular_data.features,
            fewshot_rows=fewshot_rows,
            target_row=target_row,
            target_column_id=target_column_id,
            missing_column_ids=missing_column_ids
        )

    def preprocess_dataset(self, dataset_id, model_names: List[str], test_size: float, seed: int) -> TabularData:
        """
        Preprocess dataset from YAML configuration and local CSV files
        """
        # Load YAML configuration
        try:
            with open(self.yaml_config_path, 'r') as f:
                yaml_config = yaml.safe_load(f)
        except yaml.YAMLError as e:
            logger.warning(f"YAML parsing error in preprocess_dataset, using safe loader: {e}")
            yaml_config = self._load_yaml_safe(self.yaml_config_path)
        
        # Find dataset by ID
        if dataset_id not in yaml_config:
            raise ValueError(f"Dataset {dataset_id} not found in YAML configuration")
        
        dataset_config = yaml_config[dataset_id]
        
        # Read CSV file with error handling
        try:
            df = pd.read_csv(dataset_config['path'])
        except Exception as e:
            logger.error(f"Error reading CSV file for dataset {dataset_id}: {e}")
            # Try with different encoding or error handling
            try:
                df = pd.read_csv(dataset_config['path'], encoding='latin-1', on_bad_lines='skip')
            except Exception as e2:
                logger.error(f"Failed to read dataset {dataset_id} even with error handling: {e2}")
                raise
        
        # Create dataset description
        description = dataset_config['dataset_description']
        if len(description) > 200:  # Limit description length
            description = description[:197] + "..."
    
        # Create feature list
        features = []
        feature_descriptions = []
        categories_list = []
        
        # Process feature columns with better type detection
        for col in df.columns:
            feature_name = col
            
            # Get feature description from YAML
            if col in dataset_config.get('feature_descriptions', {}):
                feature_desc = f"{col}: {dataset_config['feature_descriptions'][col][:50]}"
            else:
                feature_desc = f"{col}: A feature in the dataset"
            feature_descriptions.append(feature_desc)
            
            # Better type detection - check actual values, not just dtype
            sample_values = df[col].dropna().head(100)
            
            # Check if ANY value is a string (not just numeric string)
            has_non_numeric_strings = False
            numeric_count = 0
            for val in sample_values:
                try:
                    # Try to convert to float
                    float(str(val))
                    numeric_count += 1
                except ValueError:
                    has_non_numeric_strings = True
                    break
            
            # Determine feature type based on actual values
            if has_non_numeric_strings or df[col].dtype.name in ['object', 'category']:
                dtype = 'categorical'
                categories = list(df[col].astype(str).unique())
                categories_list.append(categories)
                value_range = []
                
                # Log detected categorical features
                if col == 'meta_ClassCount':  # Debug specific problematic column
                    logger.info(f"Column {col} detected as categorical with {len(categories)} unique values")
                    logger.info(f"Sample values: {list(sample_values.head(5))}")
            else:
                dtype = "real"
                categories = []
                categories_list.append([])
                # Convert to numeric, handling any remaining string values
                numeric_values = pd.to_numeric(df[col], errors='coerce')
                if numeric_values.notna().any():
                    value_range = [float(numeric_values.min()), float(numeric_values.max())]
                else:
                    value_range = [0.0, 1.0]
            
            # Create feature object
            features.append(Feature(
                name=col,
                description=feature_desc,
                description_embedding={},  # Will fill later
                dtype=dtype,
                categories=categories,
                categories_embedding={},  # Will fill later
                value_range=value_range
            ))
        
        # Generate embeddings for features
        try:
            desc_embeddings, cat_embeddings = self.get_feature_embeddings(
                feature_descriptions, 
                categories_list, 
                model_names
            )
        except Exception as e:
            logger.error(f"Error generating embeddings for dataset {dataset_id}: {e}")
            # Create empty embeddings as fallback
            desc_embeddings = {desc: {model: torch.zeros(768) for model in model_names} for desc in feature_descriptions}
            cat_embeddings = [{cat: {model: torch.zeros(768) for model in model_names} for cat in cats} for cats in categories_list]
        
        # Add embeddings to features
        for i, feature in enumerate(features):
            feature.description_embedding = desc_embeddings[feature.description]
            if feature.dtype == "categorical" and feature.categories:
                feature.categories_embedding = {model_name: {} for model_name in model_names}
                for cat in feature.categories:
                    if i < len(cat_embeddings) and cat in cat_embeddings[i]:
                        for model_name in model_names:
                            if model_name in cat_embeddings[i][cat]:
                                feature.categories_embedding[model_name][cat] = cat_embeddings[i][cat][model_name]
        
        # Normalize data
        preprocessed_df = pd.DataFrame()
        
        # Process each column with proper handling of mixed types
        for i, feature in enumerate(features):
            col_name = feature.name
            try:
                raw = df[col_name]
                
                if feature.dtype == "categorical":
                    # Keep categorical features as strings
                    preprocessed_df[col_name] = df[col_name].astype(str)
                else:
                    # For real features, ensure numeric and normalize
                    vals = pd.to_numeric(raw, errors='coerce').values
                    
                    # Handle NaN values
                    if np.isnan(vals).all():
                        preprocessed_df[col_name] = np.zeros(len(df))
                    else:
                        # Fill NaN with median
                        median_val = np.nanmedian(vals)
                        vals = np.nan_to_num(vals, nan=median_val)
                        
                        min_v = vals.min()
                        max_v = vals.max()
                        if max_v > min_v:
                            preprocessed_df[col_name] = (vals - min_v) / (max_v - min_v)
                        else:
                            preprocessed_df[col_name] = np.zeros_like(vals)
            except Exception as e:
                logger.warning(f"Error processing column {col_name} in dataset {dataset_id}: {e}")
                # Fill with zeros as fallback
                preprocessed_df[col_name] = np.zeros(len(df))
        
        # Fix train/test split to avoid overlap
        # Create index array
        indices = np.arange(len(preprocessed_df))
        np.random.seed(seed)
        np.random.shuffle(indices)
        
        # Split indices
        split_point = int(len(indices) * (1 - test_size))
        train_indices = indices[:split_point]
        test_indices = indices[split_point:]
        
        # Create dataframes using indices
        train_df = preprocessed_df.iloc[train_indices].reset_index(drop=True)
        test_df = preprocessed_df.iloc[test_indices].reset_index(drop=True)
        
        # Logging
        logger.info(f"Dataset {dataset_id} split:")
        logger.info(f"  Original data shape: {preprocessed_df.shape}")
        logger.info(f"  Train shape: {train_df.shape}")
        logger.info(f"  Test shape: {test_df.shape}")
        
        # Verify no overlap
        train_values = set(map(tuple, train_df.values))
        test_values = set(map(tuple, test_df.values))
        overlap = train_values & test_values
        if overlap:
            logger.warning(f"Still have {len(overlap)} overlapping examples after index split")
            # Additional safety: remove overlapping examples from test
            test_df = test_df[~test_df.apply(lambda row: tuple(row) in train_values, axis=1)]
            logger.info(f"  Test shape after removing overlaps: {test_df.shape}")
        
        # Convert to list format
        train_rows = train_df.values.tolist()
        test_rows = test_df.values.tolist()
        
        # Create final TabularData object
        return TabularData(
            description=description,
            features=features,
            train_rows=train_rows,
            test_rows=test_rows
        )

    def get_feature_embeddings(
        self, 
        descriptions: List[str], 
        categories_list: List[List[str]], 
        model_names: List[str]
    ) -> Tuple[Dict[str, Dict[str, torch.Tensor]], List[Dict[str, Dict[str, torch.Tensor]]]]:
        """
        Generate embeddings for feature descriptions and categories using the specified embedding model
        
        Args:
            descriptions: List of feature descriptions
            categories_list: List of category lists for each feature
            model_names: List of model names (for compatibility)
            
        Returns:
            Tuple of (description_embeddings, categories_embeddings)
        """
        # Initialize embedding dictionaries
        desc_embeddings = {desc: {} for desc in descriptions}
        cat_embeddings = []
        
        # Initialize category embeddings structure
        for cats in categories_list:
            cat_dict = {}
            for cat in cats:
                cat_dict[cat] = {}
            cat_embeddings.append(cat_dict)
        
        # logger.info(f"Loading embedding model: {self.embedding_model}")
        embedding_model = get_embedding_model(self.embedding_model)
        
        # Process descriptions in batches to avoid GPU memory issues
        batch_size = 32
        for i in range(0, len(descriptions), batch_size):
            batch_descriptions = descriptions[i:i+batch_size]
            batch_indices = list(range(i, min(i+batch_size, len(descriptions))))
            
            # logger.info(f"Generating embeddings for descriptions batch {i//batch_size + 1}/{(len(descriptions) + batch_size - 1)//batch_size}")
            batch_embeddings = embedding_model(batch_descriptions)
            
            # Store embeddings
            for batch_idx, desc_idx in enumerate(batch_indices):
                desc = descriptions[desc_idx]
                # Store same embedding for all model names in the list for compatibility
                for model_name in model_names:
                    desc_embeddings[desc][model_name] = batch_embeddings[batch_idx]
        
        # Process all categories
        all_categories = []
        category_mapping = []  # To map back to the original structure
        
        for feature_idx, cats in enumerate(categories_list):
            for cat in cats:
                all_categories.append(cat)
                category_mapping.append((feature_idx, cat))
        
        # Process categories in batches
        for i in range(0, len(all_categories), batch_size):
            batch_categories = all_categories[i:i+batch_size]
            batch_indices = list(range(i, min(i+batch_size, len(all_categories))))
            
            # logger.info(f"Generating embeddings for categories batch {i//batch_size + 1}/{(len(all_categories) + batch_size - 1)//batch_size}")
            batch_embeddings = embedding_model(batch_categories)
            
            # Store embeddings
            for batch_idx, cat_idx in enumerate(batch_indices):
                feature_idx, cat = category_mapping[cat_idx]
                # Store same embedding for all model names in the list for compatibility
                for model_name in model_names:
                    cat_embeddings[feature_idx][cat][model_name] = batch_embeddings[batch_idx]
        
        return desc_embeddings, cat_embeddings