import os
import sys
import argparse
import numpy as np

from ca_database_api import DataHandler

sys.path.append('../')
sys.path.append('../utils/')
from utils.default_config import set_default_cpc_config, get_choice_default_config


def process_clean_samples(
        args,
):
    data_handler = DataHandler(
        database_save_dir=args.database_save_dir,
        data_name=args.data_name,
        exp_id=None,
        patient_list=None,
        noise_ratio=0,
        window_time=args.window_time,
        slide_time=args.slide_time,
        num_level=args.num_level,
    )

    subject_id = 0
    while True:
        file_path = data_handler.obtain_database_dir(subject_id, 0)
        if os.path.exists(file_path):
            data_handler.patient_list = [subject_id]
            data_pack = data_handler.get_data(clean_label=True)

            for r in args.noise_ratio:
                noise_y = __add_random_noise__(data_pack.label, r)

                for level in range(args.num_level):
                    data_handler.noise_ratio = r
                    save_path = data_handler.obtain_database_dir(subject_id, level, random_noise=True)
                    print(f'Save to path: {save_path}')

                    np.savez_compressed(save_path,
                                        data=np.array(data_pack.data[level]),
                                        label=np.array(noise_y[level]),
                                        loc=np.array(data_pack.loc[level]))
        else:
            break
        data_handler.noise_ratio = 0
        subject_id += 1


def __add_random_noise__(
        y,
        noise_r,
):
    num_class = len(np.unique(y))
    print(f'Number of Classes: {num_class}')

    # y: num_level x seg_big_num x seg_small_num
    num_level, seg_big_num, seg_small_num = y.shape
    y = y.reshape(-1)
    replace_position = np.random.choice(len(y), int(len(y) * noise_r), replace=False)
    replace_shift = np.random.choice(np.arange(num_class - 1) + 1, int(len(y) * noise_r))
    change_y = (y[replace_position] + replace_shift) % num_class

    assert (change_y == y[replace_position]).sum() == 0
    print(f'Number of changed labels: {len(change_y)}')

    y[replace_position] = change_y
    y = y.reshape([num_level, seg_big_num, seg_small_num])
    assert len(np.unique(y)) == num_class

    return y


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='RandomNoise')
    parser = set_default_cpc_config(parser)

    parser.add_argument('--database_save_dir', type=str, default='/data/CL_database/',
                        help='Should give an absolute path to save the database of the subjects.')
    parser.add_argument('--data_name', type=str, default='fNIRS_2',
                        help='Should give the name of the database [fNIRS_2, Sleep].')
    parser.add_argument('--noise_ratio', nargs='*', type=float, default=None,
                        help='The maximal ratio of adding noise.')
    argv = sys.argv[1:]
    args_ = parser.parse_args(argv)
    args_.exp_id = 1
    args_, config = get_choice_default_config(args_)

    if args_.noise_ratio is None:
        args_.noise_ratio = [.2, .4]

    process_clean_samples(args_)
