from typing import List

import copy
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer, BertTokenizer

class PaddingCollateFunction(object):
    def __init__(self, padding_idx):
        self.padding_idx = padding_idx

    def __call__(self, batch: List[tuple]):
        reference_images, target_images, modifiers, lengths, ref_id, targ_id = zip(*batch)

        reference_images = torch.stack(reference_images, dim=0)
        target_images = torch.stack(target_images, dim=0)
        seq_lengths = torch.tensor(lengths).long()
        modifiers = pad_sequence(modifiers, padding_value=self.padding_idx, batch_first=True)
        return reference_images, target_images, modifiers, seq_lengths, None

class PaddingCollateFunctionTest(object):
    def __init__(self, padding_idx):
        self.padding_idx = padding_idx

    @staticmethod
    def _collate_test_dataset(batch):
        reference_images, ids = zip(*batch)
        reference_images = torch.stack(reference_images, dim=0)
        return reference_images, ids

    def _collate_test_query_dataset(self, batch):
        reference_images, ref_attrs, modifiers, target_attrs, lengths = zip(*batch)
        reference_images = torch.stack(reference_images, dim=0)
        seq_lengths = torch.tensor(lengths).long()
        modifiers = pad_sequence(modifiers, padding_value=self.padding_idx, batch_first=True)
        return reference_images, ref_attrs, modifiers, target_attrs, seq_lengths, None

    def __call__(self, batch: List[tuple]):
        num_items = len(batch[0])
        if num_items > 2:
            return self._collate_test_query_dataset(batch)
        else:
            return self._collate_test_dataset(batch)

class BertPaddingCollateFunction(object):
    def __init__(self, padding_idx):
        self.padding_idx = padding_idx
        self.tokenizer = AutoTokenizer.from_pretrained("roberta-base")

    def __call__(self, batch: List[tuple]):
        reference_images, target_images, modifiers, lengths, ref_id, targ_id, _ = zip(*batch)

        reference_images = torch.stack(reference_images, dim=0)
        target_images = torch.stack(target_images, dim=0)
        seq_lengths = torch.tensor(lengths).long()

        modifiers = list(modifiers)
        token = self.tokenizer.batch_encode_plus(modifiers, padding='longest', return_tensors='pt')

        attn_mask = token['attention_mask']
        modifiers = token['input_ids']
        return reference_images, target_images, modifiers, seq_lengths, attn_mask


class BertPaddingCollateFunctionTest(object):
    def __init__(self, padding_idx):
        self.padding_idx = padding_idx
        self.tokenizer = AutoTokenizer.from_pretrained("roberta-base")

    @staticmethod
    def _collate_test_dataset(batch):
        reference_images, ids = zip(*batch)
        reference_images = torch.stack(reference_images, dim=0)
        return reference_images, ids

    def _collate_test_query_dataset(self, batch):
        reference_images, ref_attrs, modifiers, target_attrs, lengths = zip(*batch)
        reference_images = torch.stack(reference_images, dim=0)
        seq_lengths = torch.tensor(lengths).long()

        modifiers = list(modifiers)
        token = self.tokenizer.batch_encode_plus(modifiers, padding='longest', return_tensors='pt')

        attn_mask = token['attention_mask']
        modifiers = token['input_ids']

        return reference_images, ref_attrs, modifiers, target_attrs, seq_lengths, attn_mask

    def __call__(self, batch: List[tuple]):
        num_items = len(batch[0])
        if num_items > 2:
            return self._collate_test_query_dataset(batch)
        else:
            return self._collate_test_dataset(batch)


def init_tokenizer():
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    tokenizer.add_special_tokens({'bos_token':'[DEC]'})
    tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})
    tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
    return tokenizer

class BLIPPaddingCollateFunction(object):
    def __init__(self):
        self.tokenizer = init_tokenizer()

    def __call__(self, batch: List[tuple]):
        reference_images, target_images, negative_target_img, targ_id, sentences = zip(*batch)

        reference_images = torch.stack(reference_images, dim=0)
        target_images = torch.stack(target_images, dim=0)
        negative_target_img = torch.stack(negative_target_img, dim=0)
        return reference_images, target_images, negative_target_img, sentences


class BLIPPaddingCollateFunctionTest(object):
    def __init__(self):
        self.tokenizer = init_tokenizer()

    @staticmethod
    def _collate_test_dataset(batch):
        reference_images, ids = zip(*batch)
        reference_images = torch.stack(reference_images, dim=0)
        return reference_images, ids

    def _collate_test_query_dataset(self, batch):
        reference_image, rel_caption, target_name = zip(*batch)
        reference_images = torch.stack(reference_image, dim=0)

        target_name = list(target_name)

        return reference_images, rel_caption, target_name
        
    def __call__(self, batch: List[tuple]):
        num_items = len(batch[0])
        if num_items > 2:
            return self._collate_test_query_dataset(batch)
        else:
            return self._collate_test_dataset(batch)


class BLIPPaddingCollateFunctionTest4CIRR(object):
    def __init__(self):
        self.tokenizer = init_tokenizer()

    @staticmethod
    def _collate_test_dataset(batch):
        reference_images, ids = zip(*batch)
        reference_images = torch.stack(reference_images, dim=0)
        return reference_images, ids

    def _collate_test_query_dataset(self, batch):
        reference_name, reference_image, rel_caption, pair_id, group_members = zip(*batch)
        reference_images = torch.stack(reference_image, dim=0)
        reference_name = list(reference_name)

        return reference_name, reference_images, rel_caption, pair_id, group_members

    def __call__(self, batch: List[tuple]):
        num_items = len(batch[0])
        if num_items > 2:
            return self._collate_test_query_dataset(batch)
        else:
            return self._collate_test_dataset(batch)
        
class BLIPPaddingCollateFunctionTest4FIQ(object):
    def __init__(self):
        self.tokenizer = init_tokenizer()

    @staticmethod
    def _collate_test_dataset(batch):
        reference_images, ids = zip(*batch)
        reference_images = torch.stack(reference_images, dim=0)
        return reference_images, ids

    def _collate_test_query_dataset(self, batch):
        target_name, reference_image, rel_caption = zip(*batch)
        reference_images = torch.stack(reference_image, dim=0)
        target_name = list(target_name)

        return reference_images, rel_caption, target_name

    def __call__(self, batch: List[tuple]):
        num_items = len(batch[0])
        if num_items > 2:
            return self._collate_test_query_dataset(batch)
        else:
            return self._collate_test_dataset(batch)



