# This file is copied from ProTrek
# Original license: MIT License
import copy
import json
import math
import os
import random
import time

import faiss
import numpy as np
import pandas as pd
import torch
import torch.distributed as dist
import torchmetrics
from sklearn.metrics import roc_auc_score
from torch.nn.functional import cross_entropy, normalize
from tqdm import tqdm

from .abstract_model import AbstractModel
from .constants import residue_level, sequence_level
from .model_interface import register_model
from .mpr import MultipleProcessRunnerSimplifier
from .protein_encoder import ProteinEncoder
from .structure_encoder import StructureEncoder
from .text_encoder import TextEncoder


def multilabel_cross_entropy(logits, labels):

    loss = 0
    for pred, label in zip(logits, labels):
        pos_logits = pred[label == 1]
        neg_logits = pred[label == 0]

        diff = neg_logits.unsqueeze(-1) - pos_logits
        loss += torch.log(1 + torch.exp(diff).sum())

    return loss / len(logits)

    # pred = (1 - 2 * labels) * logits
    # pred_neg = pred - labels * 1e12
    # pred_pos = pred - (1 - labels) * 1e12
    #
    # zeros = torch.zeros_like(logits[..., :1], dtype=logits.dtype)
    # pred_neg = torch.cat([pred_neg, zeros], dim=-1)
    # pred_pos = torch.cat([pred_pos, zeros], dim=-1)
    #
    # neg_loss = torch.logsumexp(pred_neg, dim=-1)
    # pos_loss = torch.logsumexp(pred_pos, dim=-1)
    #
    # return (neg_loss + pos_loss).mean()


