import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

##############################################################################
## 6 Different ways of generating time lags
##############################################################################
def generate_TIMELAG_sigmoid(T,sigma=1):
    dist = np.arange(T)
    dist = np.abs(dist - dist[:, np.newaxis])
    matrix = 2 / (1 +np.exp(dist*sigma))
    matrix = np.where(matrix < 1e-6, 0, matrix)  # set very small values to 0         
    return matrix

def generate_TIMELAG_gaussian(T,sigma):
    dist = np.arange(T)
    dist = np.abs(dist - dist[:, np.newaxis])
    matrix = np.exp(-(dist**2)/(2 * sigma ** 2))
    matrix = np.where(matrix < 1e-6, 0, matrix) 
    return matrix

def generate_TIMELAG_same_interval(T,sigma=0.3):
    d = np.arange(T)
    X, Y = np.meshgrid(d, d)
    matrix = 1 - np.abs(X - Y) / T
    return matrix

def generate_TIMELAG_sigmoid_window(T,sigma=1, window_ratio=1.0):
    dist = np.arange(T)
    dist = np.abs(dist - dist[:, np.newaxis])
    matrix = 2 / (1 +np.exp(dist*sigma))
    matrix = np.where(matrix < 1e-6, 0, matrix)          
    dist_from_diag = np.abs(np.subtract.outer(np.arange(dist.shape[0]), np.arange(dist.shape[1])))
    matrix[dist_from_diag > T*window_ratio] = 0
    return matrix

def generate_TIMELAG_sigmoid_threshold(T,sigma=1, threshold=1.0):
    dist = np.ones((T,T))
    dist_from_diag = np.abs(np.subtract.outer(np.arange(dist.shape[0]), np.arange(dist.shape[1])))
    dist[dist_from_diag > T*threshold] = 0
    return dist
##############################################################################

def duplicate_DTW(DTW):
    DTW0 = torch.tril(DTW, diagonal=-1)[:, :-1]   
    DTW0 += torch.triu(DTW, diagonal=1)[:, 1:]
    DTW1 = torch.cat([DTW0,DTW],dim=1)
    DTW2 = torch.cat([DTW,DTW0],dim=1)
    return DTW1, DTW2