import numpy as np
import torch
from typing import Union


def preprocess(data, mode: str = 'auto', window_size: list = None, lower: Union[float, int] = 0.1, upper: Union[float, int] = 0.9, ignore: bool = False):
    def obtain_counter(timestamp, window, tmax):
        start, count, median, fill_rate = 0, 0, [], []
        while start < tmax:
            stat = ((timestamp > start) & (timestamp <= start + window)).sum(dim=-1).numpy() + (1 if start == 0 else 0)
            fill_rate.append(stat[..., np.newaxis].copy())
            # print(tmax / window, stat.min(), stat.max(), np.median(stat))
            median.append(np.median(stat))
            count += 1 if stat.sum() else 0
            start += window
        return count, np.median(median) != 0

    assert mode in ['auto', 'manner']

    timestamp, windows, length = data[..., -1], [], []
    lower, upper = lower * timestamp.size(-1) if lower < 1 else lower, upper * timestamp.size(-1) if upper <= 1 else upper
    t, tmax, index = 1, torch.max(timestamp), 0
    while True:
        if mode == 'auto':
            t *= 2
            win = tmax / t
        elif mode == 'manner':
            win = tmax / window_size[index]
            index += 1
        else:
            raise NotImplementedError('mode[%s] >_<' % mode)

        counter, flag = obtain_counter(timestamp, win, tmax)
        if (mode == 'manner' or lower <= counter <= upper and counter < timestamp.size(-1)) and (ignore | flag):
            windows.append(win), length.append(counter)

        if mode == 'manner' and index >= len(window_size) or counter > timestamp.size(-1) or counter >= upper:
            break
    # end while
    # reverse windows and counter through index
    return torch.flip(torch.tensor(windows, device=data.device), dims=[0]), torch.flip(torch.tensor(length, device=data.device, dtype=torch.int64), dims=[0])


def obtain_window_length(loader, args):
    val, time, mask, tm = [], [], [], 0
    for batch in loader:
        time.append(batch[0]), val.append(batch[1]), mask.append(batch[2])
        tm = time[-1].size(-1) if tm < time[-1].size(-1) else tm
    # end for batch
    for ind in range(len(val)):
        num, timestamps, channels = val[ind].size()
        if timestamps < tm:
            zeros_3d, zeros_2d = torch.zeros(num, tm - timestamps, channels), torch.zeros(num, tm - timestamps)
            val[ind], time[ind] = torch.cat([val[ind], zeros_3d], dim=1), torch.cat([time[ind], zeros_2d], dim=1)
            mask[ind] = torch.cat([mask[ind], zeros_3d], dim=1)
    val, time, mask = torch.cat(val, dim=0), torch.cat(time, dim=0).unsqueeze(-1), torch.cat(mask, dim=0)
    windows, length = preprocess(torch.cat([val, mask, time], dim=-1), lower=args.pooling_lower)
    return {
        'windows': windows,
        'ts_length': length,
        "input_dim": val.size(-1),
        "timestamp": val.size(1),
    }