@register_model
class ProTrekTrimodalModel(AbstractModel):
    def __init__(
        self,
        protein_config: str,
        text_config: str,
        structure_config: str = None,
        repr_dim: int = 1024,
        temperature: float = 0.07,
        load_protein_pretrained: bool = True,
        load_text_pretrained: bool = True,
        use_mlm_loss: bool = False,
        use_zlpr_loss: bool = False,
        use_saprot: bool = False,
        gradient_checkpointing: bool = False,
        **kwargs,
    ):
        self.protein_config = protein_config
        self.structure_config = structure_config
        self.text_config = text_config
        self.repr_dim = repr_dim
        self.temperature = temperature
        self.load_protein_pretrained = load_protein_pretrained
        self.load_text_pretrained = load_text_pretrained
        self.use_mlm_loss = use_mlm_loss
        self.use_zlpr_loss = use_zlpr_loss
        self.use_saprot = use_saprot
        self.gradient_checkpointing = gradient_checkpointing
        super().__init__(**kwargs)

    def initialize_metrics(self, stage: str) -> dict:
        return_dict = {
            f"{stage}_protein_text_acc": torchmetrics.Accuracy(),
            f"{stage}_text_protein_acc": torchmetrics.Accuracy(),
        }

        if self.use_mlm_loss:
            return_dict[f"{stage}_protein_mask_acc"] = torchmetrics.Accuracy(
                ignore_index=-1
            )
            if self.structure_config is not None:
                return_dict[f"{stage}_structure_mask_acc"] = (
                    torchmetrics.Accuracy(ignore_index=-1)
                )

        if self.structure_config is not None:
            return_dict[f"{stage}_structure_protein_acc"] = (
                torchmetrics.Accuracy()
            )
            return_dict[f"{stage}_structure_text_acc"] = torchmetrics.Accuracy()
            return_dict[f"{stage}_text_structure_acc"] = torchmetrics.Accuracy()
            return_dict[f"{stage}_protein_structure_acc"] = (
                torchmetrics.Accuracy()
            )

        return return_dict

    def initialize_model(self):
        # Initialize encoders
        self.protein_encoder = ProteinEncoder(
            self.protein_config,
            self.repr_dim,
            self.load_protein_pretrained,
            self.gradient_checkpointing,
        )

        self.text_encoder = TextEncoder(
            self.text_config,
            self.repr_dim,
            self.load_text_pretrained,
            self.gradient_checkpointing,
        )

        # Learnable temperature
        self.temperature = torch.nn.Parameter(torch.tensor(self.temperature))

        # self.model is used for saving and loading
        self.model = torch.nn.ParameterList(
            [self.temperature, self.protein_encoder, self.text_encoder]
        )

        # If the structure encoder is specified
        if self.structure_config is not None:
            self.structure_encoder = StructureEncoder(
                self.structure_config, self.repr_dim
            )
            self.model.append(self.structure_encoder)

    def get_text_repr(
        self, texts: list, batch_size: int = 64, verbose: bool = False
    ) -> torch.Tensor:
        return self.text_encoder.get_repr(texts, batch_size, verbose)

    def get_structure_repr(
        self, proteins: list, batch_size: int = 64, verbose: bool = False
    ) -> torch.Tensor:
        return self.structure_encoder.get_repr(proteins, batch_size, verbose)

    def get_protein_repr(
        self, proteins: list, batch_size: int = 64, verbose: bool = False
    ) -> torch.Tensor:
        return self.protein_encoder.get_repr(proteins, batch_size, verbose)

    def forward(
        self,
        protein_inputs: dict,
        text_inputs: dict,
        structure_inputs: dict = None,
    ):
        """
        Args:
            protein_inputs: A dictionary for protein encoder
            structure_inputs: A dictionary for structure encoder
            text_inputs   : A dictionary for text encoder
        """
        protein_repr, protein_mask_logits = self.protein_encoder(
            protein_inputs, self.use_mlm_loss
        )
        text_repr = self.text_encoder(text_inputs)

        outputs = [text_repr, protein_repr, protein_mask_logits]

        if self.structure_config is not None:
            structure_repr, structure_mask_logits = self.structure_encoder(
                structure_inputs, self.use_mlm_loss
            )
            outputs += [structure_repr, structure_mask_logits]

        return outputs

    def loss_func(self, stage: str, outputs, labels):
        if self.structure_config is not None:
            (
                text_repr,
                protein_repr,
                protein_mask_logits,
                structure_repr,
                structure_mask_logits,
            ) = outputs
        else:
            text_repr, protein_repr, protein_mask_logits = outputs

        device = text_repr.device

        text_repr = normalize(text_repr, dim=-1)
        protein_repr = normalize(protein_repr, dim=-1)

        # Gather representations from all GPUs
        all_protein_repr = (
            self.all_gather(protein_repr)
            .view(-1, protein_repr.shape[-1])
            .detach()
        )
        all_text_repr = (
            self.all_gather(text_repr).view(-1, text_repr.shape[-1]).detach()
        )

        if self.structure_config is not None:
            structure_repr = normalize(structure_repr, dim=-1)
            all_structure_repr = (
                self.all_gather(structure_repr)
                .view(-1, structure_repr.shape[-1])
                .detach()
            )

        # text_idx = labels["text_idx"]
        # text_candidates = labels["text_candidates"]
        #
        # # Gather all text ids
        # text_inds = self.all_gather(text_idx).flatten()
        # # Create text classification labels
        # text_labels = torch.zeros(len(text_candidates), len(text_inds), dtype=int).to(device)
        # for i, candidate in enumerate(text_candidates):
        #     for j, idx in enumerate(text_inds):
        #         if idx.item() in candidate:
        #             text_labels[i, j] = 1
        #
        # # Gather text labels from all GPUs
        # text_labels = self.all_gather(text_labels).view(-1, text_labels.shape[-1])
        #
        # # Protein classification labels are the transpose of text labels
        # protein_labels = text_labels.T

        # Batch size
        rank = dist.get_rank()
        bs = text_repr.shape[0]

        # Get current labels
        # protein_labels = protein_labels[rank * bs: rank * bs + bs]
        # text_labels = text_labels[rank * bs: rank * bs + bs]

        # Create classification labels between structure and sequence
        bs_labels = torch.linspace(
            rank * bs, rank * bs + bs - 1, bs, dtype=int
        ).to(device)

        if self.structure_config is not None:
            pairs = {
                "protein": ["structure", "text"],
                "structure": ["protein", "text"],
                "text": ["protein", "structure"],
            }
        else:
            pairs = {"protein": ["text"], "text": ["protein"]}

        loss_list = []
        for k, values in pairs.items():
            for v in values:
                # Only calculate the similarity for the current batch
                sim = torch.matmul(
                    eval(f"{k}_repr"), eval(f"all_{v}_repr").T
                ).div(self.temperature)

                # if k == "text":
                #     if self.use_zlpr_loss:
                #         loss = multilabel_cross_entropy(sim, protein_labels)
                #     else:
                #         loss = cross_entropy(sim, bs_labels)
                #
                #     pred = []
                #     for s, l in zip(sim, protein_labels):
                #         n_label = l.sum()
                #         topk = torch.topk(s, k=n_label).indices
                #         if l[topk].sum() == n_label:
                #             pred.append(1)
                #         else:
                #             pred.append(0)
                #
                #     pred = torch.tensor(pred).to(device)
                #     label = torch.ones_like(pred)
                #     self.metrics[stage][f"{stage}_{k}_{v}_acc"].update(pred.detach(), label)
                #     # if v == "protein":
                #     #     acc = self.metrics[stage][f"{stage}_{k}_{v}_acc"].compute()
                #     #     print(f"{stage}_{k}_{v}_acc: {acc:.4f}")
                #
                # elif v == "text":
                #     if self.use_zlpr_loss:
                #         loss = multilabel_cross_entropy(sim, text_labels)
                #     else:
                #         loss = cross_entropy(sim, bs_labels)
                #
                #     pred = []
                #     for s, l in zip(sim, text_labels):
                #         n_label = l.sum()
                #         topk = torch.topk(s, k=n_label).indices
                #         if l[topk].sum() == n_label:
                #             pred.append(1)
                #         else:
                #             pred.append(0)
                #
                #     pred = torch.tensor(pred).to(device)
                #     label = torch.ones_like(pred)
                #     # if k == "protein":
                #     #     acc = pred.sum() / len(pred)
                #     #     print(f"{stage}_{k}_{v}_acc: {acc:.4f}")
                #     self.metrics[stage][f"{stage}_{k}_{v}_acc"].update(pred.detach(), label)
                #
                # else:
                #     loss = cross_entropy(sim, bs_labels)
                #     self.metrics[stage][f"{stage}_{k}_{v}_acc"].update(sim.detach(), bs_labels)

                loss = cross_entropy(sim, bs_labels)
                self.metrics[stage][f"{stage}_{k}_{v}_acc"].update(
                    sim.detach(), bs_labels
                )
                loss_list.append(loss)

        # Masked language modeling loss
        if self.use_mlm_loss:
            k_label = [("protein", labels["seq_labels"])]
            if self.structure_config is not None:
                k_label.append(("structure", labels["struc_labels"]))

            for k, label in k_label:
                logits = eval(f"{k}_mask_logits")
                # merge the first and second dimension of logits
                logits = logits.view(-1, logits.shape[-1])
                label = label.flatten().to(device)
                mlm_loss = cross_entropy(logits, label, ignore_index=-1)
                loss_list.append(mlm_loss)
                self.metrics[stage][f"{stage}_{k}_mask_acc"].update(
                    logits.detach(), label
                )

        loss = sum(loss_list) / len(loss_list)

        if stage == "train":
            log_dict = self.get_log_dict("train")
            log_dict["train_loss"] = loss
            self.log_info(log_dict)

            # Reset train metrics
            self.reset_metrics("train")

        return loss

    def padded_gather(self, tensor: torch.Tensor):
        """
        Gather tensors from all GPUs, allowing different shapes at the batch dimension.
        """

        # Get the size of the tensor
        size = tensor.shape[0]
        all_sizes = self.all_gather(torch.tensor(size, device=tensor.device))
        max_size = max(all_sizes)

        # Pad the tensor
        if size != max_size:
            tmp = torch.zeros(
                max_size,
                tensor.shape[-1],
                dtype=tensor.dtype,
                device=tensor.device,
            )
            tmp[:size] = tensor
            tensor = tmp

        padded_tensor = self.all_gather(tensor).view(-1, tensor.shape[-1])
        tensor = padded_tensor[: sum(all_sizes)]

        return tensor

    def _get_protein_indices(self):
        world_size = dist.get_world_size()
        rank = dist.get_rank()

        if self.use_saprot:
            proteins = []
            for sub_dict in self.uniprot2label.values():
                aa_seq = sub_dict["seq"]
                foldseek_seq = sub_dict["foldseek"]
                assert len(aa_seq) == len(foldseek_seq)
                seq = "".join([a + b for a, b in zip(aa_seq, foldseek_seq)])
                proteins.append(seq)

        else:
            proteins = [
                sub_dict["seq"] for sub_dict in self.uniprot2label.values()
            ]

        span = math.ceil(len(proteins) / world_size)
        sub_proteins = proteins[rank * span : (rank + 1) * span]

        # Display the progress bar on the rank 0 process
        verbose = self.trainer.local_rank == 0
        # Get protein representations
        sub_protein_repr = self.protein_encoder.get_repr(
            sub_proteins, batch_size=1, verbose=verbose
        )
        protein_repr = self.padded_gather(sub_protein_repr)

        # Construct faiss index
        d = protein_repr.shape[-1]
        protein_indices = faiss.IndexFlatIP(d)
        protein_indices.add(protein_repr.cpu().numpy())
        return protein_indices

    def _get_structure_indices(self):
        world_size = dist.get_world_size()
        rank = dist.get_rank()

        proteins = [
            sub_dict["foldseek"] for sub_dict in self.uniprot2label.values()
        ]
        span = math.ceil(len(proteins) / world_size)
        sub_proteins = proteins[rank * span : (rank + 1) * span]

        # Display the progress bar on the rank 0 process
        verbose = self.trainer.local_rank == 0
        # Get protein representations
        sub_protein_repr = self.structure_encoder.get_repr(
            sub_proteins, batch_size=1, verbose=verbose
        )
        protein_repr = self.padded_gather(sub_protein_repr)

        # Construct faiss index
        d = protein_repr.shape[-1]
        structure_indices = faiss.IndexFlatIP(d)
        structure_indices.add(protein_repr.cpu().numpy())
        return structure_indices

    def _get_text_indices(self):
        world_size = dist.get_world_size()
        rank = dist.get_rank()

        # Display the progress bar on the rank 0 process
        verbose = self.trainer.local_rank == 0
        if verbose:
            iterator = tqdm(
                self.label2text.keys(), desc="Get text representations"
            )
        else:
            iterator = self.label2text.keys()

        text_embeddings = {}
        for subsection in iterator:
            if subsection == "Total":
                continue

            texts = []
            for text_list in self.label2text[subsection].values():
                # Only use the first text for efficiency
                texts.append(text_list[0:1])

            span = math.ceil(len(texts) / world_size)
            texts = texts[rank * span : (rank + 1) * span]
            embeddings = []
            for text_list in texts:
                text_repr = self.text_encoder.get_repr(text_list)
                mean_repr = text_repr.mean(dim=0, keepdim=True)
                norm_repr = torch.nn.functional.normalize(mean_repr, dim=-1)
                embeddings.append(norm_repr)

            if len(embeddings) > 0:
                embeddings = torch.cat(embeddings, dim=0)
            else:
                embeddings = torch.zeros(
                    0, self.repr_dim, dtype=self.dtype, device=self.device
                )

            text_repr = self.padded_gather(embeddings)
            text_embeddings[subsection] = text_repr

        # Aggregate text embeddings for global retrieval
        total_embeddings = []
        for idx in self.label2text["Total"].values():
            subsection, i = idx.split("|")
            total_embeddings.append(text_embeddings[subsection][int(i)])

        text_embeddings["Total"] = torch.stack(total_embeddings)

        # Construct faiss index
        text_indices = {}
        for subsection, text_repr in text_embeddings.items():
            d = text_repr.shape[-1]
            text_indices[subsection] = faiss.IndexFlatIP(d)
            text_indices[subsection].add(text_repr.cpu().numpy())

        return text_indices

    def _protein2text(self, modality: str, protein_indices, text_indices: dict):
        def do(process_id, idx, row, writer):
            subsection, uniprot_id, prob_idx, label = row

            # Retrieve ranking results
            p_embedding = protein_indices.reconstruct(prob_idx).reshape(1, -1)
            text_inds = text_indices[subsection]
            sim_scores, rank_inds = text_inds.search(
                p_embedding, text_inds.ntotal
            )
            sim_scores, rank_inds = sim_scores[0], rank_inds[0]

            # Calculate Average Precision(AP)
            ranks = []
            label = set(label)
            for i, rk in enumerate(rank_inds):
                # Find the rank of this label in all labels
                if rk in label:
                    ranks.append(i + 1)

            ranks = np.array(ranks)
            ap = np.mean([(i + 1) / rank for i, rank in enumerate(ranks)])

            # Calculate Mean Reciprocal Rank(MRR)
            best_rank = ranks[0]
            mrr = 1 / best_rank

            # Calculate the AUC
            true_labels = np.zeros_like(sim_scores)
            true_labels[ranks - 1] = 1
            if (
                true_labels.sum() == 0
                or true_labels.sum() == true_labels.shape[0]
            ):
                auc = 0
            else:
                auc = roc_auc_score(true_labels, sim_scores)

            output = json.dumps([ap, mrr, auc])
            writer.write(output + "\n")

        inputs = []
        swissprot_subsections = set()
        for subsection in text_indices.keys():
            for i, (uniprot_id, labels) in enumerate(
                self.uniprot2label.items()
            ):
                if uniprot_id in self.swissprot_ids:
                    if subsection in labels:
                        swissprot_subsections.add(subsection)
                        label = labels[subsection]
                        inputs.append((subsection, uniprot_id, i, label))

        # Randomly shuffle the inputs
        random.seed(20000812)
        random.shuffle(inputs)

        # Split inputs into chunks for parallel processing
        world_size = dist.get_world_size()
        rank = dist.get_rank()

        span = math.ceil(len(inputs) / world_size)
        sub_inputs = inputs[rank * span : (rank + 1) * span]

        # Display the progress bar on the rank 0 process
        verbose = self.trainer.local_rank == 0
        if verbose:
            print("Evaluating on each subsection...")
        tmp_path = (
            f"/sujin/PycharmProjects/Pretraining/{time.time()}_{rank}.tsv"
        )
        mpr = MultipleProcessRunnerSimplifier(
            sub_inputs,
            do,
            save_path=tmp_path,
            n_process=8,
            verbose=verbose,
            return_results=True,
        )
        outputs = mpr.run()
        os.remove(tmp_path)

        # Aggregate results
        tensor_outputs = []
        for output in outputs:
            ap, mrr, auc = json.loads(output)
            tensor_outputs.append([float(ap), float(mrr), float(auc)])

        tensor_outputs = torch.tensor(
            tensor_outputs, dtype=torch.float32, device=self.device
        )
        tensor_outputs = self.padded_gather(tensor_outputs)

        # Record results
        avg_results = {}
        for subsection in swissprot_subsections:
            avg_results[subsection] = {"map": [], "mrr": [], "auc": []}

        for input, output in zip(inputs, tensor_outputs):
            ap, mrr, auc = output
            subsection, _, _, _ = input

            avg_results[subsection]["map"].append(ap.cpu().item())
            avg_results[subsection]["mrr"].append(mrr.cpu().item())
            avg_results[subsection]["auc"].append(auc.cpu().item())

        results = {
            f"{modality}2Text_Total_mrr": np.mean(avg_results["Total"]["mrr"]),
            f"{modality}2Text_Total_map": np.mean(avg_results["Total"]["map"]),
            f"{modality}2Text_Total_auc": np.mean(avg_results["Total"]["auc"]),
        }

        # Average the precision and recall for each level
        for level, labels in [
            ("residue-level", residue_level),
            ("sequence-level", sequence_level),
            ("all", residue_level | sequence_level),
        ]:
            mrrs = []
            maps = []
            aucs = []
            for subsection in labels:
                if subsection in avg_results:
                    mrrs.append(np.mean(avg_results[subsection]["mrr"]))
                    maps.append(np.mean(avg_results[subsection]["map"]))
                    aucs.append(np.mean(avg_results[subsection]["auc"]))

            results[f"{modality}2Text_{level}_mrr"] = np.mean(mrrs)
            results[f"{modality}2Text_{level}_map"] = np.mean(maps)
            results[f"{modality}2Text_{level}_auc"] = np.mean(aucs)

        return results

    def _text2protein(self, modality: str, protein_indices, text_indices: dict):
        def do(process_id, idx, row, writer):
            subsection, text_id, label = row

            # Retrieve ranking results
            t_embedding = (
                text_indices[subsection].reconstruct(text_id).reshape(1, -1)
            )
            sim_scores, rank_inds = protein_indices.search(
                t_embedding, protein_indices.ntotal
            )
            sim_scores, rank_inds = sim_scores[0], rank_inds[0]

            # Calculate Average Precision(AP)
            ranks = []
            label = set(label)
            for i, rk in enumerate(rank_inds):
                # Find the rank of this label in all labels
                if rk in label:
                    ranks.append(i + 1)

            ranks = np.array(ranks)
            ap = np.mean([(i + 1) / rank for i, rank in enumerate(ranks)])

            # Calculate Mean Reciprocal Rank(MRR)
            best_rank = ranks[0]
            mrr = 1 / best_rank

            # Calculate the AUC
            true_labels = np.zeros_like(sim_scores)
            true_labels[ranks - 1] = 1
            if (
                true_labels.sum() == 0
                or true_labels.sum() == true_labels.shape[0]
            ):
                auc = 0
            else:
                auc = roc_auc_score(true_labels, sim_scores)

            output = json.dumps([ap, mrr, auc])
            writer.write(output + "\n")

        text2label = {}
        swissprot_subsections = set()
        for i, (uniprot_id, subsections) in enumerate(
            self.uniprot2label.items()
        ):
            # Only evaluate the texts in Swiss-Prot
            if uniprot_id not in self.swissprot_ids:
                continue

            for subsection, text_ids in subsections.items():
                if subsection == "seq" or subsection == "foldseek":
                    continue

                swissprot_subsections.add(subsection)
                if subsection not in text2label:
                    text2label[subsection] = {}

                for text_id in text_ids:
                    text2label[subsection][text_id] = text2label[
                        subsection
                    ].get(text_id, []) + [i]

        inputs = []
        for subsection in swissprot_subsections:
            for i, (text_id, label) in enumerate(
                text2label[subsection].items()
            ):
                inputs.append((subsection, text_id, label))

        # Randomly shuffle the inputs
        random.seed(20000812)
        random.shuffle(inputs)

        # Split inputs into chunks for parallel processing
        world_size = dist.get_world_size()
        rank = dist.get_rank()

        span = math.ceil(len(inputs) / world_size)
        sub_inputs = inputs[rank * span : (rank + 1) * span]

        # Display the progress bar on the rank 0 process
        verbose = self.trainer.local_rank == 0
        if verbose:
            print("Evaluating on each text...")

        # Add time stamp to the temporary file name to avoid conflicts
        tmp_path = (
            f"/sujin/PycharmProjects/Pretraining/{time.time()}_{rank}.tsv"
        )
        mpr = MultipleProcessRunnerSimplifier(
            sub_inputs,
            do,
            save_path=tmp_path,
            n_process=8,
            verbose=verbose,
            return_results=True,
        )
        outputs = mpr.run()
        os.remove(tmp_path)

        # Aggregate results
        tensor_outputs = []
        for output in outputs:
            ap, mrr, auc = json.loads(output)
            tensor_outputs.append([float(ap), float(mrr), float(auc)])

        tensor_outputs = torch.tensor(
            tensor_outputs, dtype=torch.float32, device=self.device
        )
        tensor_outputs = self.padded_gather(tensor_outputs)

        # Record results
        avg_results = {}
        for subsection in swissprot_subsections:
            avg_results[subsection] = {"map": [], "mrr": [], "auc": []}

        for input, output in zip(inputs, tensor_outputs):
            ap, mrr, auc = output
            subsection, _, _ = input

            avg_results[subsection]["map"].append(ap.cpu().item())
            avg_results[subsection]["mrr"].append(mrr.cpu().item())
            avg_results[subsection]["auc"].append(auc.cpu().item())

        results = {
            f"Text2{modality}_Total_mrr": np.mean(avg_results["Total"]["mrr"]),
            f"Text2{modality}_Total_map": np.mean(avg_results["Total"]["map"]),
            f"Text2{modality}_Total_auc": np.mean(avg_results["Total"]["auc"]),
        }

        # Average the precision and recall for each level
        for level, labels in [
            ("residue-level", residue_level),
            ("sequence-level", sequence_level),
            ("all", residue_level | sequence_level),
        ]:
            mrrs = []
            maps = []
            aucs = []
            for subsection in labels:
                if subsection in avg_results:
                    mrrs.append(np.mean(avg_results[subsection]["mrr"]))
                    maps.append(np.mean(avg_results[subsection]["map"]))
                    aucs.append(np.mean(avg_results[subsection]["auc"]))

            results[f"Text2{modality}_{level}_mrr"] = np.mean(mrrs)
            results[f"Text2{modality}_{level}_map"] = np.mean(maps)
            results[f"Text2{modality}_{level}_auc"] = np.mean(aucs)

        return results

    def retrieval_eval(self) -> dict:
        # Get protein representations
        protein_indices = self._get_protein_indices()

        # Get structure representations
        # if self.structure_config is not None:
        #     structure_embeddings = self._get_structure_embeddings()

        # Get text representations
        text_indices = self._get_text_indices()

        # Retrieve texts for each protein
        results = {}
        results.update(
            self._protein2text("Sequence", protein_indices, text_indices)
        )
        # if self.structure_config is not None:
        #     results.update(self._protein2text("Structure", structure_embeddings, text_embeddings))
        #     results.update(self._text2protein("Structure", structure_embeddings, text_embeddings))

        # Retrieve proteins for each text
        results.update(
            self._text2protein("Sequence", protein_indices, text_indices)
        )

        return results

    def _apply_bert_mask(self, tokens, tokenizer, mask_ratio):
        while True:
            masked_tokens = copy.copy(tokens)
            labels = torch.full((len(tokens) + 2,), -1, dtype=torch.long)
            vocab = [k for k in tokenizer.get_vocab().keys()]

            for i in range(len(tokens)):
                token = tokens[i]

                prob = random.random()
                if prob < mask_ratio:
                    prob /= mask_ratio
                    labels[i + 1] = tokenizer.convert_tokens_to_ids(token)

                    if prob < 0.8:
                        # 80% random change to mask token
                        if self.use_saprot:
                            token = "#" + token[-1]
                        else:
                            token = tokenizer.mask_token
                    elif prob < 0.9:
                        # 10% chance to change to random token
                        token = random.choice(vocab)
                    else:
                        # 10% chance to keep current token
                        pass

                    masked_tokens[i] = token

            # Check if there is at least one masked token
            if (labels != -1).any():
                return masked_tokens, labels

    def mlm_eval(self) -> float:
        world_size = dist.get_world_size()
        rank = dist.get_rank()

        if self.use_saprot:
            proteins = []
            for sub_dict in self.uniprot2label.values():
                aa_seq = sub_dict["seq"]
                foldseek_seq = sub_dict["foldseek"]
                assert len(aa_seq) == len(foldseek_seq)
                seq = "".join([a + b for a, b in zip(aa_seq, foldseek_seq)])
                proteins.append(seq)

        else:
            proteins = [
                sub_dict["seq"] for sub_dict in self.uniprot2label.values()
            ]

        span = math.ceil(len(proteins) / world_size)
        sub_proteins = proteins[rank * span : (rank + 1) * span]

        # Display the progress bar on the rank 0 process
        if self.trainer.local_rank == 0:
            iterator = tqdm(sub_proteins, desc="Computing mlm...")
        else:
            iterator = sub_proteins

        total = torch.tensor([0], dtype=torch.long, device=self.device)
        correct = torch.tensor([0], dtype=torch.long, device=self.device)
        for seq in iterator:
            tokens = self.protein_encoder.tokenizer.tokenize(seq)
            masked_tokens, labels = self._apply_bert_mask(
                tokens, self.protein_encoder.tokenizer, 0.15
            )
            seq = " ".join(masked_tokens)

            inputs = self.protein_encoder.tokenizer(seq, return_tensors="pt")
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            _, logits = self.protein_encoder(inputs, get_mask_logits=True)

            logits = logits.squeeze(0)
            labels = labels.to(self.device)

            selecor = labels != -1
            preds = logits.argmax(dim=-1)[selecor]
            labels = labels[selecor]

            total += len(preds)
            correct += (preds == labels).sum()

        # Gather all results
        total = self.padded_gather(total).sum()
        correct = self.padded_gather(correct).sum()

        acc = correct / total
        return acc.cpu().item()

    def _load_eval_data(self, stage):
        # Load the data
        lmdb_dir = eval(f"self.trainer.datamodule.{stage}_lmdb")
        uniprot2label_path = os.path.join(lmdb_dir, "uniprot2label.json")
        label2text_path = os.path.join(lmdb_dir, "label2text.json")
        swissprot_id_path = os.path.join(lmdb_dir, "swissprot_ids.tsv")

        self.uniprot2label = json.load(open(uniprot2label_path, "r"))
        self.label2text = json.load(open(label2text_path, "r"))
        self.swissprot_ids = set(
            pd.read_csv(swissprot_id_path, sep="\t", header=None)
            .values.flatten()
            .tolist()
        )
        self.k = 3

    def on_test_start(self):
        self._load_eval_data("test")

        log_dict = self.retrieval_eval()
        log_dict = {"test_" + k: v for k, v in log_dict.items()}
        if self.use_mlm_loss:
            log_dict["test_mask_acc"] = self.mlm_eval()
        self.log_info(log_dict)
        print(log_dict)

    def on_validation_start(self):
        # Clear the cache
        torch.cuda.empty_cache()

        self._load_eval_data("valid")

        log_dict = self.retrieval_eval()
        log_dict = {"valid_" + k: v for k, v in log_dict.items()}
        if self.use_mlm_loss:
            log_dict["valid_mask_acc"] = self.mlm_eval()
        self.log_info(log_dict)

        self.check_save_condition(self.step, mode="max")

    def test_step(self, batch, batch_idx):
        return

    def validation_step(self, batch, batch_idx):
        return

    def on_train_epoch_end(self):
        super().on_train_epoch_end()
        # Re-sample the subset of the training data
        if self.trainer.datamodule.train_dataset.fixed_dataset_num is not None:
            self.trainer.datamodule.train_dataset.sample_subset()

    # def test_epoch_end(self, outputs):
    #     log_dict = self.get_log_dict("test")
    #     log_dict["test_loss"] = torch.cat(self.all_gather(outputs), dim=-1).mean()
    #
    #     print(log_dict)
    #     self.log_info(log_dict)
    #
    #     self.reset_metrics("test")
    #
    # def validation_epoch_end(self, outputs):
    #     log_dict = self.get_log_dict("valid")
    #     log_dict["valid_loss"] = torch.cat(self.all_gather(outputs), dim=-1).mean()
    #
    #     self.log_info(log_dict)
    #     self.reset_metrics("valid")
    #     self.check_save_condition(log_dict["valid_loss"], mode="min")
