import pandas as pd
import numpy as np
import openml
from sklearn.datasets import fetch_california_housing, make_regression, load_diabetes
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
import os
import zipfile
import wget

import torch

KAGGLE_AVAILABLE = False
api = None

# UCI Dataset Definitions
UCI_DATASETS = {
    "forest_fires": {
        "url": "https://archive.ics.uci.edu/ml/machine-learning-databases/forest-fires/forestfires.csv",
        "file_type": "csv",
        "target_col": "area"
    },
    "superconductivity": {
        "url": "https://archive.ics.uci.edu/ml/machine-learning-databases/00464/superconduct.zip",
        "file_type": "zip",
        "extract_dir": "./data/superconduct",
        "target_col": "critical_temp"
    },
    "communities_crime": {
        "url": "https://archive.ics.uci.edu/ml/machine-learning-databases/communities/communities.data",
        "file_type": "csv",
        "header": None,
        "names": [f"attr{i}" for i in range(128)],
        "na_values": "?",
        "target_col": "attr127"
    },
    "qsar_aquatic_toxicity": {
        "url": "https://archive.ics.uci.edu/static/public/505/qsar+aquatic+toxicity.zip",
        "file_type": "zip",
        "extract_dir": "./data/qsar_aquatic_toxicity",
        "sep": ";",
        "header": None,
        "target_col": 8
    },
    "energy_efficiency": {
        "url": "https://archive.ics.uci.edu/ml/machine-learning-databases/00242/ENB2012_data.xlsx",
        "file_type": "xlsx",
        "target_col": "Y1"
    },
    "bike_sharing": {
        "url": "https://archive.ics.uci.edu/ml/machine-learning-databases/00275/Bike-Sharing-Dataset.zip",
        "file_type": "zip",
        "extract_dir": "./data/bike_sharing",
        "target_col": "cnt",
        "specific_file": "day.csv"
    },
    "combined_cycle_power_plant": {
        "url": "https://archive.ics.uci.edu/ml/machine-learning-databases/00294/CCPP.zip",
        "file_type": "zip",
        "extract_dir": "./data/combined_cycle/CCPP",
        "target_col": "PE"
    },
    "parkinsons_telemonitoring": {
        "url": "https://archive.ics.uci.edu/ml/machine-learning-databases/parkinsons/telemonitoring/parkinsons_updrs.data",
        "file_type": "csv",
        "target_col": "motor_UPDRS"
    },
}

# Kaggle Dataset Definitions
KAGGLE_DATASETS = {}

# OpenML Dataset Definitions
# OpenML-CTR23 - A curated tabular regression benchmarking suite
# https://www.openml.org/search?type=study&study_type=task&sort=tasks_included&id=353
OPENML_DATASETS = {
    'grid_stability_regression': 44973,
    'diamond_regression': 44979,
    'miami_housing_regression': 44983,
    'solar_flare': 44966,
    'space_ga': 45402,
    'pumadyn32nh': 44981,
    'geographical_origin_of_music': 44965,
    'kin8nm': 44980,
    'Moneyball': 41021,
    'red_wine': 44972,
    'socmob': 44987,
    'white_wine': 44971,
    'ecoli70': 46618,
    'magic_irri': 46619,
    'nhanes_age': 46946,
    'cps88wages': 44984,
    'fps_benchmark': 44992,
    'kings_county': 44989,
    'online_news_popularity': 4545,
    'bank32nh': 558,
}

SKLEARN_DATASETS = {
    'diabetes': load_diabetes,
}

# Local DIR Datasets
LOCAL_DIR_DATASETS = {
    'Abalone': './data/Abalone.csv',
    'availPwr': './data/availPwr.csv',
    'bank8FM': './data/bank8FM.csv',
    'cpuSm': './data/cpuSm.csv',
    'fuelCons': './data/fuelCons.csv',
    'dAiler': './data/dAiler.csv',
    'maxTorque': './data/maxTorque.csv',
    'machineCpu': './data/machineCpu.csv',
    'servo': './data/servo.csv',
    'airfoild': './data/airfoild.csv',
    'concreteStrength': './data/concreteStrength.csv',
}

