import os
import numpy as np
import pandas as pd
from scipy.spatial.distance import euclidean
from fastdtw import fastdtw
from tslearn.metrics import dtw, dtw_path,gak
import tqdm

from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances

def find(condition):
    res, = np.nonzero(np.ravel(condition))
    return res

def tam(path, report='full'):
    """
    Calculates the Time Alignment Measurement (TAM) based on an optimal warping path
    between two time series.
    Reference: Folgado et. al, Time Alignment Measurement for Time Series, 2018.

    :param path: (ndarray)
                A nested array containing the optimal warping path between the
                two sequences.
    :param report: (string)
                A string containing the report mode parameter.
    :return:    In case ``report=instants`` the number of indexes in advance, delay and phase
                will be returned. For ``report=ratios``, the ratio of advance, delay and phase
                will be returned. In case ``report=distance``, only the TAM will be returned.

    """
    # Delay and advance counting
    delay = len(find(np.diff(path[0]) == 0))
    advance = len(find(np.diff(path[1]) == 0))

    # Phase counting
    incumbent = find((np.diff(path[0]) == 1) * (np.diff(path[1]) == 1))
    phase = len(incumbent)

    # Estimated and reference time series duration.
    len_estimation = path[1][-1]
    len_ref = path[0][-1]

    p_advance = advance * 1. / len_ref
    p_delay = delay * 1. / len_estimation
    p_phase = phase * 1. / np.min([len_ref, len_estimation])

    
    return p_advance + p_delay + (1 - p_phase)

def get_DTW(UTS_tr):
    N = len(UTS_tr)
    DTW_matrix = np.zeros((N,N))
    for i in tqdm.tqdm(range(N)):
        for j in range(N):
            if i>j:
                dist = dtw(UTS_tr[i].reshape(-1,1), UTS_tr[j].reshape(-1,1))
                DTW_matrix[i,j] = dist
                DTW_matrix[j,i] = dist
            elif i==j:
                DTW_matrix[i,j] = 0
            else :
                pass
    return DTW_matrix

def get_TAM(UTS_tr):
    N = len(UTS_tr)
    DTW_matrix = np.zeros((N,N))
    for i in tqdm.tqdm(range(N)):
        for j in range(N):
            if i>j:
                #_,_,_,p = dtw(UTS_tr[i], UTS_tr[j])
                k = dtw_path(UTS_tr[i].reshape(-1,1), UTS_tr[j].reshape(-1,1))[0]
                a = [i[0] for i in k]
                b = [i[1] for i in k]
                p = [np.array(a),np.array(b)]
                dist = tam(p)
                DTW_matrix[i,j] = dist
                DTW_matrix[j,i] = dist
            elif i==j:
                DTW_matrix[i,j] = 0
            else :
                pass
    return DTW_matrix

def get_GAK(UTS_tr):
    N = len(UTS_tr)
    DTW_matrix = np.zeros((N,N))
    for i in tqdm.tqdm(range(N)):
        for j in range(N):
            if i>j:
                dist = gak(UTS_tr[i].reshape(-1,1), UTS_tr[j].reshape(-1,1))
                DTW_matrix[i,j] = dist
                DTW_matrix[j,i] = dist
            elif i==j:
                DTW_matrix[i,j] = 0
            else :
                pass
    return DTW_matrix


def get_MDTW(MTS_tr):
    N = MTS_tr.shape[0]
    DTW_matrix = np.zeros((N,N))
    for i in tqdm.tqdm(range(N)):
        for j in range(N):
            if i>j:
                mdtw_dist = dtw(MTS_tr[i], MTS_tr[j])
                DTW_matrix[i,j] = mdtw_dist
                DTW_matrix[j,i] = mdtw_dist
            elif i==j:
                DTW_matrix[i,j] = 0
            else :
                pass
    return DTW_matrix

def get_COS(MTS_tr):
    cos_sim_matrix = -cosine_similarity(MTS_tr)
    return cos_sim_matrix

def get_EUC(MTS_tr):
    return euclidean_distances(MTS_tr)

