import numpy as np
import random
import torch
from torchvision.transforms import Lambda
import warnings
warnings.filterwarnings('ignore')

def TSTransform(augmentation=True, contrastive=True):
    if not (augmentation or contrastive): return None
    # weak_aug = scaling(sample)
    # strong_aug = jitter(permutation(sample))
    if contrastive:
        return Lambda(lambda sample: (scaling(sample), jitter(permutation(sample))))
    else:
        return Lambda(lambda sample: random.choice([scaling(sample), jitter(permutation(sample))]))


def jitter(x, sigma=0.05):
    new = x + np.random.normal(loc=0., scale=sigma, size=x.shape)
    return torch.tensor(new, dtype=x.dtype, device=x.device)


def scaling(x, sigma=0.1):
    scalingFactor = np.random.normal(loc=1.0, scale=sigma, size=(1,x.shape[1])) # shape=(1,3)
    myNoise = np.matmul(np.ones((x.shape[0],1)), scalingFactor)
    new = x*myNoise
    return torch.tensor(new, dtype=x.dtype, device=x.device)


def permutation(x, max_segments=5, seg_mode="random"):
    # orig_steps = np.arange(x.shape[2])

    # num_segs = np.random.randint(1, max_segments, size=(x.shape[0]))

    # ret = np.zeros_like(x)
    # for i, pat in enumerate(x):
    #     if num_segs[i] > 1:
    #         if seg_mode == "random":
    #             split_points = np.random.choice(x.shape[2] - 2, num_segs[i] - 1, replace=False)
    #             split_points.sort()
    #             splits = np.split(orig_steps, split_points)
    #         else:
    #             splits = np.array_split(orig_steps, num_segs[i])
    #         warp = np.concatenate(np.random.permutation(splits)).ravel()
    #         ret[i] = pat[0,warp]
    #     else:
    #         ret[i] = pat
    # return torch.from_numpy(ret)
    x_new = np.zeros(x.shape)
    nPerm = np.random.randint(1, max_segments)
    idx = np.random.permutation(nPerm)

    segs = np.zeros(nPerm+1, dtype=int)
    segs[1:-1] = np.sort(np.random.randint(0, x.shape[0], nPerm-1))
    segs[-1] = x.shape[0]

    pp = 0
    for ii in range(nPerm):
        x_temp = x[segs[idx[ii]]:segs[idx[ii]+1],:]
        x_new[pp:pp+len(x_temp),:] = x_temp
        pp += len(x_temp)
    return torch.tensor(x_new, dtype=x.dtype, device=x.device)
