import os
from pathlib import Path
import pandas as pd
import openml
import torch
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.impute import SimpleImputer
import numpy as np
from typing import Tuple, Dict, Any, Optional
import tomlkit 

DATA_DIR = Path("data")
DATA_DIR.mkdir(parents=True, exist_ok=True)

OPENML_DATASETS = {
    "German_Credit": 31,
    "Adult_Income": 1590,
    "compas-two-years": 42192,
    "Bank_Marketing": 1461,
    "Planning_Relax": 1490,
    "EEG_Eye_State": 1471,
    "electricity": 151,
    "Wine_Quality_Red": 40691,
    "Steel_Plates_Faults": 40982,
    "MAGIC_Gamma_Telescope": 1120,
}


DATASET_CONFIGS = {
    31: {  # German Credit
        "target_column": "class",
        "drop_columns": [],
        # Only specify ordinal columns that need special mapping
        "ordinal_mapping": {
        },
        "categorical_index": [],
        "numerical_index": [],
    },
    1590: {  # Adult Income
        "target_column": "class",
        "drop_columns": [],
        "ordinal_mapping": {
        },
        "categorical_index": [],
        "numerical_index": [],
    },
    42192: {  #compas-two-years
        "target_column": "two_year_recid",
        "drop_columns": [],
        "ordinal_mapping": {
        },
        "categorical_index": [],
        "numerical_index": [],
    },
    1461: {  # Bank Marketing
        "target_column": "class",
        "drop_columns": [],
        "ordinal_mapping": {
        },
        "categorical_index": [],
        "numerical_index": [],
    },
    1490: {  # Planning Relax
        "target_column": "class",
        "drop_columns": [],
        "ordinal_mapping": {},
        "categorical_index": [],
        "numerical_index": [],
    },
    1471: {  # EEG Eye State
        "target_column": "class",
        "drop_columns": [],
        "ordinal_mapping": {},
        "categorical_index": [],
        "numerical_index": [],
    },
    151: {  # electricity
        "target_column": "class", 
        "drop_columns": [],
        "ordinal_mapping": {},
        "categorical_index": [],
        "numerical_index": [],
    },
    40691: {  # Wine Quality Red
        "target_column": "class",
        "drop_columns": [],
        "ordinal_mapping": {},
        "categorical_index": [],
        "numerical_index": [],
    },
    40982: {  # Steel Plates Faults
        "target_column": "class",
        "drop_columns": [],
        "ordinal_mapping": {},
        "categorical_index": [],
        "numerical_index": [],
    },
    1120: {  # MAGIC Gamma Telescope
        "target_column": "class",
        "drop_columns": [],
        "ordinal_mapping": {},
        "categorical_index": [],
        "numerical_index": [],
    },
}


def download_openml_dataset(name: str, did: int, data_dir: Path = DATA_DIR) -> None:
    """Fetches an OpenML dataset and saves it as <name>.csv in *data_dir*."""
    dataset = openml.datasets.get_dataset(did)
    X, y, _, _ = dataset.get_data(
        target=dataset.default_target_attribute, dataset_format="dataframe"
    )

    df = X.copy()
    if y is not None:
        df[dataset.default_target_attribute] = y

    out_path = data_dir / f"{name}.csv"
    df.to_csv(out_path, index=False)
    print(f"Saved → {out_path}")


def load_dataset_from_openml(dataset_id: int) -> Tuple[pd.DataFrame, pd.Series, Dict[str, Any]]:
    dataset = openml.datasets.get_dataset(dataset_id)
    X, y, categorical_indicator, attribute_names = dataset.get_data(
        target=dataset.default_target_attribute, 
        dataset_format="dataframe"
    )
    
    dataset_info = {
        "name": dataset.name,
        "target_column": dataset.default_target_attribute,
        "categorical_indicator": categorical_indicator,
        "attribute_names": attribute_names,
        "description": dataset.description
    }
    
    return X, y, dataset_info

