from posthoc import solvers

track_builders = {}
track_builders['QCAI'] = lambda node : {
    ('QCAI', 'epitope'):node.tracks.add(
        node.nodes.search_modules('decoderE->BertSelfAttention->Dropout'),
        node.nodes.search_modules('decoderE->crossattention->key'),
        node.nodes.search_modules('decoderE->crossattention->query'),
        node.nodes.search_modules('encoderE->BertSelfAttention->Dropout')
    ),
    ('QCAI', 'alpha'):node.tracks.add(
        node.nodes.search_modules('decoderE->crossattention->self->Dropout'),
        node.nodes.search_modules('decoderE->crossattention->self->key'),
        node.nodes.search_modules('encoderA->BertSelfAttention->Dropout'),
    ),
    ('QCAI', 'beta'):node.tracks.add(
        node.nodes.search_modules('decoderE->crossattention->self->Dropout'),
        node.nodes.search_modules('decoderE->crossattention->self->key'),
        node.nodes.search_modules('encoderB->BertSelfAttention->Dropout'),
    ),
}

flow_builders = {}
def _qcai_flow_build(tracks, alen, blen, discard_ratio=0.9, multihead_reduce='max'):
    tracks[('QCAI', 'epitope')].flow.set_flow(
        solvers.GradAttn(discard_ratio=discard_ratio, multihead_reduce=multihead_reduce, residual_connect=True, norm=True),
        solvers.KeyDecomposeQuantifyQuery(),
        solvers.QuantifyQuery(reduce_method='sum', norm=True),
        solvers.GradAttn(discard_ratio=discard_ratio, multihead_reduce=multihead_reduce, residual_connect=True, norm=True),
        solvers.GradAttnRolloutQuantifyQuery(discard_ratio=discard_ratio, multihead_reduce=multihead_reduce, residual_connect=True, norm=True),
        solvers.KeyDecomposeQuantifyQuery(),
        solvers.QuantifyQuery(reduce_method='sum', norm=True),
        solvers.GradAttn(discard_ratio=discard_ratio, multihead_reduce=multihead_reduce, residual_connect=True, norm=True),
        solvers.GradAttn(discard_ratio=discard_ratio, multihead_reduce=multihead_reduce, residual_connect=True, norm=True),
        solvers.GradAttn(discard_ratio=discard_ratio, multihead_reduce=multihead_reduce, residual_connect=True, norm=True)
    )
    tracks[('QCAI', 'epitope')].flow.set_collect(
        lambda x: x.flatten(1)
    )
    tracks[('QCAI', 'alpha')].flow.set_flow(
        solvers.GradCrossAttnRollout(discard_ratio=discard_ratio, multihead_reduce=multihead_reduce, residual_connect=True, norm=True, clip=(1, alen+1), compress_reduce=lambda x, dim=1:x[:, 0]),
        solvers.QuantifyQueryIn(reduce_method='sum', norm=True, clip=(1, alen+1)),
        solvers.GradCrossAttnRollout(discard_ratio=discard_ratio, multihead_reduce=multihead_reduce, residual_connect=True, norm=True, clip=(1, alen+1), compress_reduce=lambda x, dim=1:x[:, 0]),
        solvers.QuantifyQueryIn(reduce_method='sum', norm=True, clip=(1, alen+1)),
        solvers.GradAttn(discard_ratio=discard_ratio, multihead_reduce=multihead_reduce, residual_connect=True, norm=True),
        solvers.GradAttn(discard_ratio=discard_ratio, multihead_reduce=multihead_reduce, residual_connect=True, norm=True)
    )
    tracks[('QCAI', 'alpha')].flow.set_collect(
        lambda x: x.flatten(1)
    )
    tracks[('QCAI', 'beta')].flow.set_flow(
        solvers.GradCrossAttnRollout(discard_ratio=discard_ratio, multihead_reduce=multihead_reduce, residual_connect=True, norm=True, clip=(alen+1, alen+blen+1), compress_reduce=lambda x, dim=1:x[:, 0]),
        solvers.QuantifyQueryIn(reduce_method='sum', norm=True, clip=(alen+1, alen+blen+1)),
        solvers.GradCrossAttnRollout(discard_ratio=discard_ratio, multihead_reduce=multihead_reduce, residual_connect=True, norm=True, clip=(alen+1, alen+blen+1), compress_reduce=lambda x, dim=1:x[:, 0]),
        solvers.QuantifyQueryIn(reduce_method='sum', norm=True, clip=(alen+1, alen+blen+1)),
        solvers.GradAttn(discard_ratio=discard_ratio, multihead_reduce=multihead_reduce, residual_connect=True, norm=True),
        solvers.GradAttn(discard_ratio=discard_ratio, multihead_reduce=multihead_reduce, residual_connect=True, norm=True)
    )
    tracks[('QCAI', 'beta')].flow.set_collect(
        lambda x: x.flatten(1)
    )
flow_builders['QCAI'] = _qcai_flow_build

def set_flow(flow_builders, tracks, alen, blen, discard_ratio=0.9, multihead_reduce='max'):
    for f in flow_builders.values():
        f(tracks, alen, blen, discard_ratio, multihead_reduce)