def save_dtw_similarity(X_tr, min_ = 0, max_ = 1, multivariate=False, type_='DTW'):
    if multivariate:
        assert type=='DTW'
        DTW_dist = get_MDTW(X_tr)
    else:
        if type_=='DTW':
            DTW_dist = get_DTW(X_tr)
        elif type_=='TAM':
            DTW_dist = get_TAM(X_tr)
        elif type_=='COS':
            DTW_dist = get_COS(X_tr)
        elif type_=='EUC':
            DTW_dist = get_EUC(X_tr)
        elif type_=='GAK':
            DTW_dist = get_GAK(X_tr)
        
        
    diag_indices = np.diag_indices(DTW_dist.shape[0])
    mask = np.ones(DTW_dist.shape, dtype=bool)
    mask[diag_indices] = False
    temp = DTW_dist[mask].reshape(DTW_dist.shape[0], DTW_dist.shape[1]-1)
    diag_indices = np.diag_indices(DTW_dist.shape[0])
    DTW_dist[diag_indices] = temp.min()
    scaler = MinMaxScaler(feature_range=(min_, max_))
    
    DTW_dist_scaled = scaler.fit_transform(DTW_dist)
    DTW_sim = 1 - DTW_dist_scaled 
    return DTW_sim 

'''
def save_dtw_similarity(X_tr,multivariate=False):
    if multivariate:
        DTW_dist = get_MDTW(X_tr)
    else:
        DTW_dist = get_DTW(X_tr)
    DTW_dist = DTW_dist/DTW_dist.max()
    DTW_sim = 1-DTW_dist
    return DTW_sim 


def save_dtw_similarity(X_tr, min_ = 0, max_ = 1, multivariate=False):
    if multivariate:
        DTW_dist = get_MDTW(X_tr)
    else:
        DTW_dist = get_DTW(X_tr)
        
    diag_indices = np.diag_indices(DTW_dist.shape[0])
    mask = np.ones(DTW_dist.shape, dtype=bool)
    mask[diag_indices] = False
    temp = DTW_dist[mask].reshape(DTW_dist.shape[0], DTW_dist.shape[1]-1)
    diag_indices = np.diag_indices(DTW_dist.shape[0])
    max_values = np.min(temp, axis=1)
    DTW_dist[diag_indices] = max_values
    scaler = MinMaxScaler(feature_range=(min_, max_))
    
    DTW_dist_scaled = scaler.fit_transform(DTW_dist)
    DTW_sim = 1 - DTW_dist_scaled 
    return DTW_sim 
'''    
def get_example_data(data_name):
    ex_data_path = f'./data/UCR/{data_name}/{data_name}_TRAIN.tsv'
    data = pd.read_csv(ex_data_path, delimiter='\t', keep_default_na=False, header=None)
    data_X = data.iloc[:,1:]
    data_y = data.iloc[:,0]
    return data_X,data_y
    
def set_nan_to_zero(a):
    where_are_NaNs = np.isnan(a)
    a[where_are_NaNs] = 0
    return a

def densify(x, tau, alpha):
    return ((2*alpha) / (1 + np.exp(-tau*x))) + (1-alpha)*np.eye(x.shape[0])


def convert_top_k_to_one_zero(matrix, k):
    """
    Convert the off-diagonal elements to one if they are in the top-k values of each row, and zero if not
    """
    # Copy the matrix to avoid modifying the original matrix
    new_matrix = matrix.copy()
    
    # Get the indices of the top-k values in each row
    top_k_indices = np.argpartition(new_matrix, -(k+1), axis=1)[:, -(k+1):]
    
    # Set the off-diagonal elements to zero
    np.fill_diagonal(new_matrix, 0)
    
    # Create a mask for the top-k indices
    mask = np.zeros_like(new_matrix)
    mask[np.repeat(np.arange(new_matrix.shape[0]), (k+1)),
         top_k_indices.flatten()] = 1
    
    return mask

def hard_DTW_matrix(soft_DTW, pos_ratio):
    num_pos = int((soft_DTW.shape[0]-1)*pos_ratio)
    hard_DTW = convert_top_k_to_one_zero(soft_DTW,num_pos)
    return hard_DTW