"""
ECG Signal Preprocessing Module

- Cropping
- Resampling
- Filtering
- Standardization
"""
from typing import Iterable, Optional

import numpy as np
from scipy.signal import butter, resample, sosfiltfilt


# https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/metrics/_classification.py#L1287
def prf_divide(numerator, denominator, zero_divide_return=0.):
    mask = denominator == 0.
    if not isinstance(mask, Iterable):
        result = np.ones_like(numerator) * zero_divide_return if mask else numerator / denominator
        return result
    denominator = denominator.copy()
    denominator[mask] = 1
    result = numerator / denominator
    if not np.any(mask):
        return result
    result[mask] = zero_divide_return
    return result


class ECGProcessing:
    def __init__(
        self,
        mode: str,
        sample_rate_to: int,
        high_pass_filter_use: bool,
        high_pass_filter_cut_off: float,
        high_pass_filter_order: int,
        low_pass_filter_use: bool,
        low_pass_filter_cut_off: float,
        low_pass_filter_order: int,
        signal_len_cut_sec: float,
        eval_cut_section_sec: Iterable,
        lead: list = ['I', 'II', 'III', 'AVR', 'AVL', 'AVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'],
        lead_mask_method: Optional[str] = None,
        lead_mask_ratio: Optional[float] = None,
        **kwargs
    ):
        self.training = mode == 'train'
        self.crop_signal_time_length = signal_len_cut_sec
        self.crop_signal_time_intervals = eval_cut_section_sec
        assert all([end_sec - start_sec == signal_len_cut_sec for start_sec, end_sec in eval_cut_section_sec]),\
            "Please check \"signal_len_cut_sec\" and \"eval_cut_section_sec\" again."
        self.resampled_sample_rate = sample_rate_to
        self.filter_use = any([high_pass_filter_use, low_pass_filter_use])
        self.filters = []
        if high_pass_filter_use:
            self.filters.append(butter(high_pass_filter_order, high_pass_filter_cut_off, btype='highpass',
                                       fs=sample_rate_to, output='sos'))
        if low_pass_filter_use:
            self.filters.append(butter(low_pass_filter_order, low_pass_filter_cut_off, btype='lowpass',
                                       fs=sample_rate_to, output='sos'))

        self.lead = lead
        self.lead_indice = {lead_name: i for i, lead_name in enumerate(self.lead)}
        self.lead_mask_method = lead_mask_method
        self.lead_mask_ratio = lead_mask_ratio

    def train(self):
        self.training = True

    def eval(self):
        self.training = False

    def get_crop_params(self, signal, sample_rate):
        if self.training:  # random cropping
            h = signal.shape[-1]
            th = int(self.crop_signal_time_length * sample_rate)
            assert h >= th, f'{h} is shorter than {th}. set [signal_len_cut_sec] properly in yaml.'
            # i = np.random.randint(0, h - th + 1)
            i = int(np.random.uniform(low=0, high=h - th, size=1))
            return [[i, i + th]]
        else:
            return [[int(start_sec * sample_rate), int(end_sec * sample_rate)]
                    for (start_sec, end_sec) in self.crop_signal_time_intervals]

    def cropping(self, signal, sample_rate):
        crop_points = self.get_crop_params(signal, sample_rate)
        samples = []
        for start, end in crop_points:
            samples.append(signal[:, start:end])
        return np.stack(samples)

    def resampling(self, signal, original_sample_rate):
        target_length = int(signal.shape[-1] * self.resampled_sample_rate / original_sample_rate)
        resampled_signal = resample(signal, target_length, axis=-1)
        return resampled_signal

    def filtering(self, signal):
        filtered_signal = signal
        for sos_filter in self.filters:
            filtered_signal = sosfiltfilt(sos_filter, filtered_signal)
        return filtered_signal

    def lead_masking(self, signal):
        num_leads = signal.shape[1]
        if self.lead_mask_method == 'limb':  # If eval mode, lead_mask_ratio = 1
            masking_lead_indice = [self.lead_indice[lead_name]
                                   for lead_name in self.lead if lead_name.startswith('V')]
            lead_mask_ratio = self.lead_mask_ratio if self.training else 1
        elif self.lead_mask_method == 'single':  # If eval mode, lead_mask_ratio = 1
            masking_lead_indice = [self.lead_indice[lead_name]
                                   for lead_name in self.lead if lead_name != 'I']
            lead_mask_ratio = self.lead_mask_ratio if self.training else 1
        elif self.lead_mask_method == 'random' and self.training:  # RLM only for training
            masking_lead_indice = list(self.lead_indice.values())
            lead_mask_ratio = self.lead_mask_ratio
        else:
            masking_lead_indice = []
            lead_mask_ratio = 0
        p = np.random.rand()
        masked_lead_indice = []
        for lead_index in range(num_leads):
            if lead_index in masking_lead_indice and p < lead_mask_ratio:
                signal[:, lead_index, :] = 0
                masked_lead_indice.append(lead_index)
            if self.lead_mask_method == "random":
                p = np.random.rand()
        return signal, np.array(masked_lead_indice)

    def standardization(self, signal, masked_lead_indice=None):
        outputs = []
        for sig_sample in signal:
            unmasked = sig_sample[~np.isin(np.arange(sig_sample.shape[0]), masked_lead_indice)]
            m = np.mean(unmasked)
            s = np.std(unmasked)
            sig_sample[~np.isin(np.arange(sig_sample.shape[0]), masked_lead_indice)] = prf_divide(unmasked - m, s)
            outputs.append(sig_sample)
        outputs = np.stack(outputs)
        return outputs

    def __call__(self,
                 signal: np.ndarray,
                 original_sample_rate: int,
                 ):
        samples = self.cropping(signal, original_sample_rate)
        if original_sample_rate != self.resampled_sample_rate:
            samples = self.resampling(samples, original_sample_rate)
        if self.filter_use:
            samples = self.filtering(samples)
        if self.lead_mask_method:
            samples, masked_lead_indice = self.lead_masking(samples)
        else:
            masked_lead_indice = None
        samples = self.standardization(samples, masked_lead_indice)
        if self.training:  # Size(1, channels, *rest) -> Size(channels, *rest)
            samples = samples.squeeze(0)
        return samples

    @classmethod
    def init_from_params(cls, mode, params):
        kwargs = {
            'sample_rate_to': params['preproc']['sample_rate_to'],
            'high_pass_filter_use': params['preproc']['high_pass_filter_use'],
            'high_pass_filter_cut_off': params['preproc']['high_pass_filter_cut_off'],
            'high_pass_filter_order': params['preproc']['high_pass_filter_order'],
            'low_pass_filter_use': params['preproc']['low_pass_filter_use'],
            'low_pass_filter_cut_off': params['preproc']['low_pass_filter_cut_off'],
            'low_pass_filter_order': params['preproc']['low_pass_filter_order'],
            'signal_len_cut_sec': params['preproc']['signal_len_cut_sec'],
            'eval_cut_section_sec': params['preproc']['eval_cut_section_sec'],
            'lead': params['preproc']['lead'],
            'lead_mask_method': params['preproc']['lead_masking']['lead_mask_method'],
            'lead_mask_ratio': params['preproc']['lead_masking']['lead_mask_ratio'],
        }
        instance = cls(mode=mode, **kwargs)
        return instance