def download_uci_dataset(url, file_type='csv', extract_dir=None, sep=',', header=0, names=None, na_values=None):
    """Download and load UCI dataset."""
    base_name = os.path.basename(url)
    download_path = f"./data/{base_name}"
    
    os.makedirs('./data', exist_ok=True)
    
    print(f"Downloading UCI dataset: {url}")
    wget.download(url, download_path)
    
    if file_type == 'zip':
        os.makedirs(extract_dir, exist_ok=True)
        with zipfile.ZipFile(download_path, 'r') as zip_ref:
            zip_ref.extractall(extract_dir)
        os.remove(download_path)
        
        files = os.listdir(extract_dir)
        data_files = []
        
        for ext in ['.csv', '.xlsx', '.xls', '.data', '.txt']:
            for file in files:
                if file.endswith(ext):
                    if 'readme' not in file.lower():
                        data_files.append(file)
                        break
            if data_files:
                break
        
        if not data_files:
            raise FileNotFoundError(f"No data file found in {extract_dir}.")
        
        file = data_files[0]
        file_path = os.path.join(extract_dir, file)
        
        if file.endswith(('.xlsx', '.xls')):
            print(f"Loading Excel file: {file}")
            df = pd.read_excel(file_path)
        else:
            print(f"Loading text file: {file}")
            df = pd.read_csv(file_path, sep=sep, header=header, names=names, na_values=na_values)
            
    elif file_type == 'xlsx':
        df = pd.read_excel(download_path)
        os.remove(download_path)
    elif file_type == 'csv':
        df = pd.read_csv(download_path, sep=sep, header=header, names=names, na_values=na_values)
        os.remove(download_path)
    else:
        raise ValueError(f"Unsupported file type: {file_type}")
    
    print(f"UCI dataset loaded: {url}, shape: {df.shape}")
    print(f"Columns: {list(df.columns)}")
    return df

def download_kaggle_dataset(dataset_name, path, is_competition=False, unzip=True):
    """Download Kaggle dataset."""
    if not KAGGLE_AVAILABLE:
        raise RuntimeError("Kaggle API is not available. Please check Kaggle API settings.")
    
    os.makedirs(path, exist_ok=True)
    print(f"Downloading Kaggle dataset: {dataset_name}")
    
    if is_competition:
        api.competition_download_files(dataset_name, path=path, unzip=unzip)
    else:
        api.dataset_download_files(dataset_name, path=path, unzip=unzip)
    
    print(f"Kaggle dataset downloaded: {dataset_name} -> {path}")

