import numpy as np
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
import pandas as pd

def UCR_load(data_name):
    train_file = 'dataset/UCRArchive_2018/' + data_name + '/' + data_name + '_TRAIN.tsv'
    test_file = 'dataset/UCRArchive_2018/' + data_name + '/' + data_name + '_TEST.tsv'

    train = pd.read_csv(train_file, sep="\t", header=None)
    test = pd.read_csv(test_file, sep="\t", header=None)

    data = pd.concat((train, test))

    y = data.values[:, 0].astype(np.int32)
    num_classes = len(np.unique(y))
    y = (y - y.min()) / (y.max() - y.min()) * (num_classes - 1)
    
    x = data.values[:, 1:].astype(np.float32)
    if x.ndim == 2:
        x = np.expand_dims(x, axis=2)

    train_x, test_x, train_y, test_y = train_test_split(x, y, test_size=0.2, shuffle=True, stratify=y)

    return train_x, train_y, test_x, test_y

def semi_setting(train_x, train_y, test_x, test_y, normalization=False, label_ratio=0.1):
    train_x, val_x, train_y, val_y = train_test_split(train_x, train_y, test_size=0.25, shuffle=True, stratify=train_y)
    train_unlabeled_x, train_labeled_x, train_unlabeled_y, train_labeled_y = train_test_split(train_x, train_y, test_size=label_ratio, shuffle=True, stratify=train_y)

    if normalization:
        # z-score normalization
        mean = train_x.mean()
        std = train_x.std()

        train_labeled_x = (train_labeled_x - mean) / std
        train_unlabeled_x = (train_unlabeled_x - mean) / std
        val_x = (val_x - mean) / std
        test_x = (test_x - mean) / std

    train_labeled_x = torch.from_numpy(train_labeled_x)
    train_unlabeled_x = torch.from_numpy(train_unlabeled_x)
    val_x = torch.from_numpy(val_x)
    test_x = torch.from_numpy(test_x)
    train_labeled_y = torch.from_numpy(train_labeled_y)
    train_unlabeled_y = torch.from_numpy(train_unlabeled_y)
    val_y = torch.from_numpy(val_y)
    test_y = torch.from_numpy(test_y)

    # print('-----------Data Information-----------')
    # print('Train Labeled X Shape: ', train_labeled_x.shape)
    # print('Train Labeled Y Shape: ', train_labeled_y.shape)
    # print('Train Unlabeled X Shape: ', train_unlabeled_x.shape)
    # print('Train Unlabeled Y Shape: ', train_unlabeled_y.shape)
    # print('Validation X Shape: ', val_x.shape)
    # print('Validation Y Shape: ', val_y.shape)
    # print('Test X Shape: ', test_x.shape)
    # print('Test X Shape: ', test_y.shape)
    # print('--------------------------------------')

    return train_labeled_x, train_labeled_y, train_unlabeled_x, train_unlabeled_y, val_x, val_y, test_x, test_y

class Dataset(Dataset):

    def __init__(self, x, y, train=True):
        self.x = x
        self.y = y

    def __len__(self):
        return self.x.shape[0]

    def __getitem__(self, idx):
        return torch.from_numpy(np.array(self.x[idx])), torch.from_numpy(
            np.array(self.y[idx])), torch.from_numpy(np.array(idx))