import os

import numpy as np
import torch
from astropy.timeseries import LombScargleMultiband, LombScargle
from dtaidistance import dtw


__all__ = ['calc_corr', 'calc_corr_from_batch', 'calc_corr_penalty', 'calc_corr_p']


def _tonumpy(item):
    return item.cpu().detach().numpy()


def calc_corr_p1(loader, name, path: str = './', freq_point: int = 1000, freq_start: float = 0, freq_end: float = 10, is_tuple: bool = False):
    from scipy.interpolate import CubicSpline

    file = os.path.join(path, '%s_fp%d_fs%s_fe%s.npy' % (name, freq_point, str(freq_start), str(freq_end)))

    if os.path.exists(file):
        corr = np.load(file)
    else:
        dt_dtw, w = 0, 0
        for idx, item in enumerate(loader):
            val = item[0] if is_tuple else item
            dim = val.size(-1) // 2
            val = val.detach().cpu().numpy()
            val, mask, ts = val[..., :dim], val[..., dim:2 * dim], val[..., -1]

            data = val.copy()
            for i in range(val.shape[0]):
                timestamp, length = ts[i, :], np.argmax(ts[i, :]) + 1
                for c in range(dim):
                    xs, m = data[i, :, c], mask[i, :, c]
                    # print(idx, i, c, m.sum())
                    if m.sum() > 1:
                        cs = CubicSpline(timestamp[m == 1], xs[m == 1])
                        data[i, :, c][m == 0][:length] = cs(timestamp[m == 0][:length])
                    elif m.sum() == 1:
                        data[i, :, c][m == 0][:length] = data[i, m == 1, c]
                x, m = data[i, :length], mask[i, :length]
                distance = dtw.distance_matrix_fast(x.swapaxes(1, 0).astype(np.float64))

                num = m.sum(axis=0).reshape(1, -1)
                miss_num = length - num

                miss_num = np.tile(miss_num, (dim, 1))
                penalty_matrix = miss_num + miss_num.transpose()

                distance += penalty_matrix * 0.1

                num = np.tile(num, (dim, 1))
                w_matrix = (num + num.transpose()) / (2 * length)

                dt_dtw += distance * w_matrix
                w += w_matrix
            # end for i
            print('[%d/%d] done...' % (idx + 1, len(loader)))
        # end for idx
        dt_dtw[np.isinf(dt_dtw)] = 0
        dt_dtw[np.isnan(dt_dtw)] = 0

        dt_dtw /= w + 1e-5
        dt_dtw += dt_dtw.transpose()
        dt_dtw[np.where(np.eye(dt_dtw.shape[0], dtype=bool))] += np.min(dt_dtw[np.where(~np.eye(dt_dtw.shape[0], dtype=bool))])

        def normalization(data):
            _range = np.max(data) - np.min(data)
            return (data - np.min(data)) / _range

        corr = normalization(1 / dt_dtw)
        np.save(file, corr)
    return corr