def load_dataset(name, data_dir='./data'):
    """Load the specified dataset."""
    print(f"Loading dataset: {name}")
    if name in SKLEARN_DATASETS:
        data = SKLEARN_DATASETS[name](as_frame=True)
        df = data.frame
        if name == 'california_housing': target_col = 'MedHouseVal'
        elif name == 'diabetes': target_col = 'target'

        print(f"Target column identified: {target_col}")
        return df, target_col

    elif name in OPENML_DATASETS:
        dataset_id = OPENML_DATASETS[name]
        dataset = openml.datasets.get_dataset(dataset_id, download_data=True, download_qualities=True, download_features_meta_data=True)
        target_col = dataset.default_target_attribute
        print(f"Target column identified: {target_col}")
        X, y, categorical_indicator, attribute_names = dataset.get_data(
            dataset_format='dataframe',
            target=target_col
        )
        if isinstance(X, pd.DataFrame) and isinstance(y, pd.Series):
            if y.name is None: y.name = target_col
            if y.name in X.columns:
                X = X.drop(columns=[y.name])
            df = pd.concat([X, y], axis=1)
            print(f"OpenML dataset {name} loaded, shape: {df.shape}")
            return df, target_col
        else:
            raise TypeError("OpenML data was not loaded as DataFrame and Series.")

    elif name in UCI_DATASETS:
        params = UCI_DATASETS[name]
        
        if params.get('file_type') == 'zip':
            extract_dir = params.get('extract_dir', f'./data/{name}')
            if not os.path.exists(extract_dir) or not os.listdir(extract_dir):
                df = download_uci_dataset(
                    params['url'],
                    params.get('file_type', 'csv'),
                    extract_dir,
                    params.get('sep', ','),
                    params.get('header', 0),
                    params.get('names'),
                    params.get('na_values')
                )
            else:
                files = os.listdir(extract_dir)
                data_files = []
                
                if 'specific_file' in params:
                    specific_file = params['specific_file']
                    if specific_file in files:
                        data_files = [specific_file]
                    else:
                        raise FileNotFoundError(f"Specified file {specific_file} not found in {extract_dir}.")
                else:
                    for ext in ['.csv', '.xlsx', '.xls', '.data', '.txt']:
                        for file in sorted(files):
                            if file.endswith(ext):
                                if 'readme' not in file.lower():
                                    data_files.append(file)
                                    break
                        if data_files:
                            break
                
                if not data_files:
                    raise FileNotFoundError(f"No data file found in {extract_dir}.")
                
                file = data_files[0]
                if file.endswith(('.xlsx', '.xls')):
                    df = pd.read_excel(
                        os.path.join(extract_dir, file),
                        header=params.get('header', 0),
                        names=params.get('names'),
                        na_values=params.get('na_values')
                    )
                else:
                    df = pd.read_csv(
                        os.path.join(extract_dir, file),
                        sep=params.get('sep', ','),
                        header=params.get('header', 0),
                        names=params.get('names'),
                        na_values=params.get('na_values')
                    )
        else:
            df = download_uci_dataset(
                params['url'],
                params.get('file_type', 'csv'),
                params.get('extract_dir'),
                params.get('sep', ','),
                params.get('header', 0),
                params.get('names'),
                params.get('na_values')
            )
        
        target_col = params['target_col']
        if isinstance(target_col, int):
            target_col = df.columns[target_col]
        
        print(f"UCI dataset {name} loaded, shape: {df.shape}, target column: {target_col}")
        return df, target_col

    elif name in KAGGLE_DATASETS:
        params = KAGGLE_DATASETS[name]
        dataset_path = params['path']
        
        if not os.path.exists(dataset_path) or not os.listdir(dataset_path):
            download_kaggle_dataset(
                params['dataset_name'],
                dataset_path,
                params.get('is_competition', False)
            )
        
        csv_files = [f for f in os.listdir(dataset_path) if f.endswith('.csv')]
        if not csv_files:
            raise FileNotFoundError(f"No CSV file found in {dataset_path}.")
        
        main_csv = max(csv_files, key=lambda f: os.path.getsize(os.path.join(dataset_path, f)))
        df = pd.read_csv(os.path.join(dataset_path, main_csv))
        
        target_col = params['target_col']
        print(f"Kaggle dataset {name} loaded, shape: {df.shape}, target column: {target_col}")
        return df, target_col

    elif name in LOCAL_DIR_DATASETS:
        filepath = LOCAL_DIR_DATASETS[name]
        if not os.path.exists(filepath):
            raise FileNotFoundError(f"Dataset file not found: {filepath}. Please download the DIR benchmark.")
        print(f"Loading local file: {filepath}")
        df = pd.read_csv(filepath)

        target_col = 'a1' if name == 'a1' \
                else 'a2' if name == 'a2' \
                else 'a3' if name == 'a3' \
                else 'a4' if name == 'a4' \
                else 'a5' if name == 'a5' \
                else 'a6' if name == 'a6' \
                else 'a7' if name == 'a7' \
                else 'Rings' if name == 'Abalone' \
                else 'acceleration' if name == 'acceleration' \
                else 'available.power' if name == 'availPwr' \
                else 'rej' if name == 'bank8FM' \
                else 'usr' if name == 'cpuSm' \
                else 'fuel.consumption.country' if name == 'fuelCons' \
                else 'Sa' if name == 'dAiler' \
                else 'HousValue' if name == 'boston' \
                else 'maximal.torque' if name == 'maxTorque' \
                else 'class' if name == 'machineCpu' \
                else 'class' if name == 'servo' \
                else 'ScaledSoundPressure' if name == 'airfoild' \
                else 'ConcreteCompressiveStrength' if name == 'concreteStrength' \
                else 'score' if name == 'STS-B-DIR' \
                else None

        if target_col not in df.columns:
            raise ValueError(f"Target column '{target_col}' not found in {filepath}. Please specify correctly.")
        print(f"Assumed target column: {target_col}")
        return df, target_col
    else:
        raise ValueError(f"Unknown dataset name: {name}")


