'''
UC Irvine Machine Learning Repository Database (https://archive.ics.uci.edu/).

The following datasets are used in this project:


- 186: Wine Quality (https://archive.ics.uci.edu/dataset/186/wine+quality);
- 165: Concrete Compressive Strength (https://archive.ics.uci.edu/dataset/165/concrete+compressive+strength);
- 294: Combined Cycle Power Plant (https://archive.ics.uci.edu/dataset/294/combined+cycle+power+plant);
- 857: Risk Factor Prediction of Chronic Kidney Disease (https://archive.ics.uci.edu/dataset/857/risk+factor+prediction+of+chronic+kidney+disease);
- 464: Superconductivity Data (https://archive.ics.uci.edu/dataset/464/superconductivty+data);
- 291: Airfoil Self-Noise (https://archive.ics.uci.edu/dataset/291/airfoil+self+noise);

The following datasets have Categorical features:

- 162: Forest Fires (https://archive.ics.uci.edu/dataset/162/forest+fires);
- 601: AI4I 2020 Predictive Maintenance Dataset (https://archive.ics.uci.edu/dataset/601/ai4i+2020+predictive+maintenance+dataset);
'''
import sys
sys.path.append('.')

import numpy as np

import torch
from typing import Tuple, List
from torch.utils.data import Dataset

from ucimlrepo import dotdict, fetch_ucirepo 


class UCIDataset(Dataset):
    '''
    The UCI Machine Learning Repository Database (https://archive.ics.uci.edu/).
    
    Parameters
    ----------
    id_UCI : int
        ID of the UCI dataset.

    xs, ys : np.ndarray
        Input and output data.

    gpu_id : int
        GPU ID. If None or a negative integer, use CPU.
        
    x_min, x_max : float | np.ndarray | None
        The range of the input data.
        
    y_min, y_max : float | np.ndarray | None
        The range of the output data.
        
    name_x, name_y : List[str] | None
        The names of the input and output data.
        
    scale_x : bool
        If True, scale the input data.
        
    Attributes
    ----------    
    num_samples : int
        Number of samples in the dataset.
        
    dim_input : int
        Dimension of the input.
        
    dim_output : int
        Dimension of the output.
    
    X : torch.Tensor (num_samples, dim_input)
        Input data.
        
    Y : torch.Tensor (num_samples, dim_output)
        Output data.
    '''
    def __init__(self, id_UCI: int, xs: np.ndarray, ys: np.ndarray, gpu_id=None, 
                    x_min=None, x_max=None, y_min=None, y_max=None,
                    name_x: List[str] | None = None, 
                    name_y: List[str] | None = None,
                    scale_x = False) -> None:
        
        self.name = 'UCIDataset'
        self.id_UCI = id_UCI
        self.gpu_id = None if (gpu_id is None or gpu_id < 0) else gpu_id
        
        #* Initialize the dataset
        self.X_cpu = xs.copy()
        self.Y_cpu = ys.copy()
        
        self.num_samples = self.X_cpu.shape[0]
        self.dim_input = self.X_cpu.shape[1]
        self.dim_output = self.Y_cpu.shape[1]
        
        self.x_min = np.min(self.X_cpu, axis=0) if x_min is None else x_min
        self.x_max = np.max(self.X_cpu, axis=0) if x_max is None else x_max
        self.y_min = np.min(self.Y_cpu, axis=0) if y_min is None else y_min
        self.y_max = np.max(self.Y_cpu, axis=0) if y_max is None else y_max
        
        if name_x is None:
            self.name_x = ['x%d'%(i) for i in range(self.dim_input)]
        else:
            self.name_x = name_x
            
        if name_y is None:
            self.name_y = ['y%d'%(i) for i in range(self.dim_output)]
        else:
            self.name_y = name_y
            
        if scale_x:
            self.X_cpu = (self.X_cpu - self.x_min) / np.clip(self.x_max - self.x_min, 1e-8, None)
        
        self.X = torch.tensor(self.X_cpu).float()
        self.Y = torch.tensor(self.Y_cpu).float()
        
        if torch.cuda.is_available() and self.gpu_id is not None:
            self.X = self.X.to(self.gpu_id)
            self.Y = self.Y.to(self.gpu_id)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx]