def calc_corr_p(loader, name, path: str = './', freq_point: int = 1000, freq_start: float = 0, freq_end: float = 10, is_tuple: bool = False):
    from scipy.interpolate import CubicSpline
    from tslearn.metrics import cdist_gak, sigma_gak

    file = os.path.join(path, '%s_fp%d_fs%s_fe%s.npy' % (name, freq_point, str(freq_start), str(freq_end)))

    if not os.path.exists(file):
        corr = np.load(file)
    else:
        dt_dtw, w = 0, 0
        n_num, n_observed = np.zeros((36)), np.zeros((36))
        for idx, item in enumerate(loader):
            val = item[0] if is_tuple else item
            dim = val.size(-1) // 2
            val = val.detach().cpu().numpy()
            val, mask, ts = val[..., :dim], val[..., dim:2 * dim], val[..., -1]

            data = val.copy()
            for i in range(val.shape[0]):
                timestamp, length = ts[i, :], np.argmax(ts[i, :]) + 1
                for c in range(dim):
                    xs, m = data[i, :, c], mask[i, :, c]
                    # print(idx, i, c, m.sum())
                    if m.sum() > 1:
                        cs = CubicSpline(timestamp[m == 1], xs[m == 1])
                        # data[i, :, c][m == 0][:length] = cs(timestamp[m == 0][:length])
                        interp = cs(timestamp[:length][m[:length] == 0])
                        tmp = np.zeros((length))
                        tmp[m[:length] == 0] = interp.copy()
                        tmp[m[:length] == 1] = data[i, :, c][m == 1][:length]
                        tmp1 = np.zeros((m.shape[0]))
                        tmp1[:length] = tmp.copy()
                        data[i, :, c] = tmp1.copy()
                    elif m.sum() == 1:
                        data[i, :length, c] = data[i, m == 1, c]
                x, m = data[i, :length], mask[i, :length]

                n_num += length
                n_observed += m.sum(axis=0)

                # distance = np.zeros((dim, dim))
                # for c1 in range(dim):
                #     for c2 in range(c1, dim):
                #         if c1 == c2:
                #             distance[c1, c2] = 0
                #         else:
                #             distance[c1, c2] = gak(x[:, c1], x[:, c2], sigma=sigma_gak(x[:, [c1, c2]]))
                # distance = distance + distance.T
                distance = cdist_gak(x.T, x.T, sigma_gak(x.T))

                num = m.sum(axis=0).reshape(1, -1)
                miss_num = length - num

                miss_num = np.tile(miss_num, (dim, 1))
                penalty_matrix = miss_num + miss_num.transpose()

                zeros_index = penalty_matrix == 0
                penalty_matrix[zeros_index] = 1.
                penalty_matrix = 1 / (penalty_matrix * 0.1)
                penalty_matrix[zeros_index] = 0
                distance += penalty_matrix

                distance[np.isnan(distance)] = 0
                distance[np.isinf(distance)] = 0

                num = np.tile(num, (dim, 1))
                w_matrix = (num + num.transpose()) / (2 * length)

                dt_dtw += distance * w_matrix
                w += w_matrix
            # end for i
            print('[%d/%d] done...' % (idx + 1, len(loader)))
        # end for idx
        dt_dtw[np.isinf(dt_dtw)] = 0
        dt_dtw[np.isnan(dt_dtw)] = 0

        dt_dtw /= w + 1e-5
        dt_dtw += dt_dtw.transpose()
        # dt_dtw[np.where(np.eye(dt_dtw.shape[0], dtype=bool))] += np.min(dt_dtw[np.where(~np.eye(dt_dtw.shape[0], dtype=bool))])

        def normalization(data):
            _range = np.max(data) - np.min(data)
            return (data - np.min(data)) / _range

        # corr = normalization(1 / dt_dtw)
        corr = normalization(dt_dtw)
        np.save(file, corr)

        np.save('rate.npy', n_observed / n_num)
    return corr


def calc_corr_penalty(loader, name, path: str = './', freq_point: int = 1000, freq_start: float = 0, freq_end: float = 10):
    file = os.path.join(path, '%s_fp%d_fs%s_fe%s.npy' % (name, freq_point, str(freq_start), str(freq_end)))

    if os.path.exists(file):
        corr = np.load(file)
    else:
        distance, weight = 0, 0
        for idx, item in enumerate(loader):
            val, dim = item[1], item[1].size(-1)
            df_dtw, w = calc_distance_single(val, item[0].unsqueeze(-1), item[2], freq_point, freq_start, freq_end, flag='mimiciii' in name.lower())
            distance, weight = distance + df_dtw, weight + w
            if idx % 5 == 0:
                print('%d/%d end...' % (idx + 1, len(loader)))
            # end if
        # end for idx, item

        inf_idx = np.isinf(distance)
        distance[inf_idx] = 0
        nan_idx = np.isnan(distance)
        distance[nan_idx] = 0

        distance = distance / (weight + 1e-5)

        distance += distance.transpose()
        distance[np.where(np.eye(distance.shape[0], dtype=bool))] += np.min(distance[np.where(~np.eye(distance.shape[0], dtype=bool))])

        def normalization(data):
            _range = np.max(data) - np.min(data)
            return (data - np.min(data)) / _range

        corr = normalization(1 / distance)
        np.save(file, corr)
    # end if
    return corr