def preprocess_data(df, target_col):
    """Preprocess dataframe: impute missing values, scale numeric features, encode categorical features."""
    print("Preprocessing data...")
    y = df[target_col]
    X = df.drop(columns=[target_col])

    numeric_features = X.select_dtypes(include=np.number).columns.tolist()
    categorical_features = X.select_dtypes(exclude=np.number).columns.tolist()

    print(f"Numeric features: {numeric_features}")
    print(f"Categorical features: {categorical_features}")

    numeric_transformer = Pipeline(steps=[
        ('imputer', SimpleImputer(strategy='median')),
        ('scaler', StandardScaler())])

    categorical_transformer = Pipeline(steps=[
        ('imputer', SimpleImputer(strategy='most_frequent')),
        ('onehot', OneHotEncoder(handle_unknown='ignore', sparse_output=False))
    ])

    preprocessor = ColumnTransformer(
        transformers=[
            ('num', numeric_transformer, numeric_features),
            ('cat', categorical_transformer, categorical_features)],
        remainder='passthrough')

    print("Preprocessor created.")
    return X, y, preprocessor


def freedman_diaconis_bins(data):
    """Calculate optimal number of bins using Freedman-Diaconis rule."""
    q75, q25 = np.percentile(data, [75, 25])
    iqr = q75 - q25
    bin_width = 2 * iqr / (len(data) ** (1/3))
    data_range = np.max(data) - np.min(data)
    n_bins = int(np.ceil(data_range / bin_width))
    return max(n_bins, 1)


def classify_bins_by_samples(bin_counts, thres_few, thres_many):
    """Classify each bin by shot type based on sample counts."""
    few_threshold = thres_few
    many_threshold = thres_many

    bin_types = []
    for count in bin_counts:
        if count <= few_threshold:
            bin_types.append('few')
        elif count <= many_threshold:
            bin_types.append('medium')
        else:
            bin_types.append('many')
    return bin_types


def create_balanced_dataset(CONFIG, X, y, shot_mapping):
    """Create balanced validation/test set."""
    few_indices = [i for i, shot in shot_mapping.items() if shot == 'few']
    med_indices = [i for i, shot in shot_mapping.items() if shot == 'medium']
    many_indices = [i for i, shot in shot_mapping.items() if shot == 'many']
    
    n_min_samples = min(len(few_indices), len(med_indices), len(many_indices))
    
    np.random.seed(CONFIG["random_state"])
    med_indices_balanced = np.random.choice(med_indices, size=n_min_samples, replace=False)
    many_indices_balanced = np.random.choice(many_indices, size=n_min_samples, replace=False)
    
    balanced_indices = np.concatenate([few_indices, med_indices_balanced, many_indices_balanced])
    
    X_balanced = X[balanced_indices]
    y_balanced = y[balanced_indices]
    
    balanced_shot_indices = {}
    for i in range(0, n_min_samples):
        balanced_shot_indices[i] = 'few'
    for i in range(n_min_samples, 2 * n_min_samples):
        balanced_shot_indices[i] = 'medium'  
    for i in range(2 * n_min_samples, 3 * n_min_samples):
        balanced_shot_indices[i] = 'many'
    
    return X_balanced, y_balanced, balanced_indices, balanced_shot_indices


