import torch as T
import numpy as np
import random
import copy


class seq_label_collater:
    def __init__(self, PAD, config, train):
        self.PAD = PAD
        self.config = config
        self.train = train

    def pad(self, items, PAD):
        max_len = max([len(item) for item in items])

        padded_items = []
        item_masks = []
        for item in items:
            mask = [1] * len(item)
            while len(item) < max_len:
                item.append(PAD)
                mask.append(0)
            padded_items.append(item)
            item_masks.append(mask)

        return padded_items, item_masks

    def sort_list_by_idx(self, objs, idx):
        return [objs[i] for i in idx]

    def collate_fn(self, batch):
        copy_batch = copy.deepcopy(batch)
        sequences_vec = [obj["sequence_vec"] for obj in copy_batch]
        char_sequences_vec = [obj["char_sequence_vec"] for obj in copy_batch]
        sequences = [obj['sequence'] for obj in copy_batch]
        labels = [obj['label'] for obj in copy_batch]
        feats = [obj["feats"] for obj in copy_batch]

        for sequence, label, sequence_vec, feat in zip(sequences, labels, sequences_vec, feats):
            # print("sequence: ", " ".join(sequence))
            # print("sequence_vec: ", sequence_vec)
            # print("label: ", " ".join([self.config["idx2labels"][id] for id in label]))
            # print("\n")
            assert len(sequence) == len(label)
            assert len(sequence_vec) == len(label)
            assert len(sequence) == len(sequence_vec)
            assert len(feat) == len(label)

        bucket_size = len(sequences_vec)
        if self.train:
            batch_size = self.config["train_batch_size"]
        else:
            batch_size = self.config["dev_batch_size"]

        lengths = [len(obj) for obj in sequences_vec]
        sorted_idx = np.argsort(lengths)

        sequences_vec = self.sort_list_by_idx(sequences_vec, sorted_idx)
        char_sequences_vec = self.sort_list_by_idx(char_sequences_vec, sorted_idx)
        sequences = self.sort_list_by_idx(sequences, sorted_idx)
        feats = self.sort_list_by_idx(feats, sorted_idx)
        labels = self.sort_list_by_idx(labels, sorted_idx)
        labels_vec = copy.deepcopy(labels)

        meta_batches = []

        i = 0
        while i < bucket_size:
            inr = batch_size
            if i + inr > bucket_size:
                inr = bucket_size - i

            max_len = max([len(obj) for obj in sequences_vec[i:i + inr]])

            """
            if max_len1 >= 70 or max_len2 >= 70:
                inr_ = min(batch_size//2, inr)
            else:
                inr_ = inr
            """
            inr_ = inr

            j = copy.deepcopy(i)
            batches = []
            feat_len = len(feats[0][0])
            char_len = len(char_sequences_vec[0][0])
            feat_pad = [0] * (feat_len - 1) + [1]
            char_pad = [self.config["char_pad_id"]] * char_len
            while j < i + inr:
                sequences_vec_, input_masks = self.pad(sequences_vec[j:j + inr_], PAD=self.PAD)
                char_sequences_vec_, _ = self.pad(char_sequences_vec[j:j + inr_], PAD=char_pad)
                feats_, _ = self.pad(feats[j:j + inr_], PAD=feat_pad)
                labels_vec_, _ = self.pad(labels_vec[j:j + inr_], PAD=self.config["labels2idx"]["O"])

                batch = {}
                batch["sequences_vec"] = T.tensor(sequences_vec_).long()
                batch["char_sequences_vec"] = T.tensor(char_sequences_vec_).long()
                batch["feats"] = T.tensor(feats_).float()
                batch["sequences"] = sequences[j:j + inr_]
                batch["labels_vec"] = T.tensor(labels_vec_).long()
                batch["labels"] = labels[j:j + inr_]
                batch["input_masks"] = T.tensor(input_masks).float()
                batch["batch_size"] = inr_
                batches.append(batch)
                j += inr_
                """
                print("sequences: ", batch["sequences"][0])
                print("labels: ", batch["labels"][0])
                print("sequences_vec: ", batch["sequences_vec"][0])
                print("labels_vec: ", batch["labels_vec"][0])
                """
            i += inr

            meta_batches.append(batches)

        random.shuffle(meta_batches)

        batches = []
        for batch_list in meta_batches:
            batches = batches + batch_list

        return batches
