"""
Preprocess UCI datasets and save them in a standard format.

Format:
data = {
    "train": list[TabularData],
    "test": list[TabularData]
}
"""

import os
import torch
import numpy as np
import pandas as pd
from typing import List, Dict, Tuple
from dataclasses import dataclass, field
import random
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from ucimlrepo import fetch_ucirepo
from tqdm import tqdm
import argparse
from transformers import AutoTokenizer, AutoModel
from embedding_utils import get_embedding_model
import logging
from utils.data_utils import Feature, TabularData
logger = logging.getLogger(__name__)
# List of UCI datasets we'll preprocess
UCI_DATASETS = [
    'iris',
    'wine',
    'adult',
    'car evaluation',
    'abalone',
    'bank marketing',
    'heart disease'
]

def get_feature_embeddings(
    self, 
    descriptions: List[str], 
    categories_list: List[List[str]], 
    model_names: List[str]
) -> Tuple[Dict[str, Dict[str, torch.Tensor]], List[Dict[str, Dict[str, torch.Tensor]]]]:
    """
    Generate embeddings for feature descriptions and categories using SFR2 model
    
    Args:
        descriptions: List of feature descriptions
        categories_list: List of category lists for each feature
        model_names: List of model names (will be ignored and SFR2 used instead)
        
    Returns:
        Tuple of (description_embeddings, categories_embeddings)
    """
    # Initialize embedding dictionaries
    desc_embeddings = {desc: {} for desc in descriptions}
    cat_embeddings = []
    
    # Initialize category embeddings structure
    for cats in categories_list:
        cat_dict = {}
        for cat in cats:
            cat_dict[cat] = {}
        cat_embeddings.append(cat_dict)
    
    logger.info(f"Loading SFR2 embedding model...")
    embedding_model = get_embedding_model("bert")
    
    # Process descriptions in batches to avoid GPU memory issues
    batch_size = 32
    for i in range(0, len(descriptions), batch_size):
        batch_descriptions = descriptions[i:i+batch_size]
        batch_indices = list(range(i, min(i+batch_size, len(descriptions))))
        
        logger.info(f"Generating embeddings for descriptions batch {i//batch_size + 1}/{(len(descriptions) + batch_size - 1)//batch_size}")
        batch_embeddings = embedding_model(batch_descriptions)
        
        # Store embeddings
        for batch_idx, desc_idx in enumerate(batch_indices):
            desc = descriptions[desc_idx]
            # Store same embedding for all model names in the list for compatibility
            for model_name in model_names:
                desc_embeddings[desc][model_name] = batch_embeddings[batch_idx]
    
    # Process all categories
    all_categories = []
    category_mapping = []  # To map back to the original structure
    
    for feature_idx, cats in enumerate(categories_list):
        for cat in cats:
            all_categories.append(cat)
            category_mapping.append((feature_idx, cat))
    
    # Process categories in batches
    for i in range(0, len(all_categories), batch_size):
        batch_categories = all_categories[i:i+batch_size]
        batch_indices = list(range(i, min(i+batch_size, len(all_categories))))
        
        logger.info(f"Generating embeddings for categories batch {i//batch_size + 1}/{(len(all_categories) + batch_size - 1)//batch_size}")
        batch_embeddings = embedding_model(batch_categories)
        
        # Store embeddings
        for batch_idx, cat_idx in enumerate(batch_indices):
            feature_idx, cat = category_mapping[cat_idx]
            # Store same embedding for all model names in the list for compatibility
            for model_name in model_names:
                cat_embeddings[feature_idx][cat][model_name] = batch_embeddings[batch_idx]
    
    return desc_embeddings, cat_embeddings