def calculate_shot_wise_mae(y_true, y_pred, shot_mapping):
    """
    Calculate MAE for each shot type (few, medium, many).
    
    Args:
        y_true: Actual target values
        y_pred: Model predictions
        shot_mapping: Dictionary mapping each index to its shot type
    
    Returns:
        dict: MAE values for each shot type
    """
    if torch.is_tensor(y_pred):
        y_pred = y_pred.cpu().numpy()

    few_indices = [i for i, shot in shot_mapping.items() if shot == 'few']
    med_indices = [i for i, shot in shot_mapping.items() if shot == 'medium']
    many_indices = [i for i, shot in shot_mapping.items() if shot == 'many']
    
    mae_results = {}
    
    for shot_type, indices in [('few', few_indices), 
                             ('medium', med_indices), 
                             ('many', many_indices)]:
        if indices:
            mae = np.mean(np.abs(y_true[indices] - y_pred[indices]))
            mae_results[shot_type] = float(mae)
        else:
            mae_results[shot_type] = None
            
    mae_results['overall'] = float(np.mean(np.abs(y_true - y_pred)))
    
    return mae_results


def map_shot_types(y_values, bin_edges, bin_types):
    """
    Map each data point to its shot type (few/medium/many).
    
    Args:
        y_values: Data values (array-like)
        bin_edges: Bin edges from np.histogram
        bin_types: List of shot types for each bin (few/medium/many)
    
    Returns:
        dict: {index: shot_type} mapping
    """
    shot_mapping = {}
    for i, val in enumerate(y_values):
        bin_idx = np.digitize(val, bin_edges) - 1
        if bin_idx >= len(bin_types):
            bin_idx = len(bin_types) - 1
        shot_mapping[i] = bin_types[bin_idx]
    return shot_mapping


def split_data(X, y, test_size=0.2, validation_size=0.2, random_state=42):
    """Split data into train, validation, and test sets."""
    X_train_val, X_test, y_train_val, y_test = train_test_split(
        X, y, test_size=test_size, random_state=random_state
    )

    relative_val_size = validation_size / (1.0 - test_size)

    X_train, X_val, y_train, y_val = train_test_split(
        X_train_val, y_train_val, test_size=relative_val_size, random_state=random_state
    )

    print(f"Data split: Train ({X_train.shape[0]}), Validation ({X_val.shape[0]}), Test ({X_test.shape[0]})")
    return X_train, X_val, X_test, y_train, y_val, y_test
    


def split_data_stratified(X, y, test_size=0.2, validation_size=0.2, n_bins=10, random_state=42):
    """
    Apply stratified sampling to continuous y values (regression target)
    and split data into train, validation, and test sets.

    Args:
        X (pd.DataFrame): Feature data
        y (pd.Series): Target data
        test_size (float): Proportion of test set from total data
        validation_size (float): Proportion of validation set from total data
        n_bins (int): Number of bins to divide y values
        random_state (int): Seed for reproducibility

    Returns:
        Tuple: Split X_train, X_val, X_test, y_train, y_val, y_test
    """
    if not isinstance(y, pd.Series):
        y = pd.Series(y)

    try:
        y_binned = pd.qcut(y, q=n_bins, labels=False, duplicates='drop')
    except ValueError as e:
        print(f"Warning: Error in qcut. Try reducing n_bins. Error: {e}")
        y_binned = None

    stratify_option_1 = y_binned if y_binned is not None else None

    X_train_val, X_test, y_train_val, y_test = train_test_split(
        X, y,
        test_size=test_size,
        random_state=random_state,
        stratify=stratify_option_1
    )

    relative_val_size = validation_size / (1.0 - test_size)
    
    stratify_option_2 = y_binned.loc[y_train_val.index] if stratify_option_1 is not None else None

    X_train, X_val, y_train, y_val = train_test_split(
        X_train_val, y_train_val,
        test_size=relative_val_size,
        random_state=random_state,
        stratify=stratify_option_2
    )

    print(f"Data split complete: Total ({len(X)}), Train ({len(X_train)}), Validation ({len(X_val)}), Test ({len(X_test)})")
    
    return X_train, X_val, X_test, y_train, y_val, y_test


