# Copyright (C) Authors of submission, all rights reserved


from dataclasses import dataclass, field
from abc import ABC
from typing import ClassVar, List, Tuple, Type
import numpy as np


class UniVarAugment(ABC):

    def set_seed(self, seed: int):
        self._rng = np.random.default_rng(seed)

    @property
    def rng(self):
        rng = getattr(self, "_rng", np.random.default_rng())
        self._rng = rng
        return rng
    
    def __call__(sample, return_info=False, **kwargs):
        raise NotImplemented("Not implemented!")

    def return_with_opt_info(self, sample, info, return_info):
        if return_info:
            return sample, info
        else:
            return sample


@dataclass
class CensorAugmentionConfig:
    censor_prob: float = 0

@dataclass
class CensorAugmention(UniVarAugment):
    config_class: ClassVar[Type] = CensorAugmentionConfig
    config: CensorAugmentionConfig


    def __call__(self, sample: np.ndarray, return_info=False, **kwargs) -> np.ndarray:
        if self.rng.random() >= self.config.censor_prob:
            return self.return_with_opt_info(sample, {}, return_info)

        sample = np.array(sample, dtype=sample.dtype)
        quantile = self.rng.uniform(0, 1)
        qauntile_value = np.quantile(sample, quantile)
        bottom_censor = self.rng.random() >= 0.5
        if bottom_censor:
            sample[sample < qauntile_value] = qauntile_value
        else:
            sample[sample > qauntile_value] = qauntile_value

        return sample


@dataclass
class AmplitudeTrendAugmentionConfig:
    apply_prob: float = 0
    max_num_changepoint: int = 5
    amplitude_scale: float = 1


