import torch
from .protein import constants


def find_cdrs(structure):
    cdrs = []
    if structure['heavy'] is not None:
        flag = structure['heavy']['cdr_flag']
        if int(constants.CDR.H1) in flag:
            cdrs.append('H_CDR1')
        if int(constants.CDR.H2) in flag:
            cdrs.append('H_CDR2')
        if int(constants.CDR.H3) in flag:
            cdrs.append('H_CDR3')

    if structure['light'] is not None:
        flag = structure['light']['cdr_flag']
        if int(constants.CDR.L1) in flag:
            cdrs.append('L_CDR1')
        if int(constants.CDR.L2) in flag:
            cdrs.append('L_CDR2')
        if int(constants.CDR.L3) in flag:
            cdrs.append('L_CDR3')
    
    return cdrs


def get_residue_first_last(data):
    loop_flag = data['generate_flag']
    loop_idx = torch.arange(loop_flag.size(0))[loop_flag]
    idx_first, idx_last = loop_idx.min().item(), loop_idx.max().item()
    residue_first = (data['chain_id'][idx_first], data['resseq'][idx_first].item(), data['icode'][idx_first])
    residue_last = (data['chain_id'][idx_last], data['resseq'][idx_last].item(), data['icode'][idx_last])
    return residue_first, residue_last


class RemoveNative(object):

    def __init__(self, remove_structure, remove_sequence):
        super().__init__()
        self.remove_structure = remove_structure
        self.remove_sequence = remove_sequence

    def __call__(self, data):
        generate_flag = data['generate_flag'].clone()
        if self.remove_sequence:
            data['aa'] = torch.where(
                generate_flag, 
                torch.full_like(data['aa'], fill_value=int(constants.AA.UNK)),    
                data['aa']
            )

        if self.remove_structure:
            data['pos_heavyatom'] = torch.where(
                generate_flag[:, None, None].expand(data['pos_heavyatom'].shape),
                torch.randn_like(data['pos_heavyatom']) * 10,
                data['pos_heavyatom']
            )

        return data