import csv
import math
import os
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import numpy as np
from tqdm import tqdm
from os import listdir
from os.path import isfile, join


class N_Tidigits(Dataset):
    def __init__(self, dataset_path, n_steps, transform=None):
        self.path = dataset_path
        self.samples = []
        self.labels = []
        self.transform = transform
        self.n_steps = n_steps
        for i in tqdm(range(11)):
            sample_dir = dataset_path + '/' + str(i) + '/'
            for f in listdir(sample_dir):
                filename = join(sample_dir, f)
                if isfile(filename):
                    self.samples.append(filename)
                    self.labels.append(i)

    def __getitem__(self, index):
        filename = self.samples[index]
        label = self.labels[index]

        data = np.zeros((64, self.n_steps))

        f = open(filename, 'r')
        lines = f.readlines()
        channel = -1
        for line in lines:
            channel += 1
            if line is None:
                continue
            line = line.split()
            line = [int(l) for l in line]
            for i in range(len(line)):
                t = int(math.floor(line[i]/2))
                if t >= self.n_steps:
                    break
                data[channel, t] = 1
        if self.transform:
            data = self.transform(data)
            data = data.type(torch.float32)
        else:
            data = torch.FloatTensor(data)
        data = data.view(64, 1, 1, self.n_steps)

        return data, label

    def __len__(self):
        return len(self.samples)


def get_n_tidigits(data_path, network_config):
    n_steps = network_config['n_steps']
    batch_size = network_config['batch_size']
    print("loading N-Tidigits")
    if not os.path.exists(data_path):
        os.mkdir(data_path)
    train_path = data_path + '/Train'
    test_path = data_path + '/Test'
    all_train_set = N_Tidigits(train_path, n_steps)  
    train_set, val_set = torch.utils.data.random_split(all_train_set, [int((len(all_train_set)-1)/2), int((len(all_train_set)+1)/2)])
    test_set = N_Tidigits(test_path, n_steps)
    print("training samples loaded: %d" % len(all_train_set))
    print("testing samples loaded: %d" % len(test_set))
    print("arch training samples loaded: %d" % len(train_set))
    print("arch validating samples loaded: %d" % len(val_set))
    return all_train_set, train_set, val_set, test_set