def preprocess_dataset(dataset_name: str, model_names: List[str], test_size: float = 0.2, seed: int = 42):
    """
    Preprocess a UCI dataset and return the TabularData object.
    
    Args:
        dataset_name: Name of the UCI dataset
        model_names: List of embedding model names to use
        test_size: Proportion of data to use for test set
        seed: Random seed for reproducibility
        
    Returns:
        TabularData object for the dataset
    """
    print(f"Processing dataset: {dataset_name}")
    
    # Fetch dataset from UCI repository
    uci_dataset = fetch_ucirepo(name=dataset_name)
    
    # Get data and metadata
    X = uci_dataset.data.features
    y = uci_dataset.data.targets
    
    # Combine features and target into one dataframe
    if isinstance(y, pd.DataFrame) and y.shape[1] == 1:
        y = y.iloc[:, 0]
    
    df = X.copy()
    if isinstance(y, pd.Series) or (isinstance(y, pd.DataFrame) and y.shape[1] == 1):
        df['target'] = y
    
    # Get feature information
    feature_info = uci_dataset.metadata.features
    target_info = uci_dataset.metadata.targets
    
    # Create dataset description
    description = f"UCI {dataset_name.title()} dataset with {len(df)} instances and {X.shape[1]} features."
    
    # Create feature list
    features = []
    feature_descriptions = []
    categories_list = []
    
    # Process feature columns
    for col in X.columns:
        # Get feature information
        if col in feature_info.index:
            col_info = feature_info.loc[col]
            feature_name = col
            
            # Create feature description
            if pd.notna(col_info['Description']):
                feature_desc = f"{feature_name}: {col_info['Description']}"
            else:
                feature_desc = f"{feature_name}: A feature in the {dataset_name} dataset"
            
            feature_descriptions.append(feature_desc)
            
            # Determine feature type
            if col_info['Type'] == 'Categorical' or X[col].dtype == 'object' or X[col].dtype == 'category':
                dtype = "categorical"
                categories = list(X[col].astype(str).unique())
                categories_list.append(categories)
                value_range = []
            else:
                dtype = "real"
                categories = []
                categories_list.append([])
                value_range = [float(X[col].min()), float(X[col].max())]
        else:
            # Handle columns not in metadata
            feature_name = col
            feature_desc = f"{feature_name}: A feature in the {dataset_name} dataset"
            feature_descriptions.append(feature_desc)
            
            if X[col].dtype == 'object' or X[col].dtype == 'category':
                dtype = "categorical"
                categories = list(X[col].astype(str).unique())
                categories_list.append(categories)
                value_range = []
            else:
                dtype = "real"
                categories = []
                categories_list.append([])
                value_range = [float(X[col].min()), float(X[col].max())]
        
        # Placeholder for embeddings, will be filled later
        features.append(Feature(
            name=feature_name,
            description=feature_desc,
            description_embedding={},
            dtype=dtype,
            categories=categories,
            categories_embedding={},
            value_range=value_range
        ))
    
    # Process target column
    target_col = df.columns[-1]  # Assuming target is the last column
    
    if target_col in target_info.index:
        col_info = target_info.loc[target_col]
        feature_name = target_col
        
        # Create feature description
        if pd.notna(col_info['Description']):
            feature_desc = f"{feature_name}: {col_info['Description']}"
        else:
            feature_desc = f"{feature_name}: The target variable in the {dataset_name} dataset"
        
        feature_descriptions.append(feature_desc)
        
        # Determine feature type
        if col_info['Type'] == 'Categorical' or df[target_col].dtype == 'object' or df[target_col].dtype == 'category':
            dtype = "categorical"
            categories = list(df[target_col].astype(str).unique())
            categories_list.append(categories)
            value_range = []
        else:
            dtype = "real"
            categories = []
            categories_list.append([])
            value_range = [float(df[target_col].min()), float(df[target_col].max())]
    else:
        # Handle target not in metadata
        feature_name = target_col
        feature_desc = f"{feature_name}: The target variable in the {dataset_name} dataset"
        feature_descriptions.append(feature_desc)
        
        if df[target_col].dtype == 'object' or df[target_col].dtype == 'category':
            dtype = "categorical"
            categories = list(df[target_col].astype(str).unique())
            categories_list.append(categories)
            value_range = []
        else:
            dtype = "real"
            categories = []
            categories_list.append([])
            value_range = [float(df[target_col].min()), float(df[target_col].max())]
    
    # Add target feature
    features.append(Feature(
        name=feature_name,
        description=feature_desc,
        description_embedding={},
        dtype=dtype,
        categories=categories,
        categories_embedding={},
        value_range=value_range
    ))
    
    # Generate embeddings for all features
    desc_embeddings, cat_embeddings = get_feature_embeddings(
        feature_descriptions,
        categories_list,
        model_names
    )
    
    # Fill embeddings in features
    for i, feature in enumerate(features):
        feature.description_embedding = desc_embeddings[feature.description]
        if feature.dtype == "categorical":
            feature.categories_embedding = {model_name: {} for model_name in model_names}
            for cat in feature.categories:
                for model_name in model_names:
                    feature.categories_embedding[model_name][cat] = cat_embeddings[i][cat][model_name]
    
    # Preprocess data for model input
    preprocessed_df = df.copy()
    
    # Encode categorical features
    encoders = {}
    for i, feature in enumerate(features):
        col = df.columns[i if i < len(df.columns)-1 else -1]
        
        if feature.dtype == "categorical":
            le = LabelEncoder()
            preprocessed_df[col] = le.fit_transform(df[col].astype(str))
            encoders[col] = le
        else:
            # Normalize real-valued features
            preprocessed_df[col] = (df[col] - df[col].min()) / (df[col].max() - df[col].min() + 1e-8)
    
    # Split into train and test sets
    train_df, test_df = train_test_split(
        preprocessed_df, 
        test_size=test_size, 
        random_state=seed
    )
    
    # Convert to list format
    train_rows = train_df.values.tolist()
    test_rows = test_df.values.tolist()
    
    # Create TabularData object
    tabular_data = TabularData(
        description=description,
        features=features,
        train_rows=train_rows,
        test_rows=test_rows
    )
    
    return tabular_data