def calc_corr_from_batch(loader, name, path: str = './', freq_point: int = 1000, freq_start: float = 0, freq_end: float = 10, is_tuple: bool = False):
    file = os.path.join(path, '%s_fp%d_fs%s_fe%s.npy' % (name, freq_point, str(freq_start), str(freq_end)))

    if os.path.exists(file):
        corr = np.load(file)
    else:
        # val = []
        # for item in loader:
        #     val.append(item[0] if is_tuple else item)
        # val = torch.cat(val, dim=0)
        # dim = val.size(-1) // 2
        #
        # corr = calc_corr(val[..., :dim], val[..., dim:2 * dim], val[..., -1:], freq_point, freq_start, freq_end)
        # np.save(file, corr)

        corr = 0
        for idx, item in enumerate(loader):
            val = item[0] if is_tuple else item
            dim = val.size(-1) // 2
            c_matrix = calc_corr(val[..., :dim], val[..., -1:], val[..., dim:2 * dim], freq_point, freq_start, freq_end)
            # corr = c_matrix if idx == 0 else (0.9 * corr + 0.1 * c_matrix)
            corr += c_matrix / len(loader)
            if idx % 5 == 0:
                print('%d/%d end...' % (idx + 1, len(loader)))
        np.save(file, corr)

        # corr = []
        # for idx, item in enumerate(loader):
        #     val = item[0] if is_tuple else item
        #     dim = val.size(-1) // 2
        #     c_matrix = calc_corr_single(val[..., :dim], val[..., -1:], val[..., dim:2 * dim], freq_point, freq_start, freq_end)
        #     corr.append(c_matrix)
        #     if idx % 2 == 0:
        #         print('%d/%d end...' % (idx + 1, len(loader)))
        # corr = np.concatenate(corr, axis=0).mean(axis=0)
    return corr


def calc_corr_single(val, time, mask, freq_point: int = 1000, freq_start: float = 0, freq_end: float = 10):
    if isinstance(val, torch.Tensor):
        val, time, mask = _tonumpy(val), _tonumpy(time), _tonumpy(mask)

    def normalization(data):
        _range = np.max(data) - np.min(data)
        return (data - np.min(data)) / _range

    num, ts, dim = val.shape
    frequency, power, corr = np.linspace(freq_start, freq_end, freq_point), np.zeros((dim, freq_point)), np.zeros((num, dim, dim))

    for i in range(num):
        for d in range(dim):
            if mask[i, :, d].sum() > 0:
                power[d] = LombScargle(time[i, :, 0][mask[i, :, d] != 0], val[i, :, d][mask[i, :, d] != 0]).power(frequency)
            else:
                power[d] = np.zeros(freq_point)
        # end for d
        power[np.isnan(power)] = 0
        power[np.isinf(power)] = np.max(power) * 1e2
        distance = dtw.distance_matrix_fast(power)
        distance[np.bitwise_and(~np.eye(distance.shape[0], dtype=bool), distance == 0)] += np.max(distance)
        distance[np.where(np.eye(distance.shape[0], dtype=bool))] += np.min(distance[np.where(~np.eye(distance.shape[0], dtype=bool))])
        corr[i] = normalization(1 / distance.copy())
    # end for i
    return corr


