from dataclasses import dataclass
from transformers import DataCollatorWithPadding
from transformers import default_data_collator

import pdb
import torch

@dataclass
class QPCollator(DataCollatorWithPadding):
    """
    Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg]
    and pass batch separately to the actual collator.
    Abstract out data detail for the model.
    """
    max_q_len: int = 32
    max_p_len: int = 128

    def __call__(self, features):
        qq = [f["query"] for f in features]
        dd = [f["passages"] for f in features]
        if isinstance(qq[0], list):
            qq = sum(qq, [])
        
        # qq_fake=[]
        # for sample_index in range(len(dd)):
        #     for passage_index in range(len(dd[sample_index])):
        #         aug_q_fake=qq[sample_index]["input_ids"]+dd[sample_index][passage_index]["input_ids"]
        #         qq_fake.append({"input_ids":aug_q_fake})

        if isinstance(dd[0], list):
            dd = sum(dd, [])

        d_collated = self.tokenizer.pad(
            dd,
            padding='max_length',
            max_length=self.max_p_len,
            return_tensors="pt",
        )
        q_collated = self.tokenizer.pad(
            qq,
            padding='max_length',
            max_length=self.max_p_len+self.max_q_len,
            return_tensors="pt",
        )
        return q_collated, d_collated


@dataclass
class EncodeCollator(DataCollatorWithPadding):
    def __call__(self, features):
        text_ids = [x["text_id"] for x in features]
        token_inputs = [x["tokenized"] for x in features]
        if isinstance(token_inputs[0], list):
            token_inputs = sum(token_inputs, [])
        token_inputs = self.tokenizer.pad(
            token_inputs,
            padding='do_not_pad',
            return_tensors="pt",
        )
        return text_ids, token_inputs


@dataclass
class DistillCollator(DataCollatorWithPadding):
    """
    Wrapper that does conversion from List[Tuple[encode_qry, encode_psg]] to List[qry], List[psg]
    and pass batch separately to the actual collator.
    Abstract out data detail for the model.
    """
    max_q_len: int = 32
    max_p_len: int = 128

    def __call__(self, features):
        qq = [f["query"] for f in features]
        dd = [f["passages"] for f in features]
        ance_dd = [f["ance-passages"] for f in features]
        ss = [f["scores"] for f in features]

        if isinstance(qq[0], list):
            qq = sum(qq, [])
        
        # qq_fake=[]
        # for sample_index in range(len(dd)):
        #     for passage_index in range(len(dd[sample_index])):
        #         aug_q_fake=qq[sample_index]["input_ids"]+dd[sample_index][passage_index]["input_ids"]
        #         qq_fake.append({"input_ids":aug_q_fake})

        if isinstance(dd[0], list):
            dd = sum(dd, [])

        if isinstance(ance_dd[0], list):
            ance_dd = sum(ance_dd, [])

        if isinstance(ss[0], list):
            ss = sum(ss, [])

        d_collated = self.tokenizer.pad(
            dd,
            padding='max_length',
            max_length=self.max_p_len,
            return_tensors="pt",
        )

        ance_d_collated = self.tokenizer.pad(
            ance_dd,
            padding='max_length',
            max_length=self.max_p_len,
            return_tensors="pt",
        )

        q_collated = self.tokenizer.pad(
            qq,
            padding='max_length',
            max_length=self.max_p_len+self.max_q_len,
            return_tensors="pt",
        )

        s_collated=torch.Tensor(ss)

        return q_collated, d_collated, ance_d_collated, s_collated