def update_shot_configs(CONFIG, val_shot_mapping, test_shot_mapping, y_val_sr, y_test_sr,
                        val_balanced_shot_indices, test_balanced_shot_indices):
    """
    Update shot-related configurations in CONFIG.

    Args:
        CONFIG: Configuration dictionary
        val_shot_mapping: Shot mapping for validation data
        test_shot_mapping: Shot mapping for test data
        y_val_sr: Target variable for validation data
        y_test_sr: Target variable for test data
        val_balanced_shot_indices: Balanced shot indices for validation data
        test_balanced_shot_indices: Balanced shot indices for test data
    """
    CONFIG['shot_info'] = {
        'validation': {
            'shot_mapping': val_shot_mapping,
            'indices': {
                'few': [i for i, shot in val_shot_mapping.items() if shot == 'few'],
                'medium': [i for i, shot in val_shot_mapping.items() if shot == 'medium'],
                'many': [i for i, shot in val_shot_mapping.items() if shot == 'many']
            }
        },
        'test': {
            'shot_mapping': test_shot_mapping,
            'indices': {
                'few': [i for i, shot in test_shot_mapping.items() if shot == 'few'],
                'medium': [i for i, shot in test_shot_mapping.items() if shot == 'medium'],
                'many': [i for i, shot in test_shot_mapping.items() if shot == 'many']
            }
        }
    }

    CONFIG['shot_eval_info'] = {
        'validation': {
            'original_y': y_val_sr.to_numpy(),
            'shot_indices': val_shot_mapping
        },
        'test': {
            'original_y': y_test_sr.to_numpy(),
            'shot_indices': test_shot_mapping
        }
    }

    CONFIG['bal_shot_info'] = {
        'validation': {
            'indices': {
                'few': val_balanced_shot_indices['few'],
                'medium': val_balanced_shot_indices['medium'],
                'many': val_balanced_shot_indices['many']
            }
        },
        'test': {
            'indices': {
                'few': test_balanced_shot_indices['few'],
                'medium': test_balanced_shot_indices['medium'],
                'many': test_balanced_shot_indices['many']
            }
        }
    }

def download_all_new_datasets():
    """Download all new UCI and Kaggle datasets."""
    print("=== Starting UCI dataset download ===")
    for name, params in UCI_DATASETS.items():
        try:
            print(f"Downloading: {name}")
            download_uci_dataset(
                params['url'],
                params.get('file_type', 'csv'),
                params.get('extract_dir'),
                params.get('sep', ','),
                params.get('header', 0),
                params.get('names'),
                params.get('na_values')
            )
        except Exception as e:
            print(f"Failed to download {name}: {e}")
    
    if KAGGLE_AVAILABLE:
        print("\n=== Starting Kaggle dataset download ===")
        for name, params in KAGGLE_DATASETS.items():
            try:
                print(f"Downloading: {name}")
                download_kaggle_dataset(
                    params['dataset_name'],
                    params['path'],
                    params.get('is_competition', False)
                )
            except Exception as e:
                print(f"Failed to download {name}: {e}")
    else:
        print("\nKaggle API unavailable, skipping Kaggle datasets.")
    
    print("\nAll new datasets downloaded successfully!")
