################################################################################
# Script to split an input data file (e.g., obtained with simulate_data.py,
# transform_data.py) into the reference and query datasets, with 50% of samples
# each, select the fraction of features to be manipulated and transform them 
# according to a specified type of manipulation.
################################################################################

import sys
import os
from pathlib import Path
import numpy as np
import pandas as pd
import random
import torch.nn as nn
import torch

from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor

from numpy import ndarray
from pandas import DataFrame

sys.path.append('../')

from src.preprocessing._transform_data import (
    reference_query_split, 
    indexes_to_manipulate,
    MLP,
    manipulate_features,
    impute_features,
    compute_rmse_of_manipulation
)

def main():
    if len(sys.argv) < 6:
        print(len(sys.argv))
        print("Error - Incorrect input")
        print("Expecting python3 transform_data.py [data_path] [output_path] "\
              "[transformation] [fraction] [maxStd]")
        sys.exit(0)
    
    # Parse the input arguments
    _, data_path, output_path, transformation, fraction, maxStd = sys.argv
    fraction = float(fraction)
    maxStd = bool(maxStd)
    try: transformation = float(transformation)
    except: transformation = eval(transformation)
    
    # Read data from given data_path
    data = np.load(data_path, allow_pickle=True)
    
    # Split the data into reference and query datasets with 50% of rows each
    reference, query = reference_query_split(data)
    
    # Obtain the indexes of the features to manipulate in the query dataset
    manipulated_idxs = indexes_to_manipulate(query, fraction, maxStd)
    
    # Apply the selected transformation to the query dataset
    if isinstance(transformation, float):
        query[:, manipulated_idxs] = manipulate_features(query[:, manipulated_idxs], transformation)
    else:
        query[:, manipulated_idxs] = impute_features(reference, query, manipulated_idxs, transformation)
    
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    
    # Save the reference and query datasets as .npy arrays
    np.save(f'{os.path.splitext(data_path)[0]}_reference.npy', reference)
    np.save(f'{os.path.splitext(data_path)[0]}_query.npy', query)
    
if __name__ == "__main__":
    main()
