import torch
import random, gc
from torch.utils.data import  Dataset
from datetime import datetime, timedelta
from typing import List
from model.mydataclass import batch_train_data
import copy
import os.path as path

class MyDataSet(Dataset):

    def __init__(self) -> None:
        super().__init__()

class MyDataLoader(object):

    def __init__(self, data_file, batch_size: int=128, batch_num: int=-1, shuffle: bool=True, load_smiles: bool=False):
        self.data_file = data_file
        self.batch_num = batch_num
        self.shuffle = shuffle
        self.load_smiles = load_smiles
        with open(data_file, 'rb') as f:
            self.batches: List[batch_train_data] = torch.load(f)
        if load_smiles:
            smiles = [smi.strip("\r\n") for smi in open(path.splitext(data_file)[0] + '.smiles')]
            self.batch_smiles = [[] for _ in range(0, len(smiles), batch_size)]
            num = len(self.batch_smiles)
            for i, smi in enumerate(smiles):
                self.batch_smiles[i % num].append(smi)

            assert len(self.batches) == len(self.batch_smiles)
        if self.batch_num > 0:
            self.batches = self.batches[:self.batch_num]
            if load_smiles:
                self.batch_smiles = self.batch_smiles[:self.batch_num]
        self.pt = 0

    def __len__(self):
        return len(self.batches)

    def __next__(self):
        self.pt = (self.pt + 1) % len(self.batches)
        if self.load_smiles:
            return copy.deepcopy(self.batches[self.pt - 1]), self.batch_smiles[self.pt - 1]
        else:
            return copy.deepcopy(self.batches[self.pt - 1])

    def __iter__(self):
        if self.shuffle:
            idx = list(range(len(self.batches)))
            random.shuffle(idx)
            self.idx = idx
            self.batches_ = [self.batches[i] for i in idx]
            if self.load_smiles:
                self.batch_smiles_ = [self.batch_smiles[i] for i in idx]
        else:
            self.idx = list(range(len(self.batches)))
            self.batches_, self.batch_smiles_ = self.batches, self.batch_smiles

        if self.load_smiles:
            for batch, smis in zip(self.batches_, self.batch_smiles_):
                self.pt = (self.pt + 1) % len(self.batches)
                yield copy.deepcopy(batch), smis
        else:
            for batch in self.batches_:
                self.pt = (self.pt + 1) % len(self.batches)
                yield copy.deepcopy(batch)
