import numpy as np
from scipy.signal import square

'''
References:
  Data Augmentation for Electrocardiogram Classification with Deep Neural Network
  (https://arxiv.org/abs/2009.04398)
  RandAugment implementation for PyTorch at timm (Pytorch Image Models)
  (https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/auto_augment.py)
  ECG augmentation implementation at Torch_ECG
  (https://github.com/DeepPSP/torch_ecg/tree/master/torch_ecg/augmenters)
'''

_LEVEL_DENOM = 10.
_HPARAMS_DEFAULT = {
    'max_proportion_const': 0.3,
    'wave_amplitude_const': 0.3,
    'wave_frequency_const': 0.5,
    'white_noise_scale_const': 0.3
}


def erase(signal):
    '''
    Erase: Randomly select a lead and set the signal of the selected lead to 0.
    '''
    output = signal.copy()
    num_leads = output.shape[0]
    erased_lead_num = np.random.randint(num_leads)
    output[erased_lead_num, :] = 0
    return output

def flip(signal):
    '''
    Flip: Flip the signal up and down at random.
    '''
    output = signal.copy()
    return -1 * output

def drop(signal, max_proportion):
    '''
    Drop: Randomly missing signal values.
    '''
    output = signal.copy()
    len_signal = output.shape[-1]
    max_count = int(len_signal * max_proportion)
    count = np.random.randint(0, max_count)
    indices = np.random.choice(np.arange(len_signal), (1, count), replace=False)
    output[:, indices] = 0
    return output

def cutout(signal, max_proportion):
    '''
    Cutout: Set a random interval signal to 0.
    '''
    output = signal.copy()
    len_signal = output.shape[-1]
    target_len = int(np.random.uniform(0, max_proportion) * len_signal)
    cutout_start_pt = np.random.randint(0, len_signal - target_len)
    output[:, cutout_start_pt:cutout_start_pt + target_len] = 0
    return output

def shift(signal, max_proportion):
    '''
    Shift: Shifts the signal at random.
    '''
    output = signal.copy()
    len_signal = output.shape[-1]
    direction = np.random.choice(['forward', 'backward'])
    shifting_len = int(np.random.uniform(0, max_proportion) * len_signal)
    if direction == 'forward':
        output[:, shifting_len:] = output[:, :len_signal - shifting_len]
        output[:, :shifting_len] = 0
    else:
        output[:, :len_signal - shifting_len] = output[:, shifting_len:]
        output[:, len_signal - shifting_len:] = 0
    return output

def entire_sine(signal, amplitude, frequency):
    '''
    Sine: Add a sine wave to the entire sample.
    '''
    output = signal.copy()
    output += amplitude * np.expand_dims(np.sin(np.arange(len(output[0])) * 2 * np.pi / frequency), axis=0)
    return output

def entire_square(signal, amplitude, frequency):
    '''
    Square: Add a square purse to the entire sample.
    '''
    output = signal.copy()
    output += amplitude * np.expand_dims(square(np.arange(len(output[0])) * 2 * np.pi / frequency), axis=0)
    return output

def partial_sine(signal, max_proportion, amplitude, frequency):
    '''
    Partial sine: Add a sine wave to a random interval only.
    '''
    output = signal.copy()
    len_signal = output.shape[-1]
    target_len = int(np.random.uniform(0, max_proportion) * len_signal)
    sine_start_pt = np.random.randint(0, len_signal - target_len)
    output[:, sine_start_pt:sine_start_pt + target_len] += amplitude *\
        np.expand_dims(np.sin(np.arange(target_len) * 2 * np.pi / frequency), axis=0)
    return output

def partial_square(signal, max_proportion, amplitude, frequency):
    '''
    Partial square: Add a square purse to a random interval only.
    '''
    output = signal.copy()
    len_signal = output.shape[-1]
    target_len = int(np.random.uniform(0, max_proportion) * len_signal)
    square_start_pt = np.random.randint(0, len_signal - target_len)
    output[:, square_start_pt:square_start_pt + target_len] += amplitude *\
        np.expand_dims(square(np.arange(target_len) * 2 * np.pi / frequency), axis=0)
    return output