def get_UCI_datasets(id_UCI: int, num_total_samples=None, ratio_train_samples=1.0, 
                        seed=None, gpu_id=None, scale_x=False, print_info=False,
                        index_delete_features=[],
                        dataset=None) -> Tuple[UCIDataset, UCIDataset, dotdict]:
    '''
    Get the UCI datasets.
    
    Parameters
    ----------
    id_UCI : int
        ID of the UCI dataset.
        
    num_total_samples : int
        Total number of samples.
        
    ratio_train_samples : float
        Ratio of the training samples.
        
    seed : int | None
        Random seed for sampling.
        
    gpu_id : int
        GPU ID. If None or a negative integer, use CPU.
        
    scale_x : bool
        If True, scale the input data.
        
    print_info : bool
        If True, print the dataset information.
        
    index_delete_features : List[int]
        Indexes of the features to be deleted.
    
    Returns
    -------
    train_set : UCIDataset
        The training UCI dataset.
        
    test_set : UCIDataset | None
        The testing UCI dataset.

    dataset : dotdict
        The UCI dataset.
    '''
    
    #* Load the UCI dataset
    if dataset is None:
        dataset = fetch_ucirepo(id=id_UCI)
    
    num_instances = dataset.metadata.num_instances  # Number of rows or samples
    info_variable = dataset.variables               # access variable info in tabular format

    name_x = dataset.data.features.columns.tolist() # Feature names
    name_y = dataset.data.targets.columns.tolist()  # Target names

    X_raw = dataset.data.features.values            # Features
    Y_raw = dataset.data.targets.values             # Target(s)
    
    #* Delete the features
    if len(index_delete_features) > 0:
        print()
        print('>>> Delete some of the features: n_delete_feature= ', len(index_delete_features))
        print()
        X_raw = np.delete(X_raw, index_delete_features, axis=1)
        name_x = [name_x[i] for i in range(len(name_x)) if i not in index_delete_features]

    #* Calculate attributes of the dataset
    
    if num_total_samples is None:
        num_total_samples = num_instances
    else:
        num_total_samples = min(num_total_samples, num_instances)

    x_min = np.min(X_raw, axis=0)
    x_max = np.max(X_raw, axis=0)
    y_min = np.min(Y_raw, axis=0)
    y_max = np.max(Y_raw, axis=0)

    #* Print the dataset information
    if print_info:
        
        print()
        print(f'>>> ID {id_UCI}:         {dataset.metadata.name}')
        print(f'>>> Number of instances: {num_total_samples} out of {num_instances} ({X_raw.shape[0]})')
        print(f'>>> Number of features:  {X_raw.shape[1]}')
        print(f'>>> Number of targets:   {Y_raw.shape[1]}')
        print()
        print(info_variable)
        print()
        
    #* Split the dataset [X_raw, Y_raw] into [X_train, Y_train] and [X_test, Y_test]
    
    if seed is not None:
        np.random.seed(seed)
    
    ratio_train_samples = min(1.0, max(0.0, ratio_train_samples))
    num_train_samples = int(num_total_samples * ratio_train_samples)
    
    indexes = np.random.permutation(num_instances)
    train_indexes = indexes[:num_train_samples]
    test_indexes = indexes[num_train_samples:num_total_samples]
    
    X_train = X_raw[train_indexes]
    Y_train = Y_raw[train_indexes]
    X_test = X_raw[test_indexes]
    Y_test = Y_raw[test_indexes]
    
    #* Create the UCI datasets
    train_set = UCIDataset(id_UCI, X_train, Y_train, gpu_id, x_min, x_max, y_min, y_max, name_x, name_y, scale_x)
    
    if ratio_train_samples < 1.0:
        test_set = UCIDataset(id_UCI, X_test, Y_test, gpu_id, x_min, x_max, y_min, y_max, name_x, name_y, scale_x)
    else:
        test_set = None

    return train_set, test_set, dataset


if __name__ == '__main__':
    
    #* UCI dataset ID
    
    dict_UCI = {
        186: 'Wine Quality',
        165: 'Concrete Compressive Strength',
        294: 'Combined Cycle Power Plant',
        464: 'Superconductivity Data',
        291: 'Airfoil Self-Noise'
    }
    
    dict_UCI_has_categorical_features = {
        162: 'Forest Fires',
        601: 'AI4I 2020 Predictive Maintenance Dataset',
    }
    
    for key, value in dict_UCI.items():
        
        print('========================================')

        train_set, test_set, dataset = get_UCI_datasets(id_UCI=key, ratio_train_samples=0.8, scale_x=True, print_info=True)
        
        print()
