import os

import numpy as np
import torch
from astropy.timeseries import LombScargleMultiband, LombScargle
from dtaidistance import dtw


__all__ = ['calc_corr', 'calc_corr_from_batch', 'calc_corr_penalty']


def _tonumpy(item):
    return item.cpu().detach().numpy()


def calc_corr_penalty(loader, name, path: str = './', freq_point: int = 1000, freq_start: float = 0, freq_end: float = 10, is_tuple: bool = False):
    file = os.path.join(path, '%s_fp%d_fs%s_fe%s.npy' % (name, freq_point, str(freq_start), str(freq_end)))

    if os.path.exists(file):
        corr = np.load(file)
    else:
        distance, weight = 0, 0
        for idx, item in enumerate(loader):
            val = item[0] if is_tuple else item
            dim = val.size(-1) // 2
            df_dtw, w = calc_distance_single(val[..., :dim], val[..., -1:], val[..., dim:2 * dim], freq_point, freq_start, freq_end)
            distance, weight = distance + df_dtw, weight + w
            if idx % 5 == 0:
                print('%d/%d end...' % (idx + 1, len(loader)))
            # end if
        # end for idx, item

        inf_idx = np.isinf(distance)
        distance[inf_idx] = 0
        nan_idx = np.isnan(distance)
        distance[nan_idx] = 0

        distance = distance / (weight + 1e-5)

        distance += distance.transpose()
        distance[np.where(np.eye(distance.shape[0], dtype=bool))] += np.min(distance[np.where(~np.eye(distance.shape[0], dtype=bool))])

        def normalization(data):
            _range = np.max(data) - np.min(data)
            return (data - np.min(data)) / _range

        corr = normalization(1 / distance)
        np.save(file, corr)
    # end if
    return corr


def calc_corr_from_batch(loader, name, path: str = './', freq_point: int = 1000, freq_start: float = 0, freq_end: float = 10, is_tuple: bool = False):
    file = os.path.join(path, '%s_fp%d_fs%s_fe%s.npy' % (name, freq_point, str(freq_start), str(freq_end)))

    if os.path.exists(file):
        corr = np.load(file)
    else:
        # val = []
        # for item in loader:
        #     val.append(item[0] if is_tuple else item)
        # val = torch.cat(val, dim=0)
        # dim = val.size(-1) // 2
        #
        # corr = calc_corr(val[..., :dim], val[..., dim:2 * dim], val[..., -1:], freq_point, freq_start, freq_end)
        # np.save(file, corr)

        corr = 0
        for idx, item in enumerate(loader):
            val = item[0] if is_tuple else item
            dim = val.size(-1) // 2
            c_matrix = calc_corr(val[..., :dim], val[..., -1:], val[..., dim:2 * dim], freq_point, freq_start, freq_end)
            # corr = c_matrix if idx == 0 else (0.9 * corr + 0.1 * c_matrix)
            corr += c_matrix / len(loader)
            if idx % 5 == 0:
                print('%d/%d end...' % (idx + 1, len(loader)))
        np.save(file, corr)

        # corr = []
        # for idx, item in enumerate(loader):
        #     val = item[0] if is_tuple else item
        #     dim = val.size(-1) // 2
        #     c_matrix = calc_corr_single(val[..., :dim], val[..., -1:], val[..., dim:2 * dim], freq_point, freq_start, freq_end)
        #     corr.append(c_matrix)
        #     if idx % 2 == 0:
        #         print('%d/%d end...' % (idx + 1, len(loader)))
        # corr = np.concatenate(corr, axis=0).mean(axis=0)
    return corr


def calc_corr_single(val, time, mask, freq_point: int = 1000, freq_start: float = 0, freq_end: float = 10):
    if isinstance(val, torch.Tensor):
        val, time, mask = _tonumpy(val), _tonumpy(time), _tonumpy(mask)

    def normalization(data):
        _range = np.max(data) - np.min(data)
        return (data - np.min(data)) / _range

    num, ts, dim = val.shape
    frequency, power, corr = np.linspace(freq_start, freq_end, freq_point), np.zeros((dim, freq_point)), np.zeros((num, dim, dim))

    for i in range(num):
        for d in range(dim):
            if mask[i, :, d].sum() > 0:
                power[d] = LombScargle(time[i, :, 0][mask[i, :, d] != 0], val[i, :, d][mask[i, :, d] != 0]).power(frequency)
            else:
                power[d] = np.zeros(freq_point)
        # end for d
        power[np.isnan(power)] = 0
        power[np.isinf(power)] = np.max(power) * 1e2
        distance = dtw.distance_matrix_fast(power)
        distance[np.bitwise_and(~np.eye(distance.shape[0], dtype=bool), distance == 0)] += np.max(distance)
        distance[np.where(np.eye(distance.shape[0], dtype=bool))] += np.min(distance[np.where(~np.eye(distance.shape[0], dtype=bool))])
        corr[i] = normalization(1 / distance.copy())
    # end for i
    return corr


