#coding=utf-8
from datautil.imgdata.util import image_test
from torchvision import transforms
import numpy as np
import math
import torch
import random

# DSADS 0:sitting, 1:standing, 2-3:lying, 4 ascending, 5:descending, 8: walking
# usc 0 walking, 3 upstairs, 4 downstairs, 7 sitting, 8 standing, 9 lying
# har WALKING, WALKING_UPSTAIRS, WALKING_DOWNSTAIRS, SITTING, STANDING, LAYING
# pamap 0 lying, 1 sitting, 2 standing, 3 walking, 7 ascending, 8 descending

def load_shuffle(x,puzzle=True,grid=5):
    if puzzle:
        lg=int(x.shape[-1]/grid)
        tile=[x[:,:,:,i*lg:(i+1)*lg] for i in range(grid)]
        random.shuffle(tile)
        x=torch.cat(tile,dim=-1)
    return x

def split_trian_val_test(da,rate=0.8,seed=0):
    dsize = len(da)
    tr_size = int(rate*dsize)
    tr_da, te_da = torch.utils.data.random_split(da, [tr_size, dsize - tr_size],generator=torch.Generator().manual_seed(seed))
    return tr_da,te_da

def act_train():
    return transforms.Compose([
        transforms.ToTensor()
    ])

def act_test():
    return transforms.Compose([
        transforms.ToTensor()
    ])

def loaddata_from_numpy(dataset='dsads',task='cross_people',root_dir='./data/act/'):
    if dataset=='pamap' and task=='cross_people':
        x=np.load(root_dir+dataset+'/'+dataset+'_x1.npy')
        ty=np.load(root_dir+dataset+'/'+dataset+'_y1.npy')
    else:     
        x=np.load(root_dir+dataset+'/'+dataset+'_x.npy')
        ty=np.load(root_dir+dataset+'/'+dataset+'_y.npy')
    cy,py,sy=ty[:,0],ty[:,1],ty[:,2]
    return x,cy,py,sy

def seq_cut(x,tl):
    tl=int(tl)
    # print(x.shape,tl)
    l=x.shape[-1]
    dl=l-tl
    s=dl//2
    if dl%2==0:
        x=x[:,:,s:l-s]
    else:
        x=x[:,:,s:l-s-1]
    return x

def split_via_len(x,cy,py,sy,sl):
    l=x.shape[-1]
    tl=l-l%sl
    x=seq_cut(x,tl)
    l=x.shape[-1]
    d=int(l/sl)
    x=np.vstack([np.array([item[:,i*d:(i+1)*d] for i in range(d)]) for item in x])
    cy=np.hstack([ np.array([item for _ in range(d)])  for item in cy])
    py=np.hstack([ np.array([item for _ in range(d)])  for item in py])
    sy=np.hstack([ np.array([item for _ in range(d)])  for item in sy])
    return x,cy,py,sy

def split_via_wind_3(x,cy,py,sy,sl,step):
    l=x.shape[-1]
    tl=l-l%sl
    tx,tcy,tpy,tsy=[],[],[],[]
    for i in range(len(x)):
        start=0
        end=sl
        while end<=l:
            tx.append(x[i,:,start:end])
            tcy.append(cy[i])
            tpy.append(py[i])
            tsy.append(sy[i])
            start+=step
            end+=step
    return np.array(tx),np.array(tcy),np.array(tpy),np.array(tsy)


def split_via_wind_1(x,cy,sl,step):
    l=x.shape[-1]
    tl=l-l%sl
    tx,tcy=[],[]
    for i in range(len(x)):
        start=0
        end=sl
        while end<=l:
            tx.append(x[i,:,start:end])
            tcy.append(cy[i])
            start+=step
            end+=step
    return np.array(tx),np.array(tcy)

def seq_downsapmle(x,ohz,thz):
    t=int(ohz*thz/math.gcd(ohz,thz))
    do=int(t/ohz)
    to=int(t/thz)
    l=x.shape[-1]
    x=seq_cut(x,l-l%to)
    # print(x.shape)
    l=x.shape[-1]
    d=int(l/to)
    # print(d)
    # print(do)
    # print(np.random.choice(np.arange(to),do,replace=False))
    # print([np.arange(to*i,to*i+to)[np.random.choice(np.arange(to),do,replace=False)] for i in range(d)])
    x=np.concatenate([np.array(x[:,:,np.arange(to*i,to*i+to)[np.random.choice(np.arange(to),do,replace=False)]]) for i in range(d)],axis=-1)
    return x