import ndjson
from nltk.tokenize import word_tokenize
import numpy as np
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import Dataset, DataLoader


class ProteinDataset(Dataset):
    def __init__(self,
                 ids_path,
                 dataset_path,
                 seq_vocab_path,
                 func_vocab_path,
                 max_seq_len,
                 max_func_len,
                 is_cross):
        
        with open(ids_path, 'r') as f:
            self.ids = f.readlines()
        with open(dataset_path, 'r') as f:
            dataset = ndjson.load(f)

        self.ids = [id.strip() for id in self.ids]
        self.dataset = {}
        self.dataset.update({item['id']: item for item in dataset})

        self.max_seq_len = max_seq_len
        self.max_func_len = max_func_len
        
        self.seq_encoder = self.word2index(seq_vocab_path)
        self.func_encoder = self.word2index(func_vocab_path)
        
        # encode sequence and function
        for item in list(self.dataset.values()):
            seq_encodings = self.encode_seq(item['sequence'])
            func_encodings = self.encode_func(item['function'], is_cross=is_cross)
            item['sequence'] = seq_encodings
            item['function'] = func_encodings
            self.dataset[item['id']] = item

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, index):
        id = self.ids[index]
        item = self.dataset[id]
        length = item['length']
        sequence = item['sequence']
        func = item['function']
        func_len = item['func_len']
        sequence = np.array(sequence)
        func = np.array(func)
        return id, length, sequence, func, func_len

    def word2index(self, vocab_path):

        with open(vocab_path, 'r') as f:
            vocab = f.readlines()

        vocab = [v.strip() for v in vocab]
        vocab_dict = {}
        for i, word in enumerate(vocab):
            vocab_dict[word] = i

        vocab_dict['[START]'] = len(vocab) + 1
        vocab_dict['[END]'] = len(vocab)

        return vocab_dict

    def encode_seq(self, seq):

        seq = list(seq)
        seq = self.align_seq(seq)
        seq_encodings = []
        for char in seq:
            encoding = self.seq_encoder[char]
            seq_encodings.append(encoding)
        return seq_encodings

    def align_seq(self, seq):

        seq = seq[:self.max_seq_len]
        if len(seq) < self.max_seq_len:
            seq = seq + ['[PAD]'] * (self.max_seq_len - len(seq))
        return seq
    
    def encode_func(self, func, is_cross=False):

        func = word_tokenize(func)
        func = self.align_func(func, is_cross)
        func_encodings = []
        for word in func:
            encoding = self.func_encoder[word]
            func_encodings.append(encoding)
        return func_encodings
    
    def align_func(self, func, is_cross):

        func = func[:self.max_func_len]
        if is_cross:
            func = ['[START]'] + func + ['[END]']
            max_len = self.max_func_len + 2
        else:
            func = func + ['[END]']
            max_len = self.max_func_len + 1

        if len(func) < max_len:
            func = func + ['[PAD]'] * (max_len - len(func))
        return func
    

def cycle(iterable):
    while True:
        for x in iterable:
            yield x
    

def get_dataloader(args):
    train_dataset = ProteinDataset(ids_path=args.train_ids_path,
                                dataset_path=args.data_path,
                                seq_vocab_path=args.seq_vocab_path,
                                func_vocab_path=args.func_vocab_path,
                                max_seq_len=args.seq_len_max,
                                max_func_len=args.func_len_max,
                                is_cross=True)
    val_dataset = ProteinDataset(ids_path=args.val_ids_path,
                                dataset_path=args.data_path,
                                seq_vocab_path=args.seq_vocab_path,
                                func_vocab_path=args.func_vocab_path,
                                max_seq_len=args.seq_len_max,
                                max_func_len=args.func_len_max,
                                is_cross=True)
    translate_dataset = ProteinDataset(ids_path=args.test_ids_path,
                                dataset_path=args.data_path,
                                seq_vocab_path=args.seq_vocab_path,
                                func_vocab_path=args.func_vocab_path,
                                max_seq_len=args.seq_len_max,
                                max_func_len=args.func_len_max,
                                is_cross=True)
    print(f"local rank {args.local_rank} successfully build train dataset")

    train_sampler = DistributedSampler(train_dataset)
    val_sampler = DistributedSampler(val_dataset)
    translate_sampler = DistributedSampler(translate_dataset)

    train_loader = DataLoader(train_dataset, 
                            batch_size=args.batch_size, 
                            shuffle=(train_sampler is None),
                            num_workers=args.num_workers,
                            sampler=train_sampler)
    val_loader = DataLoader(val_dataset,
                            batch_size=1, 
                            shuffle=False,
                            num_workers=args.num_workers,
                            sampler=val_sampler)
    translate_loader = DataLoader(translate_dataset,
                            batch_size=1, 
                            shuffle=False,
                            num_workers=args.num_workers,
                            sampler=translate_sampler)
    return train_loader, val_loader, translate_loader