def partial_white_noise(signal, max_proportion, upper_scale):
    '''
    Partial white noise: Add white noise to a random interval.
    '''
    output = signal.copy()
    len_signal = output.shape[-1]
    target_len = int(np.random.uniform(0, max_proportion) * len_signal)
    cutout_start_pt = np.random.randint(0, len_signal - target_len)
    scale = np.random.uniform(0, upper_scale)
    output[:, cutout_start_pt:cutout_start_pt + target_len] += np.random.normal(0, scale, size=(1, target_len))
    return output


def _augment_proportion_from_level(level, _hparams):
    max_proportion = (level / _LEVEL_DENOM) * _hparams['max_proportion_const']
    return max_proportion,

def _wave_arg_from_level(level, _hparams):
    level = level / _LEVEL_DENOM
    amplitude = level * _hparams['wave_amplitude_const']
    frequency =\
        _hparams['signal_len_cut_sec'] * _hparams['sample_rate_to'] * _hparams['wave_frequency_const'] / level
    return amplitude, frequency

def _partial_wave_arg_from_level(level, _hparams):
    max_proportion, = _augment_proportion_from_level(level, _hparams)
    amplitude, frequency = _wave_arg_from_level(level, _hparams)
    return max_proportion, amplitude, frequency

def _partial_white_noise_arg_from_level(level, _hparams):
    max_proportion, = _augment_proportion_from_level(level, _hparams)
    upper_scale = level / _LEVEL_DENOM * _hparams['white_noise_scale_const']
    return max_proportion, upper_scale


LEVEL_TO_ARG = {
    'erase': None,
    'flip': None,
    'drop': _augment_proportion_from_level,
    'cutout': _augment_proportion_from_level,
    'shift': _augment_proportion_from_level,
    'sine': _wave_arg_from_level,
    'square': _wave_arg_from_level,
    'partial_sine': _partial_wave_arg_from_level,
    'partial_square': _partial_wave_arg_from_level,
    'partial_white_noise': _partial_white_noise_arg_from_level
}

NAME_TO_OP = {
    'erase': erase,
    'flip': flip,
    'drop': drop,
    'cutout': cutout,
    'shift': shift,
    'sine': entire_sine,
    'square': entire_square,
    'partial_sine': partial_sine,
    'partial_square': partial_square,
    'partial_white_noise': partial_white_noise
}


class AugmentOp:
    def __init__(self, name, prob=0.5, level=10, sig_len_sec=9, sample_rate=250, hparams=None):
        hparams = hparams or _HPARAMS_DEFAULT
        self.name = name
        self.aug_fn = NAME_TO_OP[name]
        self.level_fn = LEVEL_TO_ARG[name]
        self.prob = prob
        self.level = level
        self.hparams = hparams.copy()
        self.hparams['signal_len_cut_sec'] = sig_len_sec
        self.hparams['sample_rate_to'] = sample_rate

    def __call__(self, signal):
        if np.random.random() > self.prob:
            return signal
        level_args = self.level_fn(self.level, self.hparams) if self.level_fn is not None else tuple()
        return self.aug_fn(signal, *level_args)

    def __repr__(self):
        s = self.__class__.__name__ + f'(name={self.name}, p={self.prob})'
        s += f', m={self.level})'
        return s


class ECGRandAugment:
    def __init__(self, params, op_names, level, num_layers=2):
        sig_len_sec = params['preproc']['signal_len_cut_sec']
        sample_rate = params['preproc']['sample_rate_to']
        self.ops = [AugmentOp(name, prob=0.5, level=level, sig_len_sec=sig_len_sec, sample_rate=sample_rate)
                    for name in op_names]
        self.num_layers = num_layers

    def __call__(self, signal):
        ops = np.random.choice(self.ops, self.num_layers, replace=False)
        for op in ops:
            signal = op(signal)
        return signal

    def __repr__(self):
        s = self.__class__.__name__ + f'(n={self.num_layers}, ops='
        for op in self.ops:
            s += f'\n\t{op}'
        s += ')'
        return s