def calc_distance_single(val, time, mask, freq_point: int = 1000, freq_start: float = 0, freq_end: float = 10, flag=False):
    if isinstance(val, torch.Tensor):
        val, time, mask = _tonumpy(val), _tonumpy(time), _tonumpy(mask)

    num, ts, dim = val.shape
    frequency, power, corr = np.linspace(freq_start, freq_end, freq_point), np.zeros((dim, freq_point)), np.zeros((num, dim, dim))

    df_dtw, w = 0, 0
    for i in range(num):
        for d in range(dim):
            if mask[i, :, d].sum() > 0:
                power[d] = LombScargle(time[i, :, 0][mask[i, :, d] != 0], val[i, :, d][mask[i, :, d] != 0]).power(frequency)
            else:
                power[d] = np.zeros(freq_point)
        # end for d
        power[np.isnan(power)] = 0
        power[np.isinf(power)] = np.max(power) * 1e2
        distance = dtw.distance_matrix_fast(power)
        distance[np.bitwise_and(~np.eye(distance.shape[0], dtype=bool), distance == 0)] += np.max(distance)

        missing = mask[i].sum(axis=0)
        missing = np.tile(missing, (dim, 1))
        w_matrix = (missing + missing.transpose()) / (2 * val.shape[1])

        distance[np.isnan(distance)] = 0
        distance[np.isinf(distance)] = 0

        if flag:
            w_matrix_renew = w_matrix.copy()
            w_matrix_renew[w_matrix == 0] = 1.
            df_dtw += distance * w_matrix_renew
        else:
            df_dtw += distance * w_matrix
        w += w_matrix
    # end for i
    return df_dtw, w


def calc_corr(val, time, mask, freq_point: int = 1000, freq_start: float = 0, freq_end: float = 10):
    if isinstance(val, torch.Tensor):
        val, time, mask = _tonumpy(val), _tonumpy(time), _tonumpy(mask)

    num, ts, dim = val.shape

    # multi-band (dim0 dim0 dim0 ...)
    def _comb(item: np.ndarray):
        return item.transpose((2, 0, 1)).reshape(item.shape[-1], -1)

    val, time, mask = _comb(val), _comb(time), _comb(mask)
    band = np.concatenate([d * np.ones(ts) for d in range(num)])
    frequency, power = np.linspace(freq_start, freq_end, freq_point), np.zeros((dim, freq_point))

    for d in range(dim):
        power[d] = LombScargleMultiband(time[0][mask[d] != 0], val[d][mask[d] != 0], band[mask[d] != 0]).power(frequency)

    # from matplotlib import pyplot as plt
    #
    # plt.figure()
    # for d in range(dim):
    #     plt.plot(frequency, power[d])
    # plt.show()
    #
    # plt.figure()
    # for d in range(dim):
    #     plt.plot(time[0], val[d], '.', label=str(d))
    # plt.legend()
    # plt.show()

    power[np.isnan(power)] = 0

    distance = dtw.distance_matrix_fast(power)
    # distance[np.bitwise_and(~np.eye(distance.shape[0], dtype=bool), distance == 0)] += 2 * np.max(distance)
    distance[np.bitwise_and(~np.eye(distance.shape[0], dtype=bool), distance == 0)] += np.max(distance)
    distance[np.where(np.eye(distance.shape[0], dtype=bool))] += np.min(distance[np.where(~np.eye(distance.shape[0], dtype=bool))])

    def normalization(data):
        _range = np.max(data) - np.min(data)
        return (data - np.min(data)) / _range

    return normalization(1 / distance)


if __name__ == '__main__':
    num, time, dim = 6, 200, 5
    val = np.random.random((num, time, dim))
    mask = (np.random.random((num, time, dim)) > 0.5).astype(np.float32)

    time = np.random.random((num, time, 1))
    time = (time - np.min(time)) / (np.max(time) - np.min(time))
    time = np.sort(time, axis=1)

    # func = lambda x: np.sin(2 * np.pi * x)
    # val = np.tile(func(time.transpose(0, 2, 1)).transpose((0, 2, 1)), (1, 1, dim))

    def func(x):
        val = np.empty_like(x)
        for d in range(x.shape[-1]):
            val[..., d] = np.sin((2 + d) * np.pi * x[..., d])
        return val

    mask = np.ones_like(val)
    val = func(np.tile(time, (1, 1, dim)))

    c = calc_corr(val, time, mask)

    from matplotlib import pyplot as plt
    import seaborn as sns

    print(c)

    plt.figure()
    sns.heatmap(abs(c), cmap='Reds')
    plt.show()

