import random
import warnings

import h5py
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import BertTokenizer, BertTokenizerFast
from transformers.data.data_collator import DataCollatorMixin, _torch_collate_batch


def tolist(x):
    if isinstance(x, list):
        return x
    elif hasattr(x, "numpy"):  # Checks for TF tensors without needing the import
        x = x.numpy()
    return x.tolist()


def torch_mask_tokens(inputs, tokenizer, mlm_probability, special_tokens_mask=None):
    """
    Prepare masked tokens inputs/labels for masked language modeling: 100% MASK.
    """

    labels = inputs.clone()
    # We sample a few tokens in each sequence for MLM training (with probability `mlm_probability`)
    probability_matrix = torch.full(labels.shape, mlm_probability)
    if special_tokens_mask is None:
        special_tokens_mask = [
            tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
        ]
        special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
    else:
        special_tokens_mask = special_tokens_mask.bool()

    probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
    masked_indices = torch.bernoulli(probability_matrix).bool()
    labels[~masked_indices] = -100  # We only compute loss on masked tokens

    inputs[masked_indices] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)

    return inputs, labels


class DataCollatorForProteinWithStructureWithPadding(DataCollatorMixin):
    def __init__(self, protein_tokenizer):
        self.return_tensors = "pt"
        self.protein_tokenizer = protein_tokenizer

    def torch_call(self, examples):
        protein_tokenized = {
            'input_ids': torch.stack([e["input_ids"] for e in examples]),
            'attention_mask': torch.stack([e["attention_mask"] for e in examples]),
        }
        labels = torch.stack([e['label'] for e in examples])
        node_position = torch.zeros((len(examples), protein_tokenized['input_ids'].shape[-1], 3))
        for i, e in enumerate(examples):
            node_position[i, 1:e["node_position"].shape[0] + 1, :] = e["node_position"]

        return {
            'input_ids': protein_tokenized['input_ids'],
            'attention_mask': protein_tokenized['attention_mask'],
            'node_position': node_position,
            'labels': labels,
        }


class DataCollatorForProteinTextCLIPPretrain(DataCollatorMixin):
    def __init__(self, protein_tokenizer, text_tokenizer, pdb_h5_file, sequence_only, mlm_probability):
        self.return_tensors = "pt"
        self.protein_tokenizer = protein_tokenizer
        self.text_tokenizer = text_tokenizer
        self.pdb_h5_file = pdb_h5_file
        self.sequence_only = sequence_only
        self.mlm_probability = mlm_probability

    def torch_call(self, examples):
        sequences = []
        if not self.sequence_only:
            node_positions = []
        with h5py.File(self.pdb_h5_file, "r") as pdb_h5:
            for e in examples:
                alphafold_id = e["id"]
                group = pdb_h5[alphafold_id]
                sequence = group["sequence"][()][0].decode()
                node_position = group["node_position"][()]
                sequences.append(sequence)
                if not self.sequence_only:
                    node_positions.append(torch.from_numpy(node_position)[:1024 - 2])

        protein_tokenized = self.protein_tokenizer(
            sequences,
            truncation=True,
            max_length=1024,
            padding='max_length',
            return_tensors="pt",
            return_attention_mask=True)

        text_tokenized = self.text_tokenizer(
            [e["caption"] for e in examples],
            truncation=True,
            max_length=512,
            padding='max_length',
            return_tensors="pt",
            return_attention_mask=True)

        if not self.sequence_only:
            node_position = torch.zeros((len(examples), protein_tokenized['input_ids'].shape[-1], 3))
            for i, e in enumerate(examples):
                node_position[i, 1:node_positions[i].shape[0] + 1, :] = node_positions[i]

        if self.mlm_probability == 0.0:
            if not self.sequence_only:
                return {
                    'protein_input_ids': protein_tokenized['input_ids'],
                    'protein_attention_mask': protein_tokenized['attention_mask'],
                    'text_input_ids': text_tokenized['input_ids'],
                    'text_attention_mask': text_tokenized['attention_mask'],
                    'node_position': node_position,
                }
            else:
                return {
                    'protein_input_ids': protein_tokenized['input_ids'],
                    'protein_attention_mask': protein_tokenized['attention_mask'],
                    'text_input_ids': text_tokenized['input_ids'],
                    'text_attention_mask': text_tokenized['attention_mask'],
                }
        else:
            protein_masked_input_ids, protein_masked_labels = torch_mask_tokens(protein_tokenized["input_ids"].clone(),
                                                                                self.protein_tokenizer,
                                                                                self.mlm_probability)
            if not self.sequence_only:
                return {
                    'protein_input_ids': protein_tokenized['input_ids'],
                    'protein_attention_mask': protein_tokenized['attention_mask'],
                    'text_input_ids': text_tokenized['input_ids'],
                    'text_attention_mask': text_tokenized['attention_mask'],
                    'node_position': node_position,
                    'protein_masked_input_ids': protein_masked_input_ids,
                    'protein_masked_labels': protein_masked_labels,
                }
            else:
                return {
                    'protein_input_ids': protein_tokenized['input_ids'],
                    'protein_attention_mask': protein_tokenized['attention_mask'],
                    'text_input_ids': text_tokenized['input_ids'],
                    'text_attention_mask': text_tokenized['attention_mask'],
                    'protein_masked_input_ids': protein_masked_input_ids,
                    'protein_masked_labels': protein_masked_labels,
                }
