import torch
import os
import random
import numpy as np


class TomitaDatasetPartialClassification(torch.utils.data.Dataset):
    def __init__(self, grammar, suffix, ttype, nbin = 0):
        super(TomitaDatasetPartialClassification, self).__init__()
        pos_src = torch.load(self.get_fpath(suffix, grammar, ttype, nbin, "src", pos=True))
        pos_target = torch.load(self.get_fpath(suffix, grammar, ttype, nbin, "partial_target", pos=True))
        neg_src = torch.load(self.get_fpath(suffix, grammar, ttype, nbin, "src", pos=False))
        neg_target = torch.load(self.get_fpath(suffix, grammar, ttype, nbin, "partial_target", pos = False))

        self.data = torch.cat([pos_src["strings"], neg_src["strings"]])
        self.lengths =  torch.cat([pos_src["lengths"], neg_src["lengths"]])
        self.target = torch.cat([pos_target["strings"], neg_target["strings"]])
        tgt_lengths = torch.cat([pos_target["lengths"], neg_target["lengths"]])

        equal_lengths = torch.sum((self.lengths == tgt_lengths).type(torch.float))
        assert equal_lengths.item() == torch.prod(torch.tensor(self.lengths.size())).item()
        
        self.indices = list(range(self.data.shape[0]))
        random.shuffle(self.indices)


    def get_fpath(self, suffix, grammar, ttype, nbin, st, pos = True):
        if pos:
            fpath = "../data/{0}_processed/Tomita-{1}/{2}_{3}.pt".format(suffix, grammar, ttype, st)
        else:
            fpath = "../data/{0}_processed/Tomita-{1}/{2}_neg_{3}.pt".format(suffix, grammar, ttype, st)

        if ttype == "test":
            fpath = fpath[:-3] + "_bin{}".format(nbin) + ".pt"

        return fpath


    def __len__(self):
        return len(self.indices)

    def __getitem__(self, sidx):
        idx = self.indices[sidx]
        return self.data[idx], self.target[idx], self.lengths[idx]


class DyckDataLoader(torch.utils.data.Dataset):
    def __init__(self, pos_file, neg_file):
        super(DyckDataLoader, self).__init__()
        pos_data = np.load(pos_file)
        neg_data = np.load(neg_file)
        
        pos_inp = torch.from_numpy(pos_data["inp"])
        #pos_stack = torch.from_numpy(pos_data["stack"])
        pos_len = torch.from_numpy(pos_data["inp_len"])

        neg_inp = torch.from_numpy(neg_data["inp"])
        #neg_stack = torch.from_numpy(neg_data["stack"])
        neg_len = torch.from_numpy(neg_data["inp_len"])

        self.inp = torch.cat([pos_inp, neg_inp], dim = 0 )
        #self.stack = torch.cat([pos_stack , neg_stack], dim = 0)
        self.inp_len = torch.cat([pos_len, neg_len], dim = 0)

        self.y = torch.cat([torch.ones(pos_inp.shape[0]), torch.zeros(neg_inp.shape[0])], dim = 0)
        self.indices = list(range(len(self.y)))
        
        random.shuffle(self.indices)

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, sidx):
        idx = self.indices[sidx]
        return self.inp[idx], self.y[idx], self.inp_len[idx]
