# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from IPython import embed as debug_embedded
import logging
import os
from collections.abc import Iterable
from sklearn.metrics import roc_auc_score
from xmlrpc.client import Boolean

import numpy as np
import torch
from tqdm import tqdm
from unicore import checkpoint_utils
import unicore
from unicore.data import (AppendTokenDataset, Dictionary, EpochShuffleDataset,
                          FromNumpyDataset, NestedDictionaryDataset,
                          PrependTokenDataset, RawArrayDataset,LMDBDataset, RawLabelDataset,
                          RightPadDataset, RightPadDataset2D, TokenizeDataset,SortDataset,data_utils)
from unicore.tasks import UnicoreTask, register_task
from unimol.data import (AffinityDataset, CroppingPocketDockingPoseDataset,CroppingPocketDockingPoseTestDataset,
                         CrossDistanceDataset, DistanceDataset,
                         EdgeTypeDataset, KeyDataset, LengthDataset,
                         NormalizeDataset, NormalizeDockingPoseDataset,
                         PrependAndAppend2DDataset, RemoveHydrogenDataset,
                         RemoveHydrogenPocketDataset, RightPadDatasetCoord,
                         RightPadDatasetCross2D, TTADockingPoseDataset, AffinityTestDataset, AffinityValidDataset, AffinityMolDataset, AffinityPocketDataset, ResamplingDataset, PocketFTDataset
                         ,FromStrLabelDataset,ConformerSamplePocketFinetuneDataset, RemoveHydrogenResiduePocketDataset, CroppingResiduePocketDataset)
#from skchem.metrics import bedroc_score
from rdkit.ML.Scoring.Scoring import CalcBEDROC, CalcAUC, CalcEnrichment
from sklearn.metrics import roc_curve
logger = logging.getLogger(__name__)
task_metainfo = {
    "Score": {
        "mean": -0.02113608960384876,
        "std": 0.14467607204629246,
    },
    "Druggability Score": {
        "mean": 0.04279187401338044,
        "std": 0.1338187819653573,
    },
    "Total SASA": {
        "mean": 118.7343246335413,
        "std": 59.82260887999069,
    },
    "Hydrophobicity score": {
        "mean": 16.824823092535517,
        "std": 18.16340833552264,
    },
}

def re_new(y_true, y_score, ratio):
    fp = 0
    tp = 0
    p = sum(y_true)
    n = len(y_true) - p
    num = ratio*n
    sort_index = np.argsort(y_score)[::-1]
    for i in range(len(sort_index)):
        index = sort_index[i]
        if y_true[index] == 1:
            tp += 1
        else:
            fp += 1
            if fp>= num:
                break
    return (tp*n)/(p*fp)


def calc_re(y_true, y_score, ratio_list):
    fpr, tpr, thresholds = roc_curve(y_true, y_score, pos_label=1)
    #print(fpr, tpr)
    res = {}
    res2 = {}
    total_active_compounds = sum(y_true)
    total_compounds = len(y_true)

    # for ratio in ratio_list:
    #     for i, t in enumerate(fpr):
    #         if t > ratio:
    #             #print(fpr[i], tpr[i])
    #             if fpr[i-1]==0:
    #                 res[str(ratio)]=tpr[i]/fpr[i]
    #             else:
    #                 res[str(ratio)]=tpr[i-1]/fpr[i-1]
    #             break
    
    for ratio in ratio_list:
        res2[str(ratio)] = re_new(y_true, y_score, ratio)

    #print(res)
    #print(res2)
    return res2

