import pickle
import random
import torch

class LongRangeDataset(object):
    def __init__(self, file_name, data_percent=1.0):
        self.data_percent = data_percent
        self.all_chars = ([str(i) for i in range(0,3)]
         + [chr(a) for a in range(65,65+26)]
          + [chr(a) for a in range(97,97+26)] + ["<bos>", "<pad>"])
        self.vocab = {ch:i for i, ch in enumerate(self.all_chars)}
        self.vocab_size = len(self.vocab)
        self.reverse_vocab = {v:k for k, v in self.vocab.items()}
        self.pad_idx = self.vocab['<pad>']
        self.all_data = pickle.load(open(file_name, 'rb'))
        self.all_data['train'] = self.all_data['train'][:int(
            data_percent*len(self.all_data['train']))]
        self.all_data['test'] = self.all_data['test'][:int(
            data_percent*len(self.all_data['test']))]
        self.all_data['val'] = self.all_data['val'][:int(
            data_percent*len(self.all_data['val']))]
        self.max_out_len = max(
            [len(v['correct_output']) for v in self.all_data['train']] 
            + [len(v['correct_output']) for v in self.all_data['test']])
        self.params = self.all_data['params']
        self.tensor_data = {}
        self.create_tensors()
        self.i = {'train':0, 'test':0, 'val':0}

    def create_tensors(self):
        for split in ['train', 'test', 'val']:
            self.tensor_data[split] = []
            for idx, example in enumerate(self.all_data[split]):
                    input_idxs = ([self.vocab["<bos>"]] \
                        + [self.vocab[c] for c in example['input']])
                    output_idxs = [self.vocab[c] 
                                    for c in example['correct_output']]
                    self.tensor_data[split].append(
                        {'idx':idx, 
                        'input':torch.LongTensor(input_idxs).unsqueeze(0),
                        'completion':torch.LongTensor(output_idxs
                            ).unsqueeze(0)})

    def get_batch(self, batch_size, split='train', separate_in_out=False, 
        with_idx=False):
        if self.i[split] + batch_size >= len(self.tensor_data[split]):
            batch_size = len(self.tensor_data[split]) - self.i[split]
        start = self.i[split]
        end = self.i[split] + batch_size
        current_batch = self.tensor_data[split][start:end]
        self.i[split] += batch_size
        if self.i[split] >= len(self.tensor_data[split]):
            self.i[split] = 0
            random.shuffle(self.tensor_data[split])

        if not separate_in_out:
            max_len = max([t['input'].shape[1] 
                + t['completion'].shape[1] for t in current_batch])
            return_batch = []
            for j in range(len(current_batch)):
                x = torch.cat([current_batch[j]['input'], 
                    current_batch[j]['completion']], dim=1)
                new_tensor = torch.cat([x, torch.LongTensor(
                    [self.pad_idx for k in range(max_len - x.shape[1])]
                    ).unsqueeze(0)], dim=1)
                return_batch.append(new_tensor)
            if with_idx:
                return torch.cat(return_batch, dim=0), torch.LongTensor(
                    [t['idx'] for t in current_batch])
            else:
                return torch.cat(return_batch, dim=0)
        else:
            max_in_len = max([t['input'].shape[1] for t in current_batch])
            max_out_len = max([t['completion'].shape[1] for t in current_batch])
            return_batch = []
            for j in range(len(current_batch)):
                x = current_batch[j]['input']
                y = current_batch[j]['completion']
                return_batch.append((x,y))
            if with_idx:
                return torch.cat([x[0] for x in return_batch], dim=0
                    ), torch.cat([x[1] for x in return_batch], dim=0
                    ), torch.LongTensor([t['idx'] for t in current_batch])
            else:
                return torch.cat([x[0] for x in return_batch], dim=0
                    ), torch.cat([x[1] for x in return_batch], dim=0)


    def num_batches(self, batch_size, split):
        if (len(self.tensor_data[split]) % batch_size) == 0:
            return len(self.tensor_data[split]) // batch_size
        else:
            return (len(self.tensor_data[split]) // batch_size) + 1

    def is_in_train_distribution(self, metadata):
        if metadata['total_length'] > self.params['max_train_length']:
            return False
        if metadata['seq_length'] > self.params['max_train_answer_length']:
            return False
        if metadata['dep_length_end'] > self.params['max_train_dep_length']:
            return False
        if self.params['num_withheld'] > 0:
            for c in self.all_data['withheld']:
                if c in metadata['input'] or c in metadata['correct_output']:
                    return False
        return True