import torch
import pandas as pd
from brl.utils import *


class Dataset:
    def __init__(
        self,
        src_tok,
        trg_tok,
        src_train_fname,
        trg_train_fname,
        src_valid_fname,
        trg_valid_fname,
        src_test_fname,
        trg_test_fname,
    ):
        self.src_tok = src_tok
        self.trg_tok = trg_tok

        self.src_train_fname = src_train_fname
        self.trg_train_fname = trg_train_fname
        self.src_valid_fname = src_valid_fname
        self.trg_valid_fname = trg_valid_fname
        self.src_test_fname = src_test_fname
        self.trg_test_fname = trg_test_fname

        self.max_len = 1024

    def setup(self, train=True, valid=True, test=True):
        if train:
            self.train_df = self.read_id(self.src_train_fname,
                                         self.trg_train_fname)
        if valid:
            self.valid_df = self.read_id(self.src_valid_fname,
                                         self.trg_valid_fname)
        if test:
            self.test_df = self.read_id(self.src_test_fname,
                                        self.trg_test_fname)

    def read_id(self, src_fname, trg_fname):
        src = open(src_fname).read().strip().split('\n')
        trg = open(trg_fname).read().strip().split('\n')

        src = list(map(lambda s: list(map(int, s.split())), src))
        trg = list(map(lambda s: list(map(int, s.split())), trg))

        df = pd.DataFrame({self.src_tok.lang: src, self.trg_tok.lang: trg})
        return df

    @staticmethod
    def pad(idss, tok, device='cpu'):
        """bos, eos, pad"""
        idss = [[tok.bos_id] + ids + [tok.eos_id] for ids in idss]
        mlen = max(len(ids) for ids in idss)
        idss = [ids + [tok.pad_id] * (mlen - len(ids)) for ids in idss]
        tensor = torch.tensor(idss, device=device).T
        return tensor  # (l, b)

    @staticmethod
    def unpad(tensor, tok):  # (l, b)
        idss = tensor.T.tolist()
        idss = [l[1:] if l[0] == tok.bos_id else l for l in idss]
        idss = [
            l[:l.index(tok.eos_id)] if tok.eos_id in l else l for l in idss
        ]
        return idss

    def batch_idxs(self, src_lens, trg_lens, batch_size):
        """
        src_lens: src sentence lengths (list or pd.Series) including bos and eos
        trg_lens: trg sentence lengths (list or pd.Series) including bos and eos
        batch_size: number of maximum tokens in a batch
        """
        batch_idxs = []
        batch_idx = []
        src_max = trg_max = 0
        for i, (sl, tl) in enumerate(zip(src_lens, trg_lens)):
            if sl > self.max_len or tl > self.max_len or sl > batch_size or tl > batch_size:
                continue
            src_max = max(src_max, sl)
            trg_max = max(trg_max, tl)
            batch_idx.append(i)
            if len(batch_idx) * src_max > batch_size or len(batch_idx) * trg_max > batch_size:
                batch_idxs.append(batch_idx[:-1])
                batch_idx = batch_idx[-1:]
                src_max = sl
                trg_max = tl
        batch_idxs.append(batch_idx)
        return batch_idxs

    def train_dataloader(self, batch_size, device):
        df = self.train_df.sample(frac=1)
        src_lens = df[self.src_tok.lang].map(len) + 2
        trg_lens = df[self.trg_tok.lang].map(len) + 2
        batch_idxs = self.batch_idxs(src_lens, trg_lens, batch_size)

        stats_src, stats_trg = self.analyze_batch_idxs(batch_idxs, src_lens.values, trg_lens.values, verbose=False)
        print('')
        for rv in stats_src: print(rv)
        print('')
        for rv in stats_trg: print(rv)
        print('')

        return torch.utils.data.DataLoader(
            df.values,
            batch_sampler=batch_idxs,
            collate_fn=collate_fn(device, self),
            pin_memory=True,
        )

    def analyze_batch_idxs(self, batch_idxs, src_lens, trg_lens, verbose):
        if verbose:
            print('#batch : ', len(batch_idxs))

        rv_num_toks_src = RV('#token/batch in src', save_data=True)
        rv_num_sents_src = RV('#sentence/batch in src', save_data=True)
        rv_max_len_src = RV('max_len/batch in src', save_data=True)
        rv_batch_size_src = RV('batch size in src', save_data=True)
        rv_pad_rate_src = RV('padding rate in src', save_data=True)
        stats_src = [rv_num_toks_src, rv_num_sents_src, rv_max_len_src, rv_batch_size_src, rv_pad_rate_src]

        rv_num_toks_trg = RV('#token/batch in trg', save_data=True)
        rv_num_sents_trg = RV('#sentence/batch in trg', save_data=True)
        rv_max_len_trg = RV('max_len/batch in trg', save_data=True)
        rv_batch_size_trg = RV('batch size in trg', save_data=True)
        rv_pad_rate_trg = RV('padding rate in trg', save_data=True)
        stats_trg = [rv_num_toks_trg, rv_num_sents_trg, rv_max_len_trg, rv_batch_size_trg, rv_pad_rate_trg]

        def get_stats(sent_lens, src_or_trg, rv_num_toks, rv_num_sents, rv_max_len, rv_batch_size, rv_pad_rate):
            for i, batch in enumerate(batch_idxs):
                sent_lens_batch = [sent_lens[i] for i in batch]
                num_tokens = sum(sent_lens_batch)
                num_sents = len(batch)
                max_len = max(sent_lens_batch)
                num_pads = sum([max(sent_lens_batch)-len for len in sent_lens_batch])
                batch_size = sum([max(sent_lens_batch) for len in sent_lens_batch])
                if verbose:
                    print('batch {:2d} ({}) : #tok={} , #sentence={} , max_len={:3d} , #pad={} / {} = {:.1f}%'.format(
                        i, src_or_trg, num_tokens, num_sents, max_len, num_pads, batch_size, num_pads/batch_size*100.0
                    ))
                rv_num_toks.append(num_tokens)
                rv_num_sents.append(num_sents)
                rv_max_len.append(max_len)
                rv_batch_size.append(batch_size)
                rv_pad_rate.append(num_pads / batch_size)

        get_stats(src_lens, 'src', *stats_src)
        get_stats(trg_lens, 'trg', *stats_trg)

        if verbose:
            print('')
            for rv in stats_src: print(rv)
            print('')
            for rv in stats_trg: print(rv)
            print('')
            print('analyze_batch_idxs() finished.')
            print('')
        return stats_src, stats_trg

    def valid_dataloader(self, batch_size, device):
        df = self.valid_df
        src_lens = df[self.src_tok.lang].map(len) + 2
        trg_lens = df[self.trg_tok.lang].map(len) + 2
        batch_idxs = self.batch_idxs(src_lens, trg_lens, batch_size)
        return torch.utils.data.DataLoader(
            df.values,
            batch_sampler=batch_idxs,
            collate_fn=collate_fn(device, self),
            pin_memory=True,
        )

    def test_dataloader(self, batch_size, device):
        df = self.test_df
        src_lens = df[self.src_tok.lang].map(len) + 2
        trg_lens = df[self.trg_tok.lang].map(len) + 2
        batch_idxs = self.batch_idxs(src_lens, trg_lens, batch_size)
        return torch.utils.data.DataLoader(
            df.values,
            batch_sampler=batch_idxs,
            collate_fn=collate_fn(device, self),
            pin_memory=True,
        )


class collate_fn:
    def __init__(self, device, dataset):
        self.device = device
        self.dataset = dataset

    def __call__(self, values):
        src_idss, trg_idss = list(zip(*values))
        src = self.dataset.pad(src_idss, self.dataset.src_tok, self.device)
        trg = self.dataset.pad(trg_idss, self.dataset.trg_tok, self.device)
        return Batch(self.dataset.src_tok.lang, src, self.dataset.trg_tok.lang,
                     trg)


class Batch:
    __slots__ = ['src_lang', 'src_tensor', 'trg_lang', 'trg_tensor']

    def __init__(self, src_lang, src_tensor, trg_lang, trg_tensor):

        self.src_lang = src_lang
        self.src_tensor = src_tensor
        self.trg_lang = trg_lang
        self.trg_tensor = trg_tensor