def cal_metrics(y_true, y_score, alpha):
    """
    Calculate BEDROC score.

    Parameters:
    - y_true: true binary labels (0 or 1)
    - y_score: predicted scores or probabilities
    - alpha: parameter controlling the degree of early retrieval emphasis

    Returns:
    - BEDROC score
    """
    
        # concate res_single and labels
    scores = np.expand_dims(y_score, axis=1)
    y_true = np.expand_dims(y_true, axis=1)
    scores = np.concatenate((scores, y_true), axis=1)
    # inverse sort scores based on first column
    scores = scores[scores[:,0].argsort()[::-1]]
    bedroc = CalcBEDROC(scores, 1, 80.5)
    count = 0
    # sort y_score, return index
    index  = np.argsort(y_score)[::-1]
    for i in range(int(len(index)*0.005)):
        if y_true[index[i]] == 1:
            count += 1
    auc = CalcAUC(scores, 1)
    ef_list = CalcEnrichment(scores, 1, [0.005, 0.01, 0.02, 0.05])
    ef = {
        "0.005": ef_list[0],
        "0.01": ef_list[1],
        "0.02": ef_list[2],
        "0.05": ef_list[3]
    }
    re_list = calc_re(y_true, y_score, [0.005, 0.01, 0.02, 0.05])
    return auc, bedroc, ef, re_list



@register_task("pocket_ft")
class PocketFT(UnicoreTask):
    """Task for training transformer auto-encoder models."""

    @staticmethod
    def add_args(parser):
        """Add task-specific arguments to the parser."""
        parser.add_argument(
            "data",
            help="downstream data path",
        )
        parser.add_argument(
            "--finetune-mol-model",
            default=None,
            type=str,
            help="pretrained molecular model path",
        )
        parser.add_argument(
            "--finetune-pocket-model",
            default=None,
            type=str,
            help="pretrained pocket model path",
        )
        parser.add_argument(
            "--dist-threshold",
            type=float,
            default=6.0,
            help="threshold for the distance between the molecule and the pocket",
        )
        parser.add_argument(
            "--max-pocket-atoms",
            type=int,
            default=256,
            help="selected maximum number of atoms in a pocket",
        )
        parser.add_argument(
            "--test-model",
            default=False,
            type=Boolean,
            help="whether test model",
        )
        parser.add_argument(
            "--knn",
            default=False,
            type=Boolean,
            help="whether use knn",
        )
        parser.add_argument("--reg", action="store_true", help="regression task")
        parser.add_argument(
            "--fpocket-score",
            default="Druggability Score",
            help="Select one of the 4 Fpocket scores as the target",
            choices=[
                "Score",
                "Druggability Score",
                "Total SASA",
                "Hydrophobicity score",
            ],
        )

    def __init__(self, args, pocket_dictionary):
        super().__init__(args)
        self.pocket_dictionary = pocket_dictionary
        self.seed = args.seed
        # add mask token
        self.pocket_mask_idx = pocket_dictionary.add_symbol("[MASK]", is_special=True)
        #self.test_dataset = self.load_test_dataset()
        #self.valid_dataset = self.load_valid_dataset()
        self.mol_reps = None
        self.keys = None
        
        if self.args.fpocket_score in task_metainfo:
            # for regression task, pre-compute mean and std
            self.mean = task_metainfo[self.args.fpocket_score]["mean"]
            self.std = task_metainfo[self.args.fpocket_score]["std"]
        else:
            self.mean, self.std = None, None


    @classmethod
    def setup_task(cls, args, **kwargs):
        pocket_dictionary = Dictionary.load(os.path.join(args.data, "dict_pkt.txt"))
        logger.info("pocket dictionary: {} types".format(len(pocket_dictionary)))
        return cls(args, pocket_dictionary)

    def load_dataset(self, split, **kwargs):
        """Load a given dataset split.
        'smi','pocket','atoms','coordinates','pocket_atoms','pocket_coordinates','holo_coordinates','holo_pocket_coordinates','scaffold'
        Args:
            split (str): name of the data scoure (e.g., bppp)
        """
        #print(1,split)
        data_path = os.path.join(self.args.data, split + ".lmdb")
        dataset = LMDBDataset(data_path)
        tgt_dataset_inner = KeyDataset(dataset, "target")
        tgt_dataset = KeyDataset(tgt_dataset_inner, self.args.fpocket_score) # choose in ["Score", "Druggability Score", "Total SASA", "Hydrophobicity score"]
        tgt_dataset = FromStrLabelDataset(tgt_dataset)
        if split.startswith("train"):
            
            
            dataset = PocketFTDataset(
                dataset,
                self.args.seed,
                "atoms",
                "coordinates",
                "target",
                "pdbid",
                True,
            )
            poc_dataset = KeyDataset(dataset, "pocket")
            #tgt_dataset = KeyDataset(dataset, "target")
            # tgt_dataset_inner = KeyDataset(dataset, "target")
            # tgt_dataset = KeyDataset(tgt_dataset_inner, "Druggability Score")
            # tgt_dataset = FromStrLabelDataset(tgt_dataset)
            
        else:
            
            dataset = PocketFTDataset(
                dataset,
                self.args.seed,
                "atoms",
                "coordinates",
                "target",
            )
            
            poc_dataset = KeyDataset(dataset, "pocket")


        def PrependAndAppend(dataset, pre_token, app_token):
            dataset = PrependTokenDataset(dataset, pre_token)
            return AppendTokenDataset(dataset, app_token)

        dataset = RemoveHydrogenPocketDataset(
            dataset,
            "pocket_atoms",
            "pocket_coordinates",
            "holo_pocket_coordinates",
            True,
            True,
        )
        dataset = CroppingPocketDockingPoseDataset(
            dataset,
            self.seed,
            "pocket_atoms",
            "pocket_coordinates",
            "holo_pocket_coordinates",
            self.args.max_pocket_atoms,
        )

        # dataset = RemoveHydrogenPocketDataset(
        #     dataset, "atoms", "coordinates", "holo_coordinates", True, True
        # )


        apo_dataset = NormalizeDataset(dataset, "pocket_coordinates")


        src_pocket_dataset = KeyDataset(apo_dataset, "pocket_atoms")
        pocket_len_dataset = LengthDataset(src_pocket_dataset)
        src_pocket_dataset = TokenizeDataset(
            src_pocket_dataset,
            self.pocket_dictionary,
            max_seq_len=self.args.max_seq_len,
        )
        coord_pocket_dataset = KeyDataset(apo_dataset, "pocket_coordinates")
        src_pocket_dataset = PrependAndAppend(
            src_pocket_dataset,
            self.pocket_dictionary.bos(),
            self.pocket_dictionary.eos(),
        )
        pocket_edge_type = EdgeTypeDataset(
            src_pocket_dataset, len(self.pocket_dictionary)
        )
        coord_pocket_dataset = FromNumpyDataset(coord_pocket_dataset)
        distance_pocket_dataset = DistanceDataset(coord_pocket_dataset)
        coord_pocket_dataset = PrependAndAppend(coord_pocket_dataset, 0.0, 0.0)
        distance_pocket_dataset = PrependAndAppend2DDataset(
            distance_pocket_dataset, 0.0
        )

        nest_dataset = NestedDictionaryDataset(
            {
                "net_input": {
                    "pocket_src_tokens": RightPadDataset(
                        src_pocket_dataset,
                        pad_idx=self.pocket_dictionary.pad(),
                    ),
                    "pocket_src_distance": RightPadDataset2D(
                        distance_pocket_dataset,
                        pad_idx=0,
                    ),
                    "pocket_src_edge_type": RightPadDataset2D(
                        pocket_edge_type,
                        pad_idx=0,
                    ),
                    "pocket_src_coord": RightPadDatasetCoord(
                        coord_pocket_dataset,
                        pad_idx=0,
                    ),
                    "pocket_len": RawArrayDataset(pocket_len_dataset)
                },
                "target": {
                    "finetune_target": tgt_dataset,#RawLabelDataset(tgt_dataset),
                },
                "pocket_name": RawArrayDataset(poc_dataset),
            },
        )
        # if split.startswith("train"):
        #     nest_dataset = EpochShuffleDataset(
        #         nest_dataset, len(nest_dataset), self.args.seed
        #     )
        # self.datasets[split] = nest_dataset
        if split == "train":
            with data_utils.numpy_seed(self.args.seed):
                shuffle = np.random.permutation(len(src_pocket_dataset))

            self.datasets[split] = SortDataset(
                nest_dataset,
                sort_order=[shuffle],
            )
            self.datasets[split] = ResamplingDataset(
                self.datasets[split]
            )
        else:
            self.datasets[split] = nest_dataset
        #print(len(src_dataset))

    def load_dataset_knn(self, path, **kwargs):
        """Load a given dataset split.
        'smi','pocket','atoms','coordinates','pocket_atoms','pocket_coordinates','holo_coordinates','holo_pocket_coordinates','scaffold'
        Args:
            split (str): name of the data scoure (e.g., bppp)
        """
        #print(1,split)
        data_path = path
        dataset = LMDBDataset(data_path)
        tgt_dataset_inner = KeyDataset(dataset, "target")
        tgt_dataset = KeyDataset(tgt_dataset_inner, self.args.fpocket_score) # choose in ["Score", "Druggability Score", "Total SASA", "Hydrophobicity score"]
        tgt_dataset = FromStrLabelDataset(tgt_dataset)

            
        dataset = PocketFTDataset(
            dataset,
            self.args.seed,
            "atoms",
            "coordinates",
            "target",
        )
        
        poc_dataset = KeyDataset(dataset, "pocket")


        def PrependAndAppend(dataset, pre_token, app_token):
            dataset = PrependTokenDataset(dataset, pre_token)
            return AppendTokenDataset(dataset, app_token)

        dataset = RemoveHydrogenPocketDataset(
            dataset,
            "pocket_atoms",
            "pocket_coordinates",
            "holo_pocket_coordinates",
            True,
            True,
        )
        dataset = CroppingPocketDockingPoseDataset(
            dataset,
            self.seed,
            "pocket_atoms",
            "pocket_coordinates",
            "holo_pocket_coordinates",
            self.args.max_pocket_atoms,
        )

        # dataset = RemoveHydrogenPocketDataset(
        #     dataset, "atoms", "coordinates", "holo_coordinates", True, True
        # )


        apo_dataset = NormalizeDataset(dataset, "pocket_coordinates")


        src_pocket_dataset = KeyDataset(apo_dataset, "pocket_atoms")
        pocket_len_dataset = LengthDataset(src_pocket_dataset)
        src_pocket_dataset = TokenizeDataset(
            src_pocket_dataset,
            self.pocket_dictionary,
            max_seq_len=self.args.max_seq_len,
        )
        coord_pocket_dataset = KeyDataset(apo_dataset, "pocket_coordinates")
        src_pocket_dataset = PrependAndAppend(
            src_pocket_dataset,
            self.pocket_dictionary.bos(),
            self.pocket_dictionary.eos(),
        )
        pocket_edge_type = EdgeTypeDataset(
            src_pocket_dataset, len(self.pocket_dictionary)
        )
        coord_pocket_dataset = FromNumpyDataset(coord_pocket_dataset)
        distance_pocket_dataset = DistanceDataset(coord_pocket_dataset)
        coord_pocket_dataset = PrependAndAppend(coord_pocket_dataset, 0.0, 0.0)
        distance_pocket_dataset = PrependAndAppend2DDataset(
            distance_pocket_dataset, 0.0
        )

        nest_dataset = NestedDictionaryDataset(
            {
                "net_input": {
                    "pocket_src_tokens": RightPadDataset(
                        src_pocket_dataset,
                        pad_idx=self.pocket_dictionary.pad(),
                    ),
                    "pocket_src_distance": RightPadDataset2D(
                        distance_pocket_dataset,
                        pad_idx=0,
                    ),
                    "pocket_src_edge_type": RightPadDataset2D(
                        pocket_edge_type,
                        pad_idx=0,
                    ),
                    "pocket_src_coord": RightPadDatasetCoord(
                        coord_pocket_dataset,
                        pad_idx=0,
                    ),
                    "pocket_len": RawArrayDataset(pocket_len_dataset)
                },
                "target": {
                    "finetune_target": tgt_dataset,#RawLabelDataset(tgt_dataset),
                },
                "pocket_name": RawArrayDataset(poc_dataset),
            },
        )

        return nest_dataset

    

    

    

    def load_mols_dataset(self, data_path, **kwargs):
        #atom_key = 'atoms'
        #atom_key = 'atom_types'

        """Load a given dataset split.
        'smi','pocket','atoms','coordinates','pocket_atoms','pocket_coordinates','holo_coordinates','holo_pocket_coordinates','scaffold'
        Args:
            split (str): name of the data scoure (e.g., bppp)
        """
 
        dataset = LMDBDataset(data_path)
        #label_dataset = KeyDataset(dataset, "label")
        #key_dataset = KeyDataset(dataset, "inchikey")
        dataset = AffinityMolDataset(
            dataset,
            self.args.seed,
            "atoms",
            "coordinates",
            False,
        )
        
        smi_dataset = KeyDataset(dataset, "smi")

        def PrependAndAppend(dataset, pre_token, app_token):
            dataset = PrependTokenDataset(dataset, pre_token)
            return AppendTokenDataset(dataset, app_token)



        dataset = RemoveHydrogenDataset(dataset, "atoms", "coordinates", True, True)


        apo_dataset = NormalizeDataset(dataset, "coordinates")

        src_dataset = KeyDataset(apo_dataset, "atoms")
        len_dataset = LengthDataset(src_dataset)
        src_dataset = TokenizeDataset(
            src_dataset, self.dictionary, max_seq_len=self.args.max_seq_len
        )
        coord_dataset = KeyDataset(apo_dataset, "coordinates")
        src_dataset = PrependAndAppend(
            src_dataset, self.dictionary.bos(), self.dictionary.eos()
        )
        edge_type = EdgeTypeDataset(src_dataset, len(self.dictionary))
        coord_dataset = FromNumpyDataset(coord_dataset)
        distance_dataset = DistanceDataset(coord_dataset)
        coord_dataset = PrependAndAppend(coord_dataset, 0.0, 0.0)
        distance_dataset = PrependAndAppend2DDataset(distance_dataset, 0.0)


        nest_dataset = NestedDictionaryDataset(
            {
                "net_input": {
                    "mol_src_tokens": RightPadDataset(
                        src_dataset,
                        pad_idx=self.dictionary.pad(),
                    ),
                    "mol_src_distance": RightPadDataset2D(
                        distance_dataset,
                        pad_idx=0,
                    ),
                    "mol_src_edge_type": RightPadDataset2D(
                        edge_type,
                        pad_idx=0,
                    ),
                },
                "smi_name": RawArrayDataset(smi_dataset),
                #"target":  RawArrayDataset(label_dataset),
                "mol_len": RawArrayDataset(len_dataset),
                #"key": RawArrayDataset(key_dataset),
            },
        )
        return nest_dataset
    
    

    def load_pockets_dataset(self, data_path, **kwargs):
        """Load a given dataset split.
        'smi','pocket','atoms','coordinates','pocket_atoms','pocket_coordinates','holo_coordinates','holo_pocket_coordinates','scaffold'
        Args:
            split (str): name of the data scoure (e.g., bppp)
        """

        dataset = LMDBDataset(data_path)
 
        dataset = AffinityPocketDataset(
            dataset,
            self.args.seed,
            "pocket_atoms",
            "pocket_coordinates",
            False,
            "pocket"
        )
        poc_dataset = KeyDataset(dataset, "pocket")

        def PrependAndAppend(dataset, pre_token, app_token):
            dataset = PrependTokenDataset(dataset, pre_token)
            return AppendTokenDataset(dataset, app_token)

        dataset = RemoveHydrogenPocketDataset(
            dataset,
            "pocket_atoms",
            "pocket_coordinates",
            "holo_pocket_coordinates",
            True,
            True,
        )
        dataset = CroppingPocketDockingPoseTestDataset(
            dataset,
            self.seed,
            "pocket_atoms",
            "pocket_coordinates",
            "holo_pocket_coordinates",
            self.args.max_pocket_atoms,
        )




        apo_dataset = NormalizeDataset(dataset, "pocket_coordinates")



        src_pocket_dataset = KeyDataset(apo_dataset, "pocket_atoms")
        len_dataset = LengthDataset(src_pocket_dataset)
        src_pocket_dataset = TokenizeDataset(
            src_pocket_dataset,
            self.pocket_dictionary,
            max_seq_len=self.args.max_seq_len,
        )
        coord_pocket_dataset = KeyDataset(apo_dataset, "pocket_coordinates")
        src_pocket_dataset = PrependAndAppend(
            src_pocket_dataset,
            self.pocket_dictionary.bos(),
            self.pocket_dictionary.eos(),
        )
        pocket_edge_type = EdgeTypeDataset(
            src_pocket_dataset, len(self.pocket_dictionary)
        )
        coord_pocket_dataset = FromNumpyDataset(coord_pocket_dataset)
        distance_pocket_dataset = DistanceDataset(coord_pocket_dataset)
        coord_pocket_dataset = PrependAndAppend(coord_pocket_dataset, 0.0, 0.0)
        distance_pocket_dataset = PrependAndAppend2DDataset(
            distance_pocket_dataset, 0.0
        )

        nest_dataset = NestedDictionaryDataset(
            {
                "net_input": {
                    "pocket_src_tokens": RightPadDataset(
                        src_pocket_dataset,
                        pad_idx=self.pocket_dictionary.pad(),
                    ),
                    "pocket_src_distance": RightPadDataset2D(
                        distance_pocket_dataset,
                        pad_idx=0,
                    ),
                    "pocket_src_edge_type": RightPadDataset2D(
                        pocket_edge_type,
                        pad_idx=0,
                    ),
                    "pocket_src_coord": RightPadDatasetCoord(
                        coord_pocket_dataset,
                        pad_idx=0,
                    ),
                },
                "pocket_name": RawArrayDataset(poc_dataset),
                "pocket_len": RawArrayDataset(len_dataset),
            },
        )
        return nest_dataset

    

    def build_model(self, args):
        from unicore import models

        model = models.build_model(args, self)
        
        if args.finetune_pocket_model is not None:
            if args.finetune_pocket_model.endswith("pocket_model.pt"):
                import torch
                model.pocket_model.load_state_dict(torch.load(args.finetune_pocket_model, map_location='cpu'), strict=False)
            else:
                state = checkpoint_utils.load_checkpoint_to_cpu(
                    args.finetune_pocket_model,
                )
                model.pocket_model.load_state_dict(state["model"], strict=False)
            #model.pocket_colbert_model.load_state_dict(state["model"], strict=False)
        

        return model

    def train_step(
        self, sample, model, loss, optimizer, update_num, ignore_grad=False
    ):
        """
        Do forward and backward, and return the loss as computed by *loss*
        for the given *model* and *sample*.

        Args:
            sample (dict): the mini-batch. The format is defined by the
                :class:`~unicore.data.UnicoreDataset`.
            model (~unicore.models.BaseUnicoreModel): the model
            loss (~unicore.losses.UnicoreLoss): the loss
            optimizer (~unicore.optim.UnicoreOptimizer): the optimizer
            update_num (int): the current update
            ignore_grad (bool): multiply loss by 0 if this is set to True

        Returns:
            tuple:
                - the loss
                - the sample size, which is used as the denominator for the
                  gradient
                - logging outputs to display while training
        """
        if self.args.knn == True:
            model.eval()
            self.test_knn(model)
            exit()

        
            
        model.train()
        model.set_num_updates(update_num)
        with torch.autograd.profiler.record_function("forward"):
            loss, sample_size, logging_output = loss(model, sample)
        if ignore_grad:
            loss *= 0
        with torch.autograd.profiler.record_function("backward"):
            optimizer.backward(loss)
        return loss, sample_size, logging_output
    
    def valid_step(self, sample, model, loss, test=False):
        model.eval()
        with torch.no_grad():
            loss, sample_size, logging_output = loss(model, sample)
        return loss, sample_size, logging_output

    
    

    def test_knn(self, model, **kwargs):

        import pickle
        ckpt_name = self.args.finetune_from_model
        print(ckpt_name)
        ckpt_name = ckpt_name.split("/")[-2]
        save_base = "/save/base"
        test_path = "/test/path"
        pocket_dataset = self.load_dataset_knn(test_path)
        pocket_data = torch.utils.data.DataLoader(pocket_dataset, batch_size=32, collate_fn=pocket_dataset.collater)
        labels = []
        test_pocket_reps = []
        save_path = save_base + ckpt_name + "_test_pocket_reps.pkl"
        if os.path.exists(save_path):
            with open(save_path, "rb") as f:
                test_pocket_reps = pickle.load(f)
            for _, sample in enumerate(tqdm(pocket_data)):
                labels.extend(sample["target"]["finetune_target"].detach().cpu().numpy())
        else:
            for _, sample in enumerate(tqdm(pocket_data)):
                sample = unicore.utils.move_to_cuda(sample)
                dist = sample["net_input"]["pocket_src_distance"]
                et = sample["net_input"]["pocket_src_edge_type"]
                st = sample["net_input"]["pocket_src_tokens"]
                pocket_padding_mask = st.eq(model.pocket_model.padding_idx)
                pocket_x = model.pocket_model.embed_tokens(st)
                n_node = dist.size(-1)
                gbf_feature = model.pocket_model.gbf(dist, et)
                gbf_result = model.pocket_model.gbf_proj(gbf_feature)
                graph_attn_bias = gbf_result
                graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous()
                graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
                pocket_outputs = model.pocket_model.encoder(
                    pocket_x, padding_mask=pocket_padding_mask, attn_mask=graph_attn_bias
                )
                pocket_encoder_rep = pocket_outputs[0][:,0,:].detach().cpu().numpy()

                test_pocket_reps.append(pocket_encoder_rep)
                labels.extend(sample["target"]["finetune_target"].detach().cpu().numpy())
            test_pocket_reps = np.concatenate(test_pocket_reps, axis=0)
            with open(save_path, "wb") as f:
                pickle.dump(test_pocket_reps, f)
        range = max(labels) - min(labels)
        print(range)

        train_path = "/train/path"


        pocket_dataset = self.load_dataset_knn(train_path)
        pocket_data = torch.utils.data.DataLoader(pocket_dataset, batch_size=32, collate_fn=pocket_dataset.collater)
        train_pocket_reps = []
        train_labels = []
        save_path = save_base + ckpt_name + "_train_pocket_reps.pkl"
        if os.path.exists(save_path):
            with open(save_path, "rb") as f:
                train_pocket_reps = pickle.load(f)
            for _, sample in enumerate(tqdm(pocket_data)):
                train_labels.extend(sample["target"]["finetune_target"].detach().cpu().numpy())
        else:
            for _, sample in enumerate(tqdm(pocket_data)):
                sample = unicore.utils.move_to_cuda(sample)
                dist = sample["net_input"]["pocket_src_distance"]
                et = sample["net_input"]["pocket_src_edge_type"]
                st = sample["net_input"]["pocket_src_tokens"]
                pocket_padding_mask = st.eq(model.pocket_model.padding_idx)
                pocket_x = model.pocket_model.embed_tokens(st)
                n_node = dist.size(-1)
                gbf_feature = model.pocket_model.gbf(dist, et)
                gbf_result = model.pocket_model.gbf_proj(gbf_feature)
                graph_attn_bias = gbf_result
                graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous()
                graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
                pocket_outputs = model.pocket_model.encoder(
                    pocket_x, padding_mask=pocket_padding_mask, attn_mask=graph_attn_bias
                )
                pocket_encoder_rep = pocket_outputs[0][:,0,:].detach().cpu().numpy()

                train_pocket_reps.append(pocket_encoder_rep)
                train_labels.extend(sample["target"]["finetune_target"].detach().cpu().numpy())
                #print(labels)
            train_pocket_reps = np.concatenate(train_pocket_reps, axis=0)
            print(train_pocket_reps.shape)
            with open(save_path, "wb") as f:
                pickle.dump(train_pocket_reps, f)

        valid_path = "/valid/path"

        
        pocket_dataset = self.load_dataset_knn(valid_path)
        pocket_data = torch.utils.data.DataLoader(pocket_dataset, batch_size=32, collate_fn=pocket_dataset.collater)
        valid_pocket_reps = []
        valid_labels = []
        save_path = save_base + ckpt_name + "_valid_pocket_reps.pkl"
        if os.path.exists(save_path):
            with open(save_path, "rb") as f:
                valid_pocket_reps = pickle.load(f)
            for _, sample in enumerate(tqdm(pocket_data)):
                valid_labels.extend(sample["target"]["finetune_target"].detach().cpu().numpy())
        else:
            for _, sample in enumerate(tqdm(pocket_data)):
                sample = unicore.utils.move_to_cuda(sample)
                dist = sample["net_input"]["pocket_src_distance"]
                et = sample["net_input"]["pocket_src_edge_type"]
                st = sample["net_input"]["pocket_src_tokens"]
                pocket_padding_mask = st.eq(model.pocket_model.padding_idx)
                pocket_x = model.pocket_model.embed_tokens(st)
                n_node = dist.size(-1)
                gbf_feature = model.pocket_model.gbf(dist, et)
                gbf_result = model.pocket_model.gbf_proj(gbf_feature)
                graph_attn_bias = gbf_result
                graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous()
                graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node)
                pocket_outputs = model.pocket_model.encoder(
                    pocket_x, padding_mask=pocket_padding_mask, attn_mask=graph_attn_bias
                )
                pocket_encoder_rep = pocket_outputs[0][:,0,:].detach().cpu().numpy()

                valid_pocket_reps.append(pocket_encoder_rep)
                valid_labels.extend(sample["target"]["finetune_target"].detach().cpu().numpy())
                #print(labels)
            valid_pocket_reps = np.concatenate(valid_pocket_reps, axis=0)
            with open(save_path, "wb") as f:
                pickle.dump(valid_pocket_reps, f)
        

        train_pocket_reps = np.concatenate([train_pocket_reps, valid_pocket_reps], axis=0)
        

        train_labels.extend(valid_labels)

        
        from sklearn.neighbors import KNeighborsRegressor
        neigh = KNeighborsRegressor(n_neighbors=200, weights='distance', algorithm="brute", metric='cosine', n_jobs=-1)

        neigh.fit(train_pocket_reps, train_labels)

        print("train done")
        
        
        

        preds = neigh.predict(test_pocket_reps)

        # calculate mse and rmse
        from sklearn.metrics import mean_squared_error
        from math import sqrt
        mse = mean_squared_error(labels, preds)
        rmse = sqrt(mse)
        range = max(labels) - min(labels)
        nrmse = rmse / range
        print("mse: ", mse)
        print("rmse: ", rmse)
        print("nrmse: ", nrmse)
        
        return 

    