import numpy as np
import pandas as pd
import os
from affinityenhancer.preprocess.paired_alignment_utils import *
import pickle
import torch
from smart_open import open as sopen


# Define AHo regions
AHo_regions = {
    "FR1": [i for i in range(0, 24)],
    "CDR1": [i for i in range(24, 40)],
    "FR2": [i for i in range(40, 57)],
    "CDR2": [i for i in range(57, 77)],
    "FR3a": [i for i in range(77, 93)],
    "FR3b": [i for i in range(93, 108)],
    "CDR3": [i for i in range(108, 137)],
    "FR4": [i for i in range(137, 148)],
}


def match_on_edist(target_df, save_dir='.', property_th_lb=0.5,
                   property_th_ub=3.5, edist_th=15,
                   property_to_match='affinity_pkd',
                   property_min = None,
                   col='ag_id'):
    
    assert property_to_match in target_df.columns
    #distance_tensors = {}
    #distances = {}

    if property_min is not None:
        target_df = target_df[target_df[property_to_match]>=property_min]

    distances = find_closest_ref_seqs(target_df, target_df, region=None)

    #distance_tensors = distances
    #with sopen(os.path.join(save_dir, 
    #                       f'matcheddistances_th{property_th_lb}-{property_th_ub}_edth{edist_th}_prop{property_to_match}.pickle'),
    #                       'wb') as fp:
    #    pickle.dump(distance_tensors, fp)

    df_full = pd.DataFrame([])
    #D = (distances['CDR1'] + distances['CDR2'] + distances['CDR3'] +
    #     distances['FR1'] + distances['FR2'] + distances['FR3a'] +
    #     distances['FR3b'] + distances['FR4']).sum(axis=-1)

    
    A = np.where(distances <= edist_th, 1, 0)
    if col != None:
        matching_ags = get_matching_cols(target_df, target_df, col=col)
        #print(matching_ags)
        #print(matching_ags.shape)
        A = A & matching_ags

    matched_tuples = np.where(A == 1)

    f_hv_aho, f_li_aho, s_hv_aho, s_li_aho = [], [], [], []
    f_hv, f_li, s_hv, s_li = [], [], [], []
    f_pkd, s_pkd, f_seqid, s_seqid, f_ag_seq, s_ag_seq = [], [], [], [], [], []
    
    for i in np.arange(np.asarray(matched_tuples).shape[1]):
        f_idx = matched_tuples[0][i]
        s_idx = matched_tuples[1][i]
        if f_idx != s_idx:
            f_seq_full = target_df.iloc[f_idx]
            s_seq_full = target_df.iloc[s_idx]

            f_hv_aho.append(f_seq_full['fv_heavy_aho'])
            f_li_aho.append(f_seq_full['fv_light_aho'])
            s_hv_aho.append(s_seq_full['fv_heavy_aho'])
            s_li_aho.append(s_seq_full['fv_light_aho'])

            f_hv.append(f_seq_full['fv_heavy'])
            f_li.append(f_seq_full['fv_light'])
            s_hv.append(s_seq_full['fv_heavy'])
            s_li.append(s_seq_full['fv_light'])
            f_pkd.append(f_seq_full[property_to_match])
            s_pkd.append(s_seq_full[property_to_match])
            f_seqid.append(f_seq_full['seqid'])
            s_seqid.append(s_seq_full['seqid'])
            f_ag_seq.append(f_seq_full['affinity_antigen_sequence'])
            s_ag_seq.append(s_seq_full['affinity_antigen_sequence'])

    df_matched = pd.DataFrame(np.hstack([
        np.asarray(f_hv).reshape(-1, 1),
        np.asarray(s_hv).reshape(-1, 1),
        np.asarray(f_li).reshape(-1, 1),
        np.asarray(s_li).reshape(-1, 1),
        np.asarray(f_hv_aho).reshape(-1, 1),
        np.asarray(s_hv_aho).reshape(-1, 1),
        np.asarray(f_li_aho).reshape(-1, 1),
        np.asarray(s_li_aho).reshape(-1, 1),
        np.asarray(f_pkd).reshape(-1, 1),
        np.asarray(s_pkd).reshape(-1, 1),
        np.asarray(f_seqid).reshape(-1, 1),
        np.asarray(s_seqid).reshape(-1, 1),
        np.asarray(f_ag_seq).reshape(-1, 1),
        np.asarray(s_ag_seq).reshape(-1, 1)
    ]))

    df_matched.columns = [
        'first_HeavyAA', 'second_HeavyAA', 'first_LightAA', 'second_LightAA',
        'first_HeavyAHoAA', 'second_HeavyAHoAA', 'first_LightAHoAA',
        'second_LightAHoAA', 'first_property', 'second_property',
        'first_seqid', 'second_seqid', 'first_affinity_antigen_sequence',
        'second_affinity_antigen_sequence'
    ]

    df_matched['first_property'] = pd.to_numeric(df_matched['first_property'])
    df_matched['second_property'] = pd.to_numeric(df_matched['second_property'])
    print('Matched', df_matched.shape)
    df_matched.sort_values(by=['first_property'])

    df_f_st_s = df_matched[
        ((df_matched['second_property'] - df_matched['first_property']) > property_th_lb) &
        ((df_matched['second_property'] - df_matched['first_property']) < property_th_ub)
    ]
    df_s_st_f = df_matched[
        ((df_matched['first_property'] - df_matched['second_property']) > property_th_lb) &
        ((df_matched['first_property'] - df_matched['second_property']) < property_th_ub)
    ]

    df_s_st_f.columns = [
        'second_HeavyAA', 'first_HeavyAA', 'second_LightAA', 'first_LightAA',
        'second_HeavyAHoAA', 'first_HeavyAHoAA', 'second_LightAHoAA',
        'first_LightAHoAA', 'second_property', 'first_property',
        'second_seqid', 'first_seqid', 'first_affinity_antigen_sequence',
        'second_affinity_antigen_sequence'
    ]
    final_df_target = pd.concat([df_f_st_s, df_s_st_f])
    final_df_target = final_df_target.drop_duplicates()
    print('Matched filtered', final_df_target.shape)
    
    df_full = pd.concat((df_full, final_df_target))
    
    return df_full