def main():
    parser = argparse.ArgumentParser(description='Preprocess UCI datasets')
    parser.add_argument('--output_dir', type=str, default='./data',
                      help='Output directory for processed datasets')
    parser.add_argument('--test_size', type=float, default=0.2,
                      help='Proportion of data to use for test set')
    parser.add_argument('--seed', type=int, default=42,
                      help='Random seed for reproducibility')
    parser.add_argument('--model_names', type=str, nargs='+',
                      default=['bert-base-uncased'],
                      help='Model names for feature embeddings')
    args = parser.parse_args()
    
    # Create output directory if it doesn't exist
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Set random seed
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    
    # Process all datasets
    all_data = {"train": [], "test": []}
    
    for dataset_name in UCI_DATASETS:
        try:
            tabular_data = preprocess_dataset(
                dataset_name=dataset_name,
                model_names=args.model_names,
                test_size=args.test_size,
                seed=args.seed
            )
            
            # For simplicity, we're adding all TabularData to both train and test
            # In a real scenario, you might want to split these differently
            all_data["train"].append(tabular_data)
            all_data["test"].append(tabular_data)
            
            # Also save individual dataset
            torch.save(
                {"train": [tabular_data], "test": [tabular_data]},
                os.path.join(args.output_dir, f"{dataset_name.replace(' ', '_')}.pt")
            )
            
            print(f"Successfully processed {dataset_name}")
        except Exception as e:
            print(f"Error processing {dataset_name}: {e}")
    
    # Save all datasets
    torch.save(all_data, os.path.join(args.output_dir, "all_uci_datasets.pt"))
    print(f"Saved all datasets to {os.path.join(args.output_dir, 'all_uci_datasets.pt')}")

if __name__ == "__main__":
    main()