################################################################################
# Script to fetch one or multiple datasets and save the data in .npy format.
# Applies preprocessing to the data.
################################################################################

import sys
import numpy as np
import shutil
import pandas as pd
from pathlib import Path
import requests
import urllib.request
import zipfile
import unlzw3
from io import StringIO

from sklearn.datasets import fetch_openml
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PowerTransformer, FunctionTransformer

path_to_original_covid_data = ''


def transform_data(X, do_diff=True, do_power_transform=True):
    """
    Transform data with np.diff or power transform.
    
    Parameters
    ----------
    do_diff : bool, default=True
        If True, apply np.diff.
    do_power_transform : bool, default=True
        If True, apply power transform.
        
    Returns
    -------
    array-like
    """
    def diff(X, y=None):
        return np.diff(X, axis=0)
    
    def diff2(X, y=None):
        return diff(diff(X))

    first_diff = ('first_difference', FunctionTransformer(func=diff, validate=True))
    power_transform = ('power_transform',
                       PowerTransformer(method='yeo-johnson', standardize=True))
    pipeline = []
    if do_diff:
        pipeline.append(first_diff)
    if do_power_transform:
        pipeline.append(power_transform)

    if len(pipeline) == 0:
        Z = X
    else:
        pipeline = Pipeline(pipeline)
        Z = pipeline.fit_transform(X)
    
    return Z

def fetch_gas_data(data_path):
    """
    Fetch Gas data and apply transformations.
    
    Parameters
    ----------
    data_path : str
        Path to folder to save data.
    """
    # Set the random seed for reproducibility
    np.random.seed(0)
    
    # Define name of output file in .npy format
    final_filename = data_path/'gas.npy'
    if not Path.exists(final_filename):
        # Create folder with temporary data if it does not already exist
        # All data inside this folder will be removed automatically
        temp_path = Path(data_path/'temp-data')
        if not Path.exists(temp_path):
            Path.mkdir(temp_path)

        # Download Gas data from UCI
        uci_gas_url = 'http://archive.ics.uci.edu/ml/machine-learning-databases/'\
                      '00362/HT_Sensor_UCIsubmission.zip'
        print('Downloading Gas Dataset')
        response = urllib.request.urlretrieve(uci_gas_url, temp_path/'gas.zip')

        # Extract data
        with zipfile.ZipFile(temp_path/'gas.zip') as zipped_file:
            zipped_file.extractall(temp_path/'gas-unzipped')
        with zipfile.ZipFile(temp_path/'gas-unzipped'/'HT_Sensor_dataset.zip') as zipped_file:
            zipped_file.extractall(temp_path)

        # Read Gas data and preprocess
        do_diff = False
        do_power_transform = False
        dataset = np.loadtxt(temp_path/'HT_Sensor_dataset.dat', skiprows=1)
        # Get only tests with id = 0 and then drop the id
        dataset = dataset[dataset[:, 0] == 0.][:, 1:]
        # The first 3 columns are time, gas1 pcc, gas2 pcc so irrelevant
        dataset = dataset[:, 3:]
        # Permute rows to remove time dependencies
        dataset = np.random.permutation(dataset)
        # Transform data
        dataset = transform_data(dataset, do_diff=do_diff, 
                                 do_power_transform=do_power_transform)
        # Save in .npy format
        np.save(final_filename, dataset)

        shutil.rmtree(temp_path)
        print(f'Gas dataset with {dataset.shape[0]} rows and {dataset.shape[1]} '\
              f'columns saved at {final_filename}')
    else:
        print(f'{final_filename} file already exsits. Will not '\
              'overwrite. Terminating.')

def fetch_energy_data(data_path):
    """
    Fetch Energy data and apply transformations.
    
    Parameters
    ----------
    data_path : str
        Path to folder to save data.
    """
    # Set the random seed for reproducibility
    np.random.seed(0)
    
    # Define name of output file in .npy format
    final_filename = data_path/'energy.npy'
    if not Path.exists(final_filename):
        # Create folder with temporary data if it does not already exist
        # All data inside this folder will be removed automatically
        temp_path = Path(data_path/'temp-data')
        if not Path.exists(temp_path):
            Path.mkdir(temp_path)

        # Download Energy data from UCI
        uci_energy_url = 'https://archive.ics.uci.edu/ml/machine-learning-databases/'\
                         '00374/energydata_complete.csv'
        print('Downloading Energy Dataset')
        response = urllib.request.urlretrieve(uci_energy_url, temp_path/'energydata.csv')
        
        # Read Gas data and preprocess
        do_diff = False
        do_power_transform = True

        dataset = pd.read_csv(temp_path/'energydata.csv')
        dataset = dataset.to_numpy()[:, 1:27].astype(np.float64)
        # Permute rows to remove time dependencies
        dataset = np.random.permutation(dataset)
        # Transform data
        dataset = transform_data(dataset, do_diff=do_diff, 
                                 do_power_transform=do_power_transform)
        # Save in .npy format
        np.save(final_filename, dataset)

        shutil.rmtree(temp_path)
        print(f'Energy dataset with {dataset.shape[0]} rows and {dataset.shape[1]} '\
              f'columns saved at {final_filename}')
    else:
        print(f'{final_filename} file already exsits. Will not '\
              'overwrite. Terminating.')
        
