import torch

from ._base import _mask_select_data, register_transform
from ..protein import constants


@register_transform('patch_around_anchor')
class PatchAroundAnchor(object):

    def __init__(self, initial_patch_size=128, antigen_size=128):
        super().__init__()
        self.initial_patch_size = initial_patch_size
        self.antigen_size = antigen_size

    def _center(self, data, origin):
        origin = origin.reshape(1, 1, 3)
        data['pos_heavyatom'] -= origin 
        data['pos_heavyatom'] = data['pos_heavyatom'] * data['mask_heavyatom'][:, :, None]
        data['origin'] = origin.reshape(3)
        return data

    def __call__(self, data):        
        anchor_flag = data['anchor_flag']   
        anchor_points = data['pos_heavyatom'][anchor_flag, constants.BBHeavyAtom.CA]    
        antigen_mask = (data['fragment_type'] == constants.Fragment.Antigen)
        antibody_mask = torch.logical_not(antigen_mask)

        if anchor_flag.sum().item() == 0:
            
            data_patch = _mask_select_data(
                data = data,
                mask = antibody_mask,
            )
            data_patch = self._center(
                data_patch,
                origin = data_patch['pos_heavyatom'][:, constants.BBHeavyAtom.CA].mean(dim=0)
            )
            return data_patch

        pos_alpha = data['pos_heavyatom'][:, constants.BBHeavyAtom.CA]  
        dist_anchor = torch.cdist(pos_alpha, anchor_points).min(dim=1)[0]    
        initial_patch_idx = torch.topk(
            dist_anchor,
            k = min(self.initial_patch_size, dist_anchor.size(0)),
            largest=False,
        )[1]   

        dist_anchor_antigen = dist_anchor.masked_fill(
            mask = antibody_mask, 
            value = float('+inf')
        )   
        antigen_patch_idx = torch.topk(
            dist_anchor_antigen, 
            k = min(self.antigen_size, antigen_mask.sum().item()), 
            largest=False, sorted=True
        )[1]    
        
        patch_mask = torch.logical_or(
            data['generate_flag'],
            data['anchor_flag'],
        )
        patch_mask[initial_patch_idx] = True
        patch_mask[antigen_patch_idx] = True

        patch_idx = torch.arange(0, patch_mask.shape[0])[patch_mask]

        data_patch = _mask_select_data(data, patch_mask)
        data_patch = self._center(
            data_patch,
            origin = anchor_points.mean(dim=0)
        )
        data_patch['patch_idx'] = patch_idx
        return data_patch