def preprocess_data(
    X: pd.DataFrame, 
    y: pd.Series, 
    dataset_id: int,
    dataset_info: Dict[str, Any] = None,
    test_size: float = 0.3,
    val_ratio: float = 0.666,
    random_state: int = 42,
    device: str = None,
    return_tensors: bool = True
) -> Dict[str, Any]:
    
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    config = DATASET_CONFIGS.get(dataset_id, {})
    

    categorical_indicator = dataset_info['categorical_indicator']
    feature_names = X.columns.tolist()

    categorical_cols = [feature for feature, is_cat in zip(feature_names, categorical_indicator) if is_cat]
    numerical_cols = [feature for feature, is_cat in zip(feature_names, categorical_indicator) if not is_cat]

    categorical_index = config.get('categorical_index', [])
    numerical_index = config.get('numerical_index', [])

    for col in categorical_index:
        if col in X.columns and col in numerical_cols:
            numerical_cols.remove(col)
        if col in X.columns and col not in categorical_cols:
            categorical_cols.append(col)
    
    for col in numerical_index:
        if col in X.columns and col in categorical_cols:
            categorical_cols.remove(col)
        if col in X.columns and col not in numerical_cols:
            numerical_cols.append(col)

    numerical_cols = list(set(numerical_cols))
    categorical_cols = list(set(categorical_cols))

    drop_columns = config.get('drop_columns', [])
    for col in drop_columns:
        if col in numerical_cols:
            numerical_cols.remove(col)
        if col in categorical_cols:
            categorical_cols.remove(col)
        if col in X.columns:
            X = X.drop(columns=[col])

    numerical_cols = [col for col in numerical_cols if col in X.columns]
    categorical_cols = [col for col in categorical_cols if col in X.columns]

    if numerical_cols:
        num_imputer = SimpleImputer(strategy='median')
        X[numerical_cols] = num_imputer.fit_transform(X[numerical_cols])
    
    if categorical_cols:
        cat_imputer = SimpleImputer(strategy='most_frequent')
        X[categorical_cols] = cat_imputer.fit_transform(X[categorical_cols])

    ordinal_mapping = config.get('ordinal_mapping', {})
    for col, mapping in ordinal_mapping.items():
        if col in X.columns:
            X[col] = X[col].astype(str).map(mapping).fillna(-1)
    
    cat_enc_cols = [c for c in categorical_cols if c not in ordinal_mapping]

    target_encoder = LabelEncoder()
    y_encoded = target_encoder.fit_transform(y)

    X_train, X_temp, y_train, y_temp = train_test_split(
        X, y_encoded, test_size=test_size, random_state=random_state, 
        stratify=y_encoded if len(np.unique(y_encoded)) > 1 else None
    )
    
    X_val, X_test, y_val, y_test = train_test_split(
        X_temp, y_temp, test_size=val_ratio, random_state=random_state, 
        stratify=y_temp if len(np.unique(y_temp)) > 1 else None
    )

    cat_enc_cols = [c for c in categorical_cols if c not in ordinal_mapping]
    
    if cat_enc_cols:
        X_train_encoded = pd.get_dummies(X_train, columns=cat_enc_cols, prefix=cat_enc_cols, dtype=int)
        encoded_columns = X_train_encoded.columns.tolist()
        X_val_encoded = pd.get_dummies(X_val, columns=cat_enc_cols, prefix=cat_enc_cols, dtype=int)
        X_test_encoded = pd.get_dummies(X_test, columns=cat_enc_cols, prefix=cat_enc_cols, dtype=int)

        for col in encoded_columns:
            if col not in X_val_encoded.columns:
                X_val_encoded[col] = 0
            if col not in X_test_encoded.columns:
                X_test_encoded[col] = 0

        X_val_encoded = X_val_encoded.reindex(columns=encoded_columns, fill_value=0)
        X_test_encoded = X_test_encoded.reindex(columns=encoded_columns, fill_value=0)

        X_train = X_train_encoded
        X_val = X_val_encoded
        X_test = X_test_encoded

    scaler = StandardScaler()
    
    if numerical_cols:
        X_train[numerical_cols] = scaler.fit_transform(X_train[numerical_cols])
        X_val[numerical_cols] = scaler.transform(X_val[numerical_cols])
        X_test[numerical_cols] = scaler.transform(X_test[numerical_cols])

    if return_tensors:
        X_train_data = torch.tensor(X_train.values, dtype=torch.float32, device=device)
        X_val_data = torch.tensor(X_val.values, dtype=torch.float32, device=device)
        X_test_data = torch.tensor(X_test.values, dtype=torch.float32, device=device)
        
        y_train_data = torch.tensor(y_train, dtype=torch.long, device=device).unsqueeze(1)
        y_val_data = torch.tensor(y_val, dtype=torch.long, device=device).unsqueeze(1)
        y_test_data = torch.tensor(y_test, dtype=torch.long, device=device).unsqueeze(1)

        y_candidates = torch.unique(torch.tensor(y_encoded, device=device)).view(-1)
        num_classes = len(y_candidates)
        input_dim = X_train_data.shape[1]

    else:
        X_train_data = X_train
        X_val_data = X_val
        X_test_data = X_test
        
        y_train_data = y_train
        y_val_data = y_val
        y_test_data = y_test

        y_candidates = np.unique(y_encoded)
        num_classes = len(y_candidates)
        input_dim = X_train.shape[1]
    
    return {
        'X_train': X_train_data,
        'X_val': X_val_data,
        'X_test': X_test_data,
        'y_train': y_train_data,
        'y_val': y_val_data,
        'y_test': y_test_data,
        'y_candidates': y_candidates,
        'num_classes': num_classes,
        'input_dim': input_dim,
        'feature_names': X.columns.tolist(),
        'y_encoder': target_encoder,
        # 'X_encoders': target_encoders,
        'scaler': scaler
    }


def load_and_preprocess_dataset(
    dataset_id: int,
    config_path: Optional[str] = None,
    config: Optional[Dict[str, Any]] = None,
    device: str = None,
    return_tensors: bool = True
) -> Dict[str, Any]:
    
    if config is not None:
        config_data = config
    elif config_path and os.path.exists(config_path):
        with open(config_path, 'r') as f:
            config_data = tomlkit.load(f)
    else:
        config_data = {}

    data_config = config_data.get('preprocessing', {})  
    test_size = data_config.get('test_size', 0.3)
    val_ratio = data_config.get('val_ratio', 0.666)
    random_state = data_config.get('random_state', 42)

    X, y, dataset_info = load_dataset_from_openml(dataset_id)

    processed_data = preprocess_data(
        X, y, dataset_id, 
        dataset_info=dataset_info,
        test_size=test_size,
        val_ratio=val_ratio,
        random_state=random_state,
        device=device,
        return_tensors=return_tensors
    )

    processed_data['dataset_info'] = dataset_info
    
    return processed_data


if __name__ == "__main__":
    for ds_name, ds_id in OPENML_DATASETS.items():
        download_openml_dataset(ds_name, ds_id)
    
    
    print("\nAll datasets downloaded successfully. Files are in ./data")