import torch
import numpy as np
from typing import List

from batlinet.builders import FEATURE_EXTRACTORS
from batlinet.data.battery_data import BatteryData
from batlinet.feature.base import BaseFeatureExtractor
from batlinet.feature.severson import get_Qdlin, smooth
from scipy.ndimage import uniform_filter1d, gaussian_filter1d, median_filter

def smooth_and_reduce(X, s):
    Xp = X[s:]
    if s > 0:
        x_old = np.linspace(0, 1, len(Xp))  
        x_new = np.linspace(0, 1, len(X))        
        X = np.interp(x_new, x_old, Xp)
    return X


@FEATURE_EXTRACTORS.register()
class VoltageCapacityMatrixFeatureExtractor(BaseFeatureExtractor):
    def __init__(self,
                 interp_dim: int = 100,
                 diff_base: int = 9,
                 cycles_to_keep: List[int] = None,
                 min_cycle_index: int = 0,
                 max_cycle_index: int = 99,
                 use_precalculated_qdlin: bool = False,
                 smooth: bool = True,
                 cycle_average: int = None,
                 cycle_to_drop: list = None,
                 get_steps_after: int = 0):
        self.interp_dim = interp_dim
        self.min_cycle_index = min_cycle_index
        self.max_cycle_index = max_cycle_index
        self.use_precalculated_qdlin = use_precalculated_qdlin
        self.get_steps_after = get_steps_after

        assert diff_base <= max_cycle_index, (diff_base, max_cycle_index)
        assert diff_base >= min_cycle_index, (diff_base, min_cycle_index)
        self.diff_base = diff_base

        if cycles_to_keep is not None and isinstance(cycles_to_keep, int):
            cycles_to_keep = [cycles_to_keep]
        self.cycles_to_keep = cycles_to_keep

        self.smooth = smooth

        # See https://github.com/petermattia/revisit-severson-et-al/blob/main/revisit-severson-et-al.ipynb noqa
        self.cycle_average = cycle_average
        if isinstance(cycle_to_drop, int):
            cycle_to_drop = [cycle_to_drop]
        self.cycle_to_drop = cycle_to_drop or []

    def process_cell(self, cell_data: BatteryData) -> torch.Tensor:
        feature = []
        diff_base_qdlin = get_Qdlin(
            cell_data,
            cell_data.cycle_data[self.diff_base],
            self.use_precalculated_qdlin,
            interp_dim=self.interp_dim)
        if self.smooth:
            diff_base_qdlin = smooth(diff_base_qdlin)
        if self.cycle_average is not None:
            diff_base_qdlin = diff_base_qdlin[..., ::self.cycle_average]

        for cycle_index, cycle_data in enumerate(cell_data.cycle_data):
            if cycle_index < self.min_cycle_index:
                continue
            if cycle_index > self.max_cycle_index:
                break

            if self.cycles_to_keep is not None \
                    and cycle_index not in self.cycles_to_keep:
                continue

            if cycle_index in self.cycle_to_drop:
                feature.append(torch.zeros(self.interp_dim,))
                continue

            qdlin = get_Qdlin(
                cell_data, cycle_data, self.use_precalculated_qdlin, interp_dim=self.interp_dim)
            if self.smooth:
                qdlin = smooth(qdlin)
            if self.cycle_average is not None:
                qdlin = qdlin[..., ::self.cycle_average]

            diff_qdlin = qdlin - diff_base_qdlin
            if self.smooth:
                diff_qdlin = smooth(diff_qdlin)

            if diff_qdlin.mean() > 1.0:
                feature.append(torch.zeros(self.interp_dim,))
                continue

            feature.append(torch.from_numpy(smooth_and_reduce(diff_qdlin, self.get_steps_after)))
        feature = torch.stack(feature)

        if self.smooth:
            feature = torch.stack([
                hampel_smooth(cycle)
                for cycle in feature
            ])
        # Fill NaN
        feature[torch.isnan(feature) | torch.isinf(feature)] = 0.

        return feature



def rollingOps1d(x, func, window_size=11):
    processed = func(x.unfold(-1, window_size, 1))
    L, l = x.size(-1), processed.size(-1)  # noqa
    left = (L - l) // 2
    right = L - l - left
    res = torch.zeros_like(x)
    res[..., left:-right] = processed
    res[..., :left] = res[..., [left]]
    res[..., -right:] = res[..., [-(right+1)]]

    return res

def med1d(x, window_size=10):
    def med(x):
        return x.median(-1)[0]
    return rollingOps1d(x, med, window_size)


def mad1d(x, window_size=10):
    def mad(x):
        med = x.median(-1)[0]
        diff = (x - med.unsqueeze(-1)).abs()
        return diff.median(-1)[0]
    return rollingOps1d(x, mad, window_size)

def _hampel_smooth(x, window_size):
    med = med1d(x, window_size)
    diff = (x - med).abs()
    sigma = 1.4826 * mad1d(x, window_size) * 3

    res = x.clone()
    res[diff > sigma] = med[diff > sigma]

    return res


def hampel_smooth(x, window_size=21, device='cuda:0'):
    # x size (*, L)
    # NOTE: x should not be too large, as the unfold will expand the memory use
    #       if x is very large (e.g. [B, N, K, L] with large B and N), you can
    #       use torch.stack([x_single for x_single in x])

    assert window_size % 2 == 1, 'Window size must be odd!'
    is_array = False
    if isinstance(x, np.ndarray):
        x = torch.from_numpy(x)
        is_array = True

    original_device = x.device
    x = x.to(device)
    res = _hampel_smooth(x, window_size)
    res = res.to(original_device)

    if is_array:
        res = res.cpu().numpy()

    return res