import edlib
import numpy as np
import pandas as pd

pd.options.mode.chained_assignment = None  # default='warn'

def edist(x, y, region):
    x = "".join([x[r] for r in region]).replace("-", "")
    y = "".join([y[r] for r in region]).replace("-", "")
    out = edlib.align(x, y)
    return out["editDistance"], out


def find_closest_ref_seqs(df, df_ref, region, must_match='ag_id'):
    """
    For each row in df it finds the row in df_ref with smaller edit distance.
    Returns the indices and edit distances (for heavy and light chains)
    """
    #ed = np.zeros((df.shape[0], df_ref.shape[0], 2)).astype(int)
    ed = np.zeros((df.shape[0], df_ref.shape[0])).astype(int)

    if region != None:
        df["h_region"] = ["".join([h[r] for r in region]).replace("-", "") for h in df.fv_heavy_aho]
        df["l_region"] = ["".join([l[r] for r in region]).replace("-", "") for l in df.fv_light_aho]
        df_ref["h_region"] = [
            "".join([h[r] for r in region]).replace("-", "") for h in df_ref.fv_heavy_aho
        ]
        df_ref["l_region"] = [
            "".join([l[r] for r in region]).replace("-", "") for l in df_ref.fv_light_aho
        ]
    else:
        df["h_region"] = df.fv_heavy
        df["l_region"] = df.fv_light
        df_ref["h_region"] = df_ref.fv_heavy
        df_ref["l_region"] = df_ref.fv_light

    for idx in range(df.shape[0]):
        ed_heavy = df_ref.h_region.apply(
            lambda heavy: edlib.align(heavy, df.h_region.iloc[idx])["editDistance"]
        ).to_numpy()
        ed_light = df_ref.l_region.apply(
            lambda light: edlib.align(light, df.l_region.iloc[idx])["editDistance"]
        ).to_numpy()
        #ed[idx, :, 0] = ed_heavy
        #ed[idx, :, 1] = ed_light
        ed[idx, :] = ed_heavy + ed_light

    return ed


def get_matching_cols(df, df_ref, col='ag_id'):

    same_ag = np.zeros((df.shape[0], df_ref.shape[0])).astype(int)
    for idx in range(df.shape[0]):
        same_ag[idx, :] = df_ref[col].apply(
            lambda x: x == df[col].iloc[idx]
        ).to_numpy()

    return same_ag