@dataclass
class ApmplitudeTrendAugmentation(UniVarAugment):
    config_class: ClassVar[Type] = AmplitudeTrendAugmentionConfig
    config: AmplitudeTrendAugmentionConfig


    def __call__(self, sample: np.ndarray, return_info=False, **kwargs) -> np.ndarray:
        length = len(sample)
        dtype = sample.dtype

        if self.rng.random() > self.config.apply_prob:
            return self.return_with_opt_info(sample, {"type": "none"}, return_info)
        else:
            num_change_points = self.rng.integers(low=0, high=min(self.config.max_num_changepoint, max(1, length // 64)))  # limit numbe rof change points for short series
            change_pos = self.rng.integers(1, length, num_change_points)
            change_pos = np.concatenate(([0], sorted(change_pos), [length]))
            amplitudes = self.rng.normal(1, scale=self.config.amplitude_scale, size=num_change_points+2)
            trend = np.interp(np.arange(length), change_pos, amplitudes)
    
        trend = trend.astype(dtype)
        return self.return_with_opt_info(sample * trend, {}, return_info)



def tophat_kernel(x, centers, width, amplitude):
    return np.where(np.abs(x - centers) <= width / 2, 1, 0) * amplitude

def rbf_kernel(x, centers, gamma, amplitude):
    return np.exp(-gamma * (x - centers)**2) * amplitude

def linear_kernel(x, centers, width, amplitude):
    return np.maximum(0, (1 - np.abs(x - centers) / width)) * amplitude


SPIKE_KERNEL_MAP = {
    "tophat": tophat_kernel,
    "rbf": rbf_kernel,
    "linear": linear_kernel,
}
SPIKE_KERNEL_PARAM_RANGE = {
    "tophat": {
        "width": lambda p: (1, p // 3),
        "amplitude": lambda p: (-4, 4)
    },
    "rbf": {
        "gamma": lambda p: (0.5 / p, 5 / p),
        "amplitude": lambda p: (-4, 4)
    },
    "linear": {
        "width": lambda p: (1, p // 2),
        "amplitude": lambda p: (-4, 4)
    },
}



@dataclass
class SpikeAugmentConfig:
    apply_prob: float = 0.1
    pattern_prob: List[float] = (0.75, 0.1, 0.1, 0.05)
    spike_type_prob: List[float] = (0.4, 0.4, 0.2)
    spike_in_fc_prob: float = 0.5
    min_periodicty: int = 10
    max_periodicty: int = 512


@dataclass
class SpikeAugmention(UniVarAugment):
    config_class: ClassVar[Type] = SpikeAugmentConfig
    config: SpikeAugmentConfig

    spike_pattern: ClassVar[List[Tuple[List[int]]]] = [
        (   # Simple
            [0],
            [0,1],
        ),
        (   # 3-periodicity
            [0,1,2],
            [0,0,1],
        ),
        (   # 4 periodicityy
            [0,0,1,1],
            [0,0,1,2],
            [0,1,0,2],
        ), 
        (   # weekly periodicty
            [0,0,0,0,0,1,1], # week, weekend
            [0,0,0,0,0,1,2]  # week, saturday, sunday
        )
    ]

    def __post_init__(self):
        assert len(self.spike_pattern) == len(self.config.pattern_prob)


    def _sample_pattern(self):
        pattern_idx = self.rng.choice(len(self.spike_pattern), p=self.config.pattern_prob)
        pattern_sub_idx = self.rng.integers(0, len(self.spike_pattern[pattern_idx]))
        pattern = self.spike_pattern[pattern_idx][pattern_sub_idx]
        used_pattern = np.roll(pattern, self.rng.integers(0, high=len(pattern)))
        return used_pattern

    def _sample_periodicity(self, sample_length):
        periodicity = self.rng.integers(self.config.min_periodicty, min(sample_length, self.config.max_periodicty)).item()
        return periodicity

    def _gen_spike_param(self, kernel_type, periodicity):
        params_range = SPIKE_KERNEL_PARAM_RANGE[kernel_type]
        params = {}
        for p_name, p_range in params_range.items():
            min, max = p_range(periodicity)
            params[p_name] = self.rng.uniform(min, max)
        return params

    def _calc_spike(self, spike_axis, positions, kernel_type, spike_param):
        return SPIKE_KERNEL_MAP[kernel_type](x=spike_axis, centers=positions, **spike_param)

    def _apply_pattern(self, length, pattern, periodicity, last_spike_pos):
        pattern_len = len(pattern)
        padded_length = length + 2 * periodicity # pad with periodicity so that after shifting the periodiciy stays
        last_spike_pos = last_spike_pos + 2 * periodicity # padding needs to account for last  spike pos
        pattern_with_zeros_len = pattern_len * periodicity
        num_full_repetitions = padded_length // pattern_with_zeros_len

        # Full Repetations
        if num_full_repetitions > 0:
            base = np.zeros(pattern_len * periodicity, dtype=pattern.dtype)
            base[::periodicity] = pattern
            base = np.tile(base, num_full_repetitions)
        else:
            base = np.array([], dtype=pattern.dtype)
        
        # Rest itertaive
        remaining_length = padded_length % pattern_with_zeros_len
        if remaining_length > 0:
            num_pattern_elements_remaining = min(pattern_len, (remaining_length + periodicity - 1) // periodicity)
            remaining_pattern = np.zeros(remaining_length, dtype=pattern.dtype)
            remaining_pattern[:(num_pattern_elements_remaining * periodicity):periodicity] = pattern[:num_pattern_elements_remaining]
            result = np.concatenate((base, remaining_pattern))
            last_spike_current_pos = len(base) + np.flatnonzero(remaining_pattern)[-1].item()
        else:
            result = base
            last_spike_current_pos = padded_length - periodicity
        assert len(result) == length + 2 * periodicity

        # Shift for right last spike pos
        shift = last_spike_pos - last_spike_current_pos
        assert shift < periodicity
        result_shifted = np.roll(result, shift=shift)
        result_shifted = result_shifted[periodicity: -periodicity] # cut off padding
        assert len(result_shifted) == length
        return result_shifted

    def _compute_spike_kernels(self, spike_position, periodicity, dtype):
        spike_aug = np.zeros_like(spike_position, dtype=dtype)
        kernel_type = self.rng.choice(list(SPIKE_KERNEL_MAP.keys()),  p=self.config.spike_type_prob).item()
        spikes = np.unique(spike_position)[1:]
        spike_param = [self._gen_spike_param(kernel_type, periodicity) for _ in spikes]
        spike_axis = np.arange(len(spike_position))
        for i, spike in enumerate(spikes):
            for p in np.flatnonzero(spike_position == spike):
                tmp = self._calc_spike(
                    spike_axis=spike_axis, 
                    positions=p, 
                    kernel_type=kernel_type, 
                    spike_param=spike_param[i]
                )
                spike_aug += tmp
            spike_param[i]["type"] = kernel_type    
        return spike_aug, spike_param

    def __call__(self, sample: np.ndarray, past_end_idx, return_info=False, **kwargs) -> np.ndarray:
        length = len(sample)
        if length < self.config.min_periodicty or self.rng.random() >= self.config.apply_prob:
            return self.return_with_opt_info(sample, {}, return_info)

        periodicity = self._sample_periodicity(length)
        pattern = self._sample_pattern() + 1
        if self.rng.random() <= self.config.spike_in_fc_prob:
            last_spike_pos = self.rng.integers(max(past_end_idx, length-periodicity), length).item()
        else:
            last_spike_pos = self.rng.integers(length-periodicity, length).item()
        spike_positions = self._apply_pattern(length, pattern, periodicity, last_spike_pos)
        augmention, spike_param = self._compute_spike_kernels(spike_positions, periodicity=periodicity, dtype=sample.dtype)
        sample = np.array(sample, dtype=sample.dtype)
        sample += augmention
        return self.return_with_opt_info(sample, {
            "pattern": pattern,
            "periodicity": periodicity,
            "spike_param": spike_param
        }, return_info)