def fetch_musk2_data(data_path):
    """
    Fetch Musk2 data and apply transformations.
    
    Parameters
    ----------
    data_path : str
        Path to folder to save data.
    """
    # Define name of output file in .npy format
    final_filename = data_path/'musk2.npy'
    if not Path.exists(final_filename):
        # Create folder with temporary data if it does not already exist
        # All data inside this folder will be removed automatically
        temp_path = Path(data_path/'temp-data')
        if not Path.exists(temp_path):
            Path.mkdir(temp_path)
        
        # Download Musk2 data from UCI
        uci_energy_url = 'https://archive.ics.uci.edu//ml/machine-learning-'\
                         'databases/musk/clean2.data.Z'
        print('Downloading Musk2 Dataset')
        response = urllib.request.urlretrieve(uci_energy_url, temp_path/'musk2.Z')
        
        # Extract Musk data and preprocess
        uncompressed_data = unlzw3.unlzw(Path(temp_path/'musk2.Z').read_bytes())
        dataset = pd.read_csv(StringIO(uncompressed_data.decode('utf-8')), header=None)
        
        # Remove the first two columns and the target
        dataset = dataset.to_numpy()[:, 2:-1]
        
        # Save in .npy format
        np.save(final_filename, dataset)
        
        shutil.rmtree(temp_path)
        print(f'Musk2 dataset with {dataset.shape[0]} rows and {dataset.shape[1]} '\
              f'columns saved at {final_filename}')
    else:
        print(f'{final_filename} file already exsits. Will not '\
              'overwrite. Terminating.')
        
def fetch_scene_data(data_path):
    """
    Fetch Scene data and apply transformations.
    
    Parameters
    ----------
    data_path : str
        Path to folder to save data.
    """
    # Define name of output file in .npy format
    final_filename = data_path/'scene.npy'
    if not Path.exists(final_filename):
        # Download Scene data from OpenML
        print('Downloading Scene Dataset')
        dataset = fetch_openml(data_id=312, as_frame=True, cache=False)['data']
        
        # Remove the last five columns
        dataset = dataset.to_numpy()[:, :-5]
        
        # Save in .npy format
        np.save(final_filename, dataset)
        
        print(f'Scene dataset with {dataset.shape[0]} rows and {dataset.shape[1]} '\
              f'columns saved at {final_filename}')
    else:
        print(f'{final_filename} file already exsits. Will not '\
              'overwrite. Terminating.')
        
def fetch_mnist_data(data_path):
    """
    Fetch mnist data and apply transformations.
    
    Parameters
    ----------
    data_path : str
        Path to folder to save data.
    """
    # Define name of output file in .npy format
    final_filename = data_path/'mnist.npy'
    if not Path.exists(final_filename):
        # Download mnist data from OpenML
        print('Downloading MNIST Dataset')
        dataset = fetch_openml(data_id=554, as_frame=True, cache=False)['data']
        
        # Save in .npy format
        np.save(final_filename, dataset)
        
        print(f'MNIST dataset with {dataset.shape[0]} rows and {dataset.shape[1]} '\
              f'columns saved at {final_filename}')
    else:
        print(f'{final_filename} file already exsits. Will not '\
              'overwrite. Terminating.')
        
def fetch_dilbert_data(data_path):
    """
    Fetch Dilbert data and apply transformations.
    
    Parameters
    ----------
    data_path : str
        Path to folder to save data.
    """
    # Define name of output file in .npy format
    final_filename = data_path/'dilbert.npy'
    if not Path.exists(final_filename):    
        # Download Dilbert data from OpenML
        print('Downloading Dilbert Dataset')
        dataset = fetch_openml(data_id=41163, as_frame=True, cache=False)['data']
        
        # Save in .npy format
        np.save(final_filename, dataset)
        
        print(f'Dilbert dataset with {dataset.shape[0]} rows and {dataset.shape[1]} '\
              f'columns saved at {final_filename}')
    else:
        print(f'{final_filename} file already exsits. Will not '\
              'overwrite. Terminating.')
        
