import torch
import random
from typing import List, Optional

from ..protein import constants
from ._base import register_transform


def random_shrink_extend(flag, min_length=5, shrink_limit=1, extend_limit=2):
    first, last = continuous_flag_to_range(flag)
    length = flag.sum().item()
    if (length - 2*shrink_limit) < min_length:
        shrink_limit = 0
    first_ext = max(0, first-random.randint(-shrink_limit, extend_limit))
    last_ext = min(last+random.randint(-shrink_limit, extend_limit), flag.size(0)-1)
    flag_ext = flag.clone()
    flag_ext[first_ext : last_ext+1] = True
    return flag_ext


def continuous_flag_to_range(flag):
    first = (torch.arange(0, flag.size(0))[flag]).min().item()
    last = (torch.arange(0, flag.size(0))[flag]).max().item()
    return first, last


@register_transform('mask_single_cdr')
class MaskSingleCDR(object):

    def __init__(self, selection=None, augmentation=True):
        super().__init__()
        cdr_str_to_enum = {
            'H1': constants.CDR.H1,
            'H2': constants.CDR.H2,
            'H3': constants.CDR.H3,
            'L1': constants.CDR.L1,
            'L2': constants.CDR.L2,
            'L3': constants.CDR.L3,
            'H_CDR1': constants.CDR.H1,
            'H_CDR2': constants.CDR.H2,
            'H_CDR3': constants.CDR.H3,
            'L_CDR1': constants.CDR.L1,
            'L_CDR2': constants.CDR.L2,
            'L_CDR3': constants.CDR.L3,
            'CDR3': 'CDR3',     
        }
        assert selection is None or selection in cdr_str_to_enum
        self.selection = cdr_str_to_enum.get(selection, None)
        self.augmentation = augmentation

    def perform_masking_(self, data, selection=None):
        cdr_flag = data['cdr_flag']

        if selection is None:
            cdr_all = cdr_flag[cdr_flag > 0].unique().tolist()
            cdr_to_mask = random.choice(cdr_all)
        else:
            cdr_to_mask = selection

        cdr_to_mask_flag = (cdr_flag == cdr_to_mask)
        if self.augmentation:
            cdr_to_mask_flag = random_shrink_extend(cdr_to_mask_flag)

        cdr_first, cdr_last = continuous_flag_to_range(cdr_to_mask_flag)
        left_idx = max(0, cdr_first-1)
        right_idx = min(data['aa'].size(0)-1, cdr_last+1)
        anchor_flag = torch.zeros(data['aa'].shape, dtype=torch.bool)
        anchor_flag[left_idx] = True
        anchor_flag[right_idx] = True

        data['generate_flag'] = cdr_to_mask_flag
        data['anchor_flag'] = anchor_flag

    def __call__(self, structure):
        if self.selection is None:
            ab_data = []
            if structure['heavy'] is not None:
                ab_data.append(structure['heavy'])
            if structure['light'] is not None:
                ab_data.append(structure['light'])
            data_to_mask = random.choice(ab_data)
            sel = None
        elif self.selection in (constants.CDR.H1, constants.CDR.H2, constants.CDR.H3, ):
            data_to_mask = structure['heavy']
            sel = int(self.selection)
        elif self.selection in (constants.CDR.L1, constants.CDR.L2, constants.CDR.L3, ):
            data_to_mask = structure['light']
            sel = int(self.selection)
        elif self.selection == 'CDR3':
            if structure['heavy'] is not None:
                data_to_mask = structure['heavy']
                sel = constants.CDR.H3
            else:
                data_to_mask = structure['light']
                sel = constants.CDR.L3

        self.perform_masking_(data_to_mask, selection=sel)
        return structure


