import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import copy
import time
from tqdm import tqdm as tqdm
from scipy.spatial import distance
from sklearn_extra.cluster import CLARA
from pathlib import Path
import os
import argparse
from scipy.spatial import distance


def parse_arguments():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--noise_portion', type=float, default=0.1)
    parser.add_argument('--snr', type=int, default=100)
    parser.add_argument('--data_path', type=str)
    
    return parser

parser = parse_arguments()
args = parser.parse_args()
fname, ext = args.data_path.split('.')

data = np.load(args.data_path)
labels = data['label']
data = data['data']


from tqdm import tqdm as tqdm

def random_sampling(data, num_sample):
    #print(len(data))
    #print(data.shape)
    np.random.seed(2024)
    rand_ids = np.random.choice(len(data), num_sample, replace=False)
    print(len(rand_ids))
    print('---')
    return rand_ids

def add_noise_gaussian_full_data(X_train, noise_portion = 0.1, snr = 100):
    print('xxx')
    print(X_train.shape)
    print(int(len(X_train) * noise_portion))
    random_ids = random_sampling(X_train, int(len(X_train) * noise_portion))
    #print(random_ids)
    #arranged_ids = np.sort(random_ids)
    #print(arranged_ids)
    print(len(random_ids))
    for i in tqdm(random_ids):
        x = X_train[i]
        #print('X shape')
        #print(x.shape)
        sp = np.mean( x**2 ) # Signal Power
        std_n = ( sp / snr )**0.5 # Noise std. deviation
        n = np.random.normal(0, std_n, x.shape[0])
        #print(n.shape)
        #print(np.sum(n))
        xn = x + n
        #print(xn.shape)
        X_train[i] = np.array(xn)
    print(X_train.shape)
    return X_train

def add_noise_gaussian_per_key(X_train, noise_portion, snr):
    print('xxx')
    print(X_train.shape)
    print(int(len(X_train) * noise_portion))
    random_ids = random_sampling(X_train, int(len(X_train) * noise_portion))
    print(random_ids)
    print(len(random_ids))
    arranged_ids = np.sort(random_ids)
    print(arranged_ids)
    X_res = copy.deepcopy(X_train)
    for i in tqdm(random_ids):
        x = X_train[i]
        ##print('X shape')
        #print(x.shape)
        #sp = np.mean( x**2 ) # Signal Power
        #std_n = ( sp / snr )**0.5 # Noise std. deviation
        #print(std_n)
        #print(x.shape)
        std_n = 1/snr
        n = np.random.normal(0, std_n, x.shape[0])
        #print(n.shape)
        #print(n[:10])
        #print(x.shape)
        #print(n.shape)
        #print(np.sum(n))
        xn = x + n * x

        #print(xn.shape)
        X_res[i] = np.array(xn)
    print(X_res.shape)
    return X_res

def add_noise_gaussian(X_train_all, noise_portion = 0.1, snr = 100):
    print(noise_portion)
    X_train_res = []
    for data_per_key in X_train_all:
        X_train_key = add_noise_gaussian_per_key(data_per_key, noise_portion, snr)
        #print(X_train_key.shape)
        #print(data_per_key.shape)
        print((X_train_key==data_per_key).all())
        X_train_res.append(X_train_key)
    X_train_res = np.array(X_train_res)
    print(X_train_res.shape)
    #exit()
    return X_train_res 




data_noise = add_noise_gaussian(data, args.noise_portion)
print(data_noise.shape)


np.savez('{}_noise_{}_{}.npz'.format(fname ,args.noise_portion, args.snr), data=data_noise, label=labels)