def fetch_founders_data(data_path):
    """
    Fetch Founders data and apply transformations.
    
    Parameters
    ----------
    data_path : str
        Path to folder to save data.
    """    
    # Define name of output file in .npy format
    final_filename = data_path/'founders.npy'
    
    if not Path.exists(final_filename):
        print('Downloading Founders Dataset')
        
        # Create folder with temporary data if it does not already exist
        # All data inside this folder will be removed automatically
        temp_path = Path(data_path/'temp-data')
        if not Path.exists(temp_path):
            Path.mkdir(temp_path)
        
        # Download Founders data from Figshare and save in .npz
        url = 'https://ndownloader.figshare.com/files/35015998'
        response = requests.get(url)
        
        with open(temp_path/'Founders_ch22_10KSNP.npz', 'wb') as f:
            f.write(response.content)
        
        # Read .npz file and access data
        data = np.load(temp_path/'Founders_ch22_10KSNP.npz', allow_pickle=True)
        dataset = data['Founders_ch22_10KSNP']
        
        # Save in .npy format
        np.save(final_filename, dataset)
        
        shutil.rmtree(temp_path)
        print(f'Founders dataset with {dataset.shape[0]} rows and {dataset.shape[1]} '\
              f'columns saved at {final_filename}')
    else:
        print(f'{final_filename} file already exsits. Will not '\
              'overwrite. Terminating.')
    
def fetch_covid_data(data_path):
    """
    Fetch Covid data and apply transformations.
    
    Parameters
    ----------
    data_path : str
        Path to folder to save data.
    """
    try: 
        # Set the random seed for reproducibility
        np.random.seed(0)

        # Define name of output file in .npy format
        final_filename = data_path/'covid.npy'
        if not Path.exists(final_filename):
            # Create folder with temporary data if it does not already exist
            # All data inside this folder will be removed automatically
            temp_path = Path(data_path/'temp-data')
            if not Path.exists(temp_path):
                Path.mkdir(temp_path)

            # Read Covid data and preprocess
            do_diff = True
            do_power_transform = True
            columns = ['MI', 'PA', 'IL', 'NY', 'MA', 'FL', 'TX', 'CA', 'NJ', 'NYC']
            
            dataset = pd.read_csv(path_to_original_covid_data, index_col=0)
            dataset = dataset.loc[:, columns].to_numpy().astype(np.float64)
            # Permute rows to remove time dependencies
            dataset = np.random.permutation(dataset)
            # Transform data
            dataset = transform_data(dataset, do_diff=do_diff, 
                                     do_power_transform=do_power_transform)
            # Save in .npy format
            np.save(final_filename, dataset)

            shutil.rmtree(temp_path)
            print(f'Covid dataset with {dataset.shape[0]} rows and {dataset.shape[1]} '\
                  f'columns saved at {final_filename}')
        else:
            print(f'{final_filename} file already exsits. Will not '\
                  'overwrite. Terminating.')
    except:
        print('Due to licencing issues, we are unable to share the COVID dataset. '\
          'Please contact the author via his email on the GitHub page for '\
          'access to the data!')

def fetch_phenotypes_data():
    print('Due to licencing issues, we are unable to share the Phenotypes dataset. '\
      'Please contact the author via his email on the GitHub page for '\
      'access to the data!')    

def fetch_embark_data():
    print('Due to licencing issues, we are unable to share the Embark dataset. '\
      'Please contact the author via his email on the GitHub page for '\
      'access to the data!')

def main():
    if len(sys.argv) < 3:
        print(len(sys.argv))
        print("Error - Incorrect input")
        print("Expecting python3 fetch_data.py [dataset_name] [data_path]")
        sys.exit(0)
    
    # Parse the input arguments
    # E.g. dataset_name = 'all'
    # E.g. data_path = '../../data/'
    _, dataset_name, data_path = sys.argv
    data_path = Path(data_path)
    
    if not Path.exists(data_path):
        Path.mkdir(data_path)
    
    # Fetch data
    if dataset_name == 'all':
        fetch_gas_data(data_path)
        fetch_energy_data(data_path)
        fetch_musk2_data(data_path)
        fetch_scene_data(data_path)
        fetch_mnist_data(data_path)
        fetch_dilbert_data(data_path)
        fetch_founders_data(data_path)
        fetch_covid_data(data_path)
        fetch_phenotypes_data()
        fetch_embark_data()
    elif dataset_name == 'gas':
        fetch_gas_data(data_path)
    elif dataset_name == 'energy':
        fetch_energy_data(data_path)
    elif dataset_name == 'musk2':
        fetch_musk2_data(data_path)
    elif dataset_name == 'scene':
        fetch_scene_data(data_path)
    elif dataset_name == 'mnist':
        fetch_mnist_data(data_path)
    elif dataset_name == 'dilbert':
        fetch_dilbert_data(data_path)
    elif dataset_name == 'founders':
        fetch_founders_data(data_path)
    elif dataset_name == 'covid':
        fetch_covid_data(data_path)
    elif dataset_name == 'phenotypes':
        fetch_phenotypes_data()
    elif dataset_name == 'embark':
        fetch_embark_data()
    else:
        raise ValueError(f'{dataset_name} not recognized, please pass either: all, gas, '\
              'energy, musk2, scene, dilbert or covid.')
        
if __name__ == "__main__":
    main()