import os
import torch
from torch.utils.data import Dataset
import numpy as np
from tqdm import tqdm
from os import listdir
from os.path import isfile, join
import torchvision.transforms as transforms


class TIAlaph(Dataset):
    def __init__(self, dataset_path, n_steps, sample_per_class, n_class, transform=None):
        self.path = dataset_path
        self.samples = []
        self.labels = []
        self.transform = transform
        self.n_steps = n_steps
        for i in range(n_class):
            sample_dir = dataset_path + '/' + str(i) + '/'
            count = 0
            for f in listdir(sample_dir):
                filename = join(sample_dir, f)
                if isfile(filename):
                    self.samples.append(filename)
                    self.labels.append(i)
                    count += 1
                if sample_per_class != -1 and count>=sample_per_class:
                    break

    def __getitem__(self, index):
        data_path = self.samples[index]
        label = self.labels[index]
        tmp = np.genfromtxt(data_path) * 10000

        data = np.zeros((78, self.n_steps))
        for i in range(78):
            l = len(tmp[i, :])
            if l <= self.n_steps:
                data[i, 0:l] = tmp[i, :]
            else:
                data[i, 0:self.n_steps] = tmp[i, l - self.n_steps:l]
        if self.transform:
            data = self.transform(data)
            data = data.type(torch.float32).view(78, 1, 1, self.n_steps)
        else:
            data = torch.FloatTensor(data).view(78, 1, 1, self.n_steps)

        return data, label

    def __len__(self):
        return len(self.samples)


def get_tialpha(data_path, network_config):
    print("loading TI46")
    if not os.path.exists(data_path):
        os.mkdir(data_path)
    batch_size = network_config['batch_size']
    n_steps = network_config['n_steps']
    n_class = network_config['n_class']
    sample_per_class = network_config['sample_per_class']
    train_path = data_path + '/train'
    test_path = data_path + '/test'
    transform = transforms.Compose([transforms.ToTensor()])  
    all_train_set = TIAlaph(train_path, n_steps, sample_per_class, n_class, transform)
    train_set, val_set = torch.utils.data.random_split(all_train_set, [int(len(all_train_set)/2), int(len(all_train_set)/2)])
    test_set = TIAlaph(test_path, n_steps, sample_per_class, n_class, transform)
    print("training samples loaded: %d" % len(all_train_set))
    print("testing samples loaded: %d" % len(test_set))
    print("arch training samples loaded: %d" % len(train_set))
    print("arch validating samples loaded: %d" % len(val_set))
    return all_train_set, train_set, val_set, test_set