@register_transform('mask_multiple_cdrs')
class MaskMultipleCDRs(object):
    
    def __init__(self, selection: Optional[List[str]]=None, augmentation=True):
        super().__init__()
        cdr_str_to_enum = {
            'H1': constants.CDR.H1,
            'H2': constants.CDR.H2,
            'H3': constants.CDR.H3,
            'L1': constants.CDR.L1,
            'L2': constants.CDR.L2,
            'L3': constants.CDR.L3,
            'H_CDR1': constants.CDR.H1,
            'H_CDR2': constants.CDR.H2,
            'H_CDR3': constants.CDR.H3,
            'L_CDR1': constants.CDR.L1,
            'L_CDR2': constants.CDR.L2,
            'L_CDR3': constants.CDR.L3,
        }
        if selection is not None:
            self.selection = [cdr_str_to_enum[s] for s in selection]
        else:
            self.selection = None
        self.augmentation = augmentation

    def mask_one_cdr_(self, data, cdr_to_mask):
        cdr_flag = data['cdr_flag']

        cdr_to_mask_flag = (cdr_flag == cdr_to_mask)
        if self.augmentation:
            cdr_to_mask_flag = random_shrink_extend(cdr_to_mask_flag)

        cdr_first, cdr_last = continuous_flag_to_range(cdr_to_mask_flag)
        left_idx = max(0, cdr_first-1)
        right_idx = min(data['aa'].size(0)-1, cdr_last+1)
        anchor_flag = torch.zeros(data['aa'].shape, dtype=torch.bool)
        anchor_flag[left_idx] = True
        anchor_flag[right_idx] = True

        if 'generate_flag' not in data:
            data['generate_flag'] = cdr_to_mask_flag
            data['anchor_flag'] = anchor_flag
        else:
            data['generate_flag'] |= cdr_to_mask_flag
            data['anchor_flag'] |= anchor_flag

    def mask_for_one_chain_(self, data):
        cdr_flag = data['cdr_flag']
        cdr_all = cdr_flag[cdr_flag > 0].unique().tolist()
    
        num_cdrs_to_mask = random.randint(1, len(cdr_all))

        if self.selection is not None:
            cdrs_to_mask = list(set(cdr_all).intersection(self.selection))
        else:
            random.shuffle(cdr_all)
            cdrs_to_mask = cdr_all[:num_cdrs_to_mask]

        for cdr_to_mask in cdrs_to_mask:
            self.mask_one_cdr_(data, cdr_to_mask)

    def __call__(self, structure):
        if structure['heavy'] is not None:
            self.mask_for_one_chain_(structure['heavy'])
        if structure['light'] is not None:
            self.mask_for_one_chain_(structure['light'])
        return structure

        
@register_transform('mask_antibody')
class MaskAntibody(object):

    def mask_ab_chain_(self, data):
        data['generate_flag'] = torch.ones(data['aa'].shape, dtype=torch.bool)

    def __call__(self, structure):
        pos_ab_alpha = []
        if structure['heavy'] is not None:
            self.mask_ab_chain_(structure['heavy'])
            pos_ab_alpha.append(
                structure['heavy']['pos_heavyatom'][:, constants.BBHeavyAtom.CA]
            )
        if structure['light'] is not None:
            self.mask_ab_chain_(structure['light'])
            pos_ab_alpha.append(
                structure['light']['pos_heavyatom'][:, constants.BBHeavyAtom.CA]
            )
        pos_ab_alpha = torch.cat(pos_ab_alpha, dim=0)   

        if structure['antigen'] is not None:
            pos_ag_alpha = structure['antigen']['pos_heavyatom'][:, constants.BBHeavyAtom.CA]
            ag_ab_dist = torch.cdist(pos_ag_alpha, pos_ab_alpha)    
            nn_ab_dist = ag_ab_dist.min(dim=1)[0]   
            contact_flag = (nn_ab_dist <= 6.0)      
            if contact_flag.sum().item() == 0:
                contact_flag[nn_ab_dist.argmin()] = True

            anchor_idx = torch.multinomial(contact_flag.float(), num_samples=1).item()
            anchor_flag = torch.zeros(structure['antigen']['aa'].shape, dtype=torch.bool)
            anchor_flag[anchor_idx] = True
            structure['antigen']['anchor_flag'] = anchor_flag
            structure['antigen']['contact_flag'] = contact_flag
        
        return structure


@register_transform('remove_antigen')
class RemoveAntigen:

    def __call__(self, structure):
        structure['antigen'] = None
        structure['antigen_seqmap'] = None
        return structure
