import numpy as np
from numpy.random import default_rng

import argparse

import os

# datasets = ['banknote', 'bcc', 'column', 'iris', 'planning_relax', 'student', 'survival', 'wine', 'somerville', 'wine4']
# datasets = ['digits4', 'penguin', 'seeds']
# datasets = ['digits2', 'digits3', 'penguin2', 'penguin3']
# datasets = ['digits4', 'iris', 'penguin', 'student', 'wine']
# datasets = ['digits4_pca4', 'penguin_pca4', 'wine_pca4']
# datasets = ['monks1', 'monks2', 'monks3']
datasets = ['seeds_pca4']
# noises = [0.02, 0.05, 0.1, 0.15, 0.2, 0.25]
noises = [0.02, 0.04, 0.06, 0.08, 0.1, 0.15]
num_trials = 5

def main(data_dir, noises, num_draws):
    X = np.load(f'{data_dir}/X_data.npy')
    y = np.load(f'{data_dir}/y_data.npy')

    for p in noises:
        for t in range(num_draws):
            seed = int((p * 1432) * (t+1.4) ** 2)
            rng = default_rng(seed)
    
            indices = rng.choice(range(y.shape[0]), size=int(y.shape[0] * p), replace=False)
            noise = np.zeros(y.shape)
            noise[indices] = 1

            print(indices.shape)

            y_new = (y + noise) % 2

            os.makedirs(f'{data_dir}/noise_{p}/trial_{t}', exist_ok=True)

            np.save(f'{data_dir}/noise_{p}/trial_{t}/X_data', X)
            np.save(f'{data_dir}/noise_{p}/trial_{t}/y_data', y_new)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Creates multiple draws of noise for a given dataset')
    parser.add_argument('-d', '--data_dir', default='./datasets/wine_pca')
    parser.add_argument('-n', '--noises', nargs='*', type=float, default=[0.02, 0.04, 0.06, 0.08, 0.1, 0.15])
    parser.add_argument('-nd', '--num_draws', default=5, type=int)

    args = parser.parse_args()

    main(
        data_dir = args.data_dir,
        noises = args.noises,
        num_draws = args.num_draws,
    )