def calc_distance_single(val, time, mask, freq_point: int = 1000, freq_start: float = 0, freq_end: float = 10):
    if isinstance(val, torch.Tensor):
        val, time, mask = _tonumpy(val), _tonumpy(time), _tonumpy(mask)

    num, ts, dim = val.shape
    frequency, power, corr = np.linspace(freq_start, freq_end, freq_point), np.zeros((dim, freq_point)), np.zeros((num, dim, dim))

    df_dtw, w = 0, 0
    for i in range(num):
        for d in range(dim):
            if mask[i, :, d].sum() > 0:
                power[d] = LombScargle(time[i, :, 0][mask[i, :, d] != 0], val[i, :, d][mask[i, :, d] != 0]).power(frequency)
            else:
                power[d] = np.zeros(freq_point)
        # end for d
        power[np.isnan(power)] = 0
        power[np.isinf(power)] = np.max(power) * 1e2
        distance = dtw.distance_matrix_fast(power)
        distance[np.bitwise_and(~np.eye(distance.shape[0], dtype=bool), distance == 0)] += np.max(distance)

        missing = mask[i].sum(axis=0)
        missing = np.tile(missing, (dim, 1))
        w_matrix = (missing + missing.transpose()) / (2 * val.shape[1])

        distance[np.isnan(distance)] = 0
        distance[np.isinf(distance)] = 0

        df_dtw += distance * w_matrix
        w += w_matrix
    # end for i
    return df_dtw, w


def calc_corr(val, time, mask, freq_point: int = 1000, freq_start: float = 0, freq_end: float = 10):
    if isinstance(val, torch.Tensor):
        val, time, mask = _tonumpy(val), _tonumpy(time), _tonumpy(mask)

    num, ts, dim = val.shape

    # multi-band (dim0 dim0 dim0 ...)
    def _comb(item: np.ndarray):
        return item.transpose((2, 0, 1)).reshape(item.shape[-1], -1)

    val, time, mask = _comb(val), _comb(time), _comb(mask)
    band = np.concatenate([d * np.ones(ts) for d in range(num)])
    frequency, power = np.linspace(freq_start, freq_end, freq_point), np.zeros((dim, freq_point))

    for d in range(dim):
        power[d] = LombScargleMultiband(time[0][mask[d] != 0], val[d][mask[d] != 0], band[mask[d] != 0]).power(frequency)

    # from matplotlib import pyplot as plt
    #
    # plt.figure()
    # for d in range(dim):
    #     plt.plot(frequency, power[d])
    # plt.show()
    #
    # plt.figure()
    # for d in range(dim):
    #     plt.plot(time[0], val[d], '.', label=str(d))
    # plt.legend()
    # plt.show()

    power[np.isnan(power)] = 0

    distance = dtw.distance_matrix_fast(power)
    # distance[np.bitwise_and(~np.eye(distance.shape[0], dtype=bool), distance == 0)] += 2 * np.max(distance)
    distance[np.bitwise_and(~np.eye(distance.shape[0], dtype=bool), distance == 0)] += np.max(distance)
    distance[np.where(np.eye(distance.shape[0], dtype=bool))] += np.min(distance[np.where(~np.eye(distance.shape[0], dtype=bool))])

    def normalization(data):
        _range = np.max(data) - np.min(data)
        return (data - np.min(data)) / _range

    return normalization(1 / distance)


if __name__ == '__main__':
    num, time, dim = 6, 200, 5
    val = np.random.random((num, time, dim))
    mask = (np.random.random((num, time, dim)) > 0.5).astype(np.float32)

    time = np.random.random((num, time, 1))
    time = (time - np.min(time)) / (np.max(time) - np.min(time))
    time = np.sort(time, axis=1)

    # func = lambda x: np.sin(2 * np.pi * x)
    # val = np.tile(func(time.transpose(0, 2, 1)).transpose((0, 2, 1)), (1, 1, dim))

    def func(x):
        val = np.empty_like(x)
        for d in range(x.shape[-1]):
            val[..., d] = np.sin((2 + d) * np.pi * x[..., d])
        return val

    mask = np.ones_like(val)
    val = func(np.tile(time, (1, 1, dim)))

    c = calc_corr(val, time, mask)

    from matplotlib import pyplot as plt
    import seaborn as sns

    print(c)

    plt.figure()
    sns.heatmap(abs(c), cmap='Reds')
    plt.show()

