import torch


class LoadTextTokens(object):
    def __init__(self, tokenizer, max_text_len=40, padding='do_not_pad'):
        self.tokenizer = tokenizer
        self.max_text_len = max_text_len
        self.padding = padding

    def descriptions_to_text_tokens(self, target, begin_token):
        target_encoding = self.tokenizer(
            target, padding=self.padding,
            add_special_tokens=False,
            truncation=True, max_length=self.max_text_len)

        need_predict = [1] * len(target_encoding['input_ids'])
        payload = target_encoding['input_ids']
        if len(payload) > self.max_text_len - 2:
            payload = payload[-(self.max_text_len - 2):]
            need_predict = payload[-(self.max_text_len - 2):]

        input_ids = [begin_token] + payload + [self.tokenizer.sep_token_id]

        need_predict = [0] + need_predict + [1]
        data = {
            'text_tokens': torch.tensor(input_ids),
            'text_lengths': len(input_ids),
            'need_predict': torch.tensor(need_predict),
        }

        return data

    def __call__(self, object_descriptions, box_features, begin_token):
        text_tokens = []
        text_lengths = []
        need_predict = []
        for description in object_descriptions:
            tokens = self.descriptions_to_text_tokens(description, begin_token)
            text_tokens.append(tokens['text_tokens'])
            text_lengths.append(tokens['text_lengths'])
            need_predict.append(tokens['need_predict'])

        text_tokens = torch.cat(self.collate(text_tokens), dim=0).to(box_features.device)
        text_lengths = torch.tensor(text_lengths).to(box_features.device)
        need_predict = torch.cat(self.collate(need_predict), dim=0).to(box_features.device)

        assert text_tokens.dim() == 2 and need_predict.dim() == 2
        data = {'text_tokens': text_tokens,
                'text_lengths': text_lengths,
                'need_predict': need_predict}

        return data

    def collate(self, batch):
        if all(isinstance(b, torch.Tensor) for b in batch) and len(batch) > 0:
            if not all(b.shape == batch[0].shape for b in batch[1:]):
                assert all(len(b.shape) == len(batch[0].shape) for b in batch[1:])
                shape = torch.tensor([b.shape for b in batch])
                max_shape = tuple(shape.max(dim=0)[0].tolist())
                batch2 = []
                for b in batch:
                    if any(c < m for c, m in zip(b.shape, max_shape)):
                        b2 = torch.zeros(max_shape, dtype=b.dtype, device=b.device)
                        if b.dim() == 1:
                            b2[:b.shape[0]] = b
                        elif b.dim() == 2:
                            b2[:b.shape[0], :b.shape[1]] = b
                        elif b.dim() == 3:
                            b2[:b.shape[0], :b.shape[1], :b.shape[2]] = b
                        else:
                            raise NotImplementedError
                        b = b2
                    batch2.append(b[None, ...])
            else:
                batch2 = []
                for b in batch:
                    batch2.append(b[None, ...])
            return batch2
        else:
            raise NotImplementedError
