import torch
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
from collections import Counter
from torch.nn.utils.rnn import pad_sequence
import numpy as np
import Encoder_Decoder as ED
from params import vocab_in, vocab_out  # path and vocab


# mark==> r: row  c: col

def read_files(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        lines = [line.strip() for line in file]
    return lines
    # list: len(lines)=r


def BuildDataset(path_query, path_value):
    # src_sequences = read_files(path_in)
    # tgt_sequences = read_files(path_out)
    src_sequences = read_files(path_query)
    tgt_sequences = read_files(path_value)
    if len(src_sequences) != len(tgt_sequences):
        print("==> Error! The sequences are not matched!!!!!")
    else:
        print("==> The sequences are matched!")
    return src_sequences, tgt_sequences


def TokenSplit(sequences):
    result = [word for word in sequences]
    return result
    # result = ['A','G',...]


def BuildTokens(src_sequences, tgt_sequences):
    src_tokens = [TokenSplit(sequences) for sequences in src_sequences]
    tgt_tokens = [TokenSplit(sequences) for sequences in tgt_sequences]
    return src_tokens, tgt_tokens
    # src_tokens = [['A','G',...],['G','T',...],...]


def numericalize(tokens, vocab):
    result = []
    for sequences in tokens:
        tem = [vocab['<sos>']]
        tem_1 = [vocab[word] for word in sequences if word in vocab]
        tem = tem + tem_1
        tem.append(vocab['<eos>'])
        result.append(tem)
    result = pad_sequence([torch.LongTensor(np.array(sequences)) for sequences in result],
                          batch_first=True,
                          padding_value=vocab['<pad>'])
    return result
    # result = [['1','3',...],['1','5',...],...]  => (r * max sequence length (c))


def BuildNumerical(src_tokens, src_vocab, tgt_tokens, tgt_vocab):
    src_numerical = numericalize(src_tokens, src_vocab)
    tgt_numerical = numericalize(tgt_tokens, tgt_vocab)
    return src_numerical, tgt_numerical


def BuildData(path_query, path_value):
    src_sequences, tgt_sequences = BuildDataset(path_query, path_value)
    src_tokens, tgt_tokens = BuildTokens(src_sequences, tgt_sequences)
    src_numerical, tgt_numerical = BuildNumerical(src_tokens, vocab_in, tgt_tokens, vocab_out)
    return src_numerical, tgt_numerical, vocab_in, vocab_out


class TranslationDataset(Dataset):
    def __init__(self, src_numerical, tgt, tgt_numerical, src_vocab, tgt_vocab, pad=0):
        super(TranslationDataset, self).__init__()
        self.src = src_numerical
        self.src_mask = (src_numerical != pad).unsqueeze(-2)
        # src_mask = tensor([row,1,col+2(sos,eos)])
        print("Shape of Mask in dataset:", np.shape(self.src_mask))
        self.groundTruth = tgt_numerical
        self.tgt_mask = None
        if tgt is not None:
            self.tgt = tgt[:, :-1]
            self.tgt_y = tgt[:, 1:]
            self.tgt_mask = self.make_std_mask(self.tgt, pad)
            self.ntokens = (self.tgt_y != pad).sum()
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab

    @staticmethod
    def make_std_mask(tgt, pad):  # tgt=([row,col])
        tgt_mask = (tgt != pad).unsqueeze(-2)  # tgt_mask=tensor([row,1,col])
        final_mask = ED.subsequent_mask(tgt.size(-1))  # final_mask=tensor([1,col,col])
        final_mask = final_mask.type_as(tgt_mask.data)
        final_mask = tgt_mask & Variable(final_mask)
        return final_mask  # final_mask=tensor([row,col,col])

    def __len__(self):
        return len(self.src)

    def __getitem__(self, idx):
        src_sequences = self.src[idx]
        tgt_sequences = self.groundTruth[idx]
        return {
            'src': src_sequences.type(torch.long),  # tensor([col])
            'src_mask': Variable(self.src_mask[idx]),  # tensor([1,col])
            'tgt': self.tgt[idx].type(torch.long),  # tensor([col-1]) ([sos,2,3,...,)
            'tgt_mask': self.tgt_mask[idx],  # tensor([col-1,col-1])
            'tgt_y': self.tgt_y[idx].type(torch.long),  # tensor([col-1])([2,3,...,eos])
            'ntokens': self.ntokens,  # int
            'gt': tgt_sequences.type(torch.long)  # tensor([col])
        }


def BuildDataLoader(batch_size=32, path_query=None, path_value=None):
    src_numerical, tgt_numerical, src_vocab, tgt_vocab = BuildData(path_query, path_value)
    dataset = TranslationDataset(src_numerical=src_numerical, tgt=tgt_numerical, tgt_numerical=tgt_numerical,
                                 src_vocab=src_vocab, tgt_vocab=tgt_vocab)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader, src_vocab, tgt_vocab
