import pandas as pd
import sys
import edlib
import numpy as np
from affinityenhancer.preprocess.utils import read_yaml, lookup_target_sequence
from affinityenhancer.preprocess.constants import DEFAULT_DATASET_PATH, DEFAULT_PAIRING_SETTINGS
from prescient.transforms.functional import anarci_numbering
from prescient.constants import CDR_RANGES_AHO, LENGTH_FV_HEAVY_AHO, RANGES_AHO

import seaborn as sns
import matplotlib.pyplot as plt

def get_common_ids(property_th_lb=DEFAULT_PAIRING_SETTINGS['property_th_lb'],
                   property_th_ub=DEFAULT_PAIRING_SETTINGS['property_th_ub'],
                   edist_th=DEFAULT_PAIRING_SETTINGS['edist_th'],
                   property_to_match=DEFAULT_PAIRING_SETTINGS['property_to_match'],
                   min_prop=None
                   ):
    
    file_1 = f'{DEFAULT_DATASET_PATH}/skempi_matched_th{property_th_lb}-{property_th_ub}_edth{edist_th}_prop15.parquet'
    file_2 = f'{DEFAULT_DATASET_PATH}/prescient_matched_th{property_th_lb}-{property_th_ub}_edth{edist_th}_propED.parquet'

    dataset = pd.read_csv(
    's3://prescient-data-dev/sandbox/vasilaks/sabdab_vs_world/DATA_v5.1/epitome_dataset.csv',
    low_memory = False)

    print(file_1)
    df = dataset[dataset['affinity_datasource']=="skempi"]
    df = df[~df['affinity_pkd'].isna()]
    df['seqid'] = df['seq_id']

    df_1 = dataset[dataset['affinity_datasource']=="prescient"]
    df_1 = df_1[~df_1['affinity_pkd'].isna()]
    df_1['seqid'] = df_1['seq_id']

    unique_ids1 = list(df['seqid'].unique())
    unique_ids2 = list(df_1['seqid'].unique())

    common_ids = [t for t in unique_ids1 if t in unique_ids2]

    print(len(common_ids))
    print(common_ids)

def print_paired_dataset_stats(name, 
                               property_th_lb=DEFAULT_PAIRING_SETTINGS['property_th_lb'],
                               property_th_ub=DEFAULT_PAIRING_SETTINGS['property_th_ub'],
                               edist_th=DEFAULT_PAIRING_SETTINGS['edist_th'],
                               property_to_match=DEFAULT_PAIRING_SETTINGS['property_to_match'],
                               min_prop=None
                               ):
    suffix = ''
    if 'property_min' in settings:
        min_prop = settings['property_min']
        suffix = f'_minprop{min_prop}'
    if name == 'skempi':
            datafile = f'{DEFAULT_DATASET_PATH}/{name}_matched_th{property_th_lb}-{property_th_ub}_edth{edist_th}_prop{property_to_match}{suffix}.parquet'
    else:
            datafile = f'{DEFAULT_DATASET_PATH}/{name}_matched_th{property_th_lb}-{property_th_ub}_edth{edist_th}_propED{suffix}.parquet'
    
    print(datafile)
    df_paired = pd.read_parquet(datafile)
    print('Unique_pairs:', df_paired.shape[0])
    print(df_paired[:4][['first_property', 'second_property']])
    df = pd.DataFrame()
    df['seqid'] = df_paired['first_seqid'].values.tolist() + df_paired['second_seqid'].values.tolist()
    df['heavy'] = df_paired['first_HeavyAA'].values.tolist() + df_paired['second_HeavyAA'].values.tolist()
    df['light'] = df_paired['first_LightAA'].values.tolist() + df_paired['second_LightAA'].values.tolist()
    df = df.drop_duplicates()
    print('Unique_Abs:', df.shape[0])


def print_original_dataset_stats(name='prescient'):
    csv_file = 's3://prescient-data-dev/sandbox/vasilaks/sabdab_vs_world/DATA_v5.1/epitome_dataset.csv'
    accepted_seeds = pd.read_csv("/homefs/home/mahajs17/projects/C1sr/scripts/accepted_seeds.csv")
    #print(accepted_seeds.columns)
    #print(accepted_seeds['target'])
    df = pd.read_csv(csv_file)
    df = df[df['affinity_datasource']==name]
    print(df.columns)
    unique_ag_ids = df['ag_id'].unique()
    print(len(unique_ag_ids))
    #print(unique_ag_ids)
    #unique_targets = df['affinity_antigen_name'].unique()
    #print(unique_targets)
    unique_ag_ids_st = df['ag_id_strict'].unique()
    print(len(unique_ag_ids_st))
    #print(unique_ag_ids_st)
    unique_ag_seq = df['affinity_antigen_sequence'].unique()
    print(len(unique_ag_seq))
    ed_ag = [edlib.align(s, unique_ag_seq)['editDistance']
             for s in unique_ag_seq]
    min_ed_ag = min(ed_ag)
    print(min_ed_ag)
    min_ed_ag_sep = [t for t in ed_ag if t > 80]
    print('Over 80 ED: ', len(min_ed_ag_sep))
    ax = sns.histplot(ed_ag)
    ax.set_title(f'Dataset: {name}')
    ax.set_xlabel('Min. Edit Distance to other antigens in dataset')
    plt.savefig(f'{name}_antigen_ed_dist.png')
    plt.close()


    
def find_closest_ag(source_1='prescient', source_2='skempi', target='IL6'):
    csv_file = 's3://prescient-data-dev/sandbox/vasilaks/sabdab_vs_world/DATA_v5.1/epitome_dataset.csv'
    # Alternatively: seed_df = cortex.optim._initialization.load_seeds()
    seed_df = pd.read_csv("/homefs/home/mahajs17/projects/C1sr/scripts/accepted_seeds.csv")
    ag_seq_lookup = lookup_target_sequence()
    seed_df["affinity_antigen_sequence"] = seed_df.target.apply(lambda x: ag_seq_lookup.get(x, "")).values
    seed_df.rename(columns={"seed_id": "seed_alias"}, inplace=True)
    # seed_df = io.load_parquet_to_pandas(seed_path)
    design_targets = seed_df.target.unique()
    #print(accepted_seeds.columns)
    #print(accepted_seeds['target'])
    df = pd.read_csv(csv_file)
    df_1 = df[df['affinity_datasource']==source_1]
    df_2 = df[df['affinity_datasource']==source_2]
    df_1_target = df_1[df_1['target']==target]
    ag_seqs = df_1_target['affinity_antigen_sequence'].unique()
    print(ag_seqs)
    df_2_ag_seqs = df_2[df_2['affinity_antigen_sequence'].isin(ag_seqs)]
    ag_seqs_2 = [] if df_2_ag_seqs.empty else df_2_ag_seqs['affinity_antigen_sequence'].unique()
    print(len(ag_seqs_2))
    seed_df_target = seed_df[seed_df['affinity_antigen_sequence'].isin(ag_seqs)]
    seed_df_target_seed_id = seed_df_target['seed_alias']
    print(seed_df_target_seed_id)



def get_pair_dataset_stats():
    settings = dict(property_th_lb=0.5,
                    property_th_ub=1.5,
                    edist_th=5,
                    property_to_match='affinity_pkd',
                    property_min=6.5
                    )
    #settings = DEFAULT_PAIRING_SETTINGS
    if len(sys.argv) > 1:
        settings_file = sys.argv[1]
        settings = read_yaml(settings_file)

    print(settings)
    
    for name in ['prescient', 'skempi', 'aalphabio']:
        print_paired_dataset_stats(name, settings=settings)


def get_closest_match(csv_paired, seed_name='IL6-1409', partition=None):
    seed_df = pd.read_csv("/homefs/home/mahajs17/projects/C1sr/scripts/accepted_seeds.csv")
    ag_seq_lookup = lookup_target_sequence()
    seed_df["affinity_antigen_sequence"] = seed_df.target.apply(lambda x: ag_seq_lookup.get(x, "")).values
    seed_df.rename(columns={"seed_id": "seed_alias"}, inplace=True)
    seed_df_ag = seed_df[['target', 'affinity_antigen_sequence']].drop_duplicates()
    target_seq_lookup = {row['target']:row['affinity_antigen_sequence']
                            for i, row in seed_df_ag.iterrows()}
    
    df_paired = pd.read_parquet(csv_paired)
    if partition is not None:
        print(df_paired['partition'].unique())
        assert partition in df_paired['partition'].unique()
        df_paired = df_paired[df_paired['partition']==partition]
    print('dataset size', df_paired.shape[0])
    
    #print(seed_df.seed_alias)
    seed = seed_df[seed_df['seed_alias']==seed_name]
    #print(seed.affinity_antigen_sequence)
    print(seed.shape[0], seed.seed_alias)
    fv_heavy, fv_light = seed['fv_heavy'].iloc[0], seed['fv_light'].iloc[0]
    fv_comb = fv_heavy + fv_light
    #print(fv_heavy)
    #print(fv_light)
    heavy_seqs = df_paired['first_HeavyAA'].values.tolist() + df_paired['second_HeavyAA'].values.tolist()
    light_seqs = df_paired['first_LightAA'].values.tolist() + df_paired['second_LightAA'].values.tolist()
    comb_seqs = (df_paired['first_HeavyAA'] + df_paired['first_LightAA']).values.tolist() +\
    (df_paired['second_HeavyAA'] + df_paired['second_LightAA']).values.tolist()
    #print(comb_seqs[0])
    ed_heavy = [edlib.align(s, fv_heavy)['editDistance'] for s in heavy_seqs]
    ed_light = [edlib.align(s, fv_light)['editDistance'] for s in light_seqs]
    ed_comb = [edlib.align(s, fv_comb)['editDistance'] for s in comb_seqs]
    
    edit_distance = {}
    str_cat = ''
    for cdr, cdr_range in CDR_RANGES_AHO.items():
        chain = 'heavy' if cdr.startswith('H') else 'light'
        r1, r2 = cdr_range[0], cdr_range[1]+1
        chain_name = 'Heavy' if cdr.startswith('H') else 'Light'
        segment = seed['fv_heavy_aho'].iloc[0][r1:r2].replace('-','') if  cdr.startswith('H') \
                        else seed['fv_light_aho'][r1:r2].replace('-','')
        edit_distance[f'edit_distance_{cdr}'] = [edlib.align(row[f'first_{chain_name}AHoAA'][r1:r2].replace('-',''), 
                                                segment)['editDistance'] 
                                                for i, row in df_paired.iterrows()]
    
        str_cat += f"{cdr}-{min(edit_distance[f'edit_distance_{cdr}'])};"
    
    print(str_cat)

    print(f'Min heavy chain ED to seed: {min(ed_heavy)}')
    print(f'Min light chain ED to seed: {min(ed_light)}')
    print(f'Min fv chain ED to seed: {min(ed_comb)}')


def get_closest_match_denovo_seeds(csv_paired,
                                   seed_name=None,
                                   seed_csv="s3://prescient-data-dev/sandbox/mahajs17/Propen/denovo/round5_denovo_seeds_allcols.csv",
                                   partition=None):
    seed_df = pd.read_csv(seed_csv)
    print(seed_df.columns)
    print(seed_df["seqres_antigen"])
    seed_df["affinity_antigen_sequence"] = seed_df["seqres_antigen"]

    seed_df.rename(columns={"id": "seed_alias"}, inplace=True)
    
    df_paired = pd.read_parquet(csv_paired)
    if partition is not None:
        print(df_paired['partition'].unique())
        assert partition in df_paired['partition'].unique()
        df_paired = df_paired[df_paired['partition']==partition]
    print('dataset size', df_paired.shape[0])
    
    seed = seed_df[seed_df['seed_alias']==seed_name]
    print(seed.shape[0], seed.seed_alias)
    fv_heavy, fv_light = seed['fv_heavy'].iloc[0], seed['fv_light'].iloc[0]
    fv_comb = fv_heavy + fv_light
    heavy_seqs = df_paired['first_HeavyAA'].values.tolist() + df_paired['second_HeavyAA'].values.tolist()
    light_seqs = df_paired['first_LightAA'].values.tolist() + df_paired['second_LightAA'].values.tolist()
    comb_seqs = (df_paired['first_HeavyAA'] + df_paired['first_LightAA']).values.tolist() +\
    (df_paired['second_HeavyAA'] + df_paired['second_LightAA']).values.tolist()
    ed_heavy = [edlib.align(s, fv_heavy)['editDistance'] for s in heavy_seqs]
    ed_light = [edlib.align(s, fv_light)['editDistance'] for s in light_seqs]
    ed_comb = [edlib.align(s, fv_comb)['editDistance'] for s in comb_seqs]
    
    edit_distance = {}
    str_cat = ''
    for cdr, cdr_range in CDR_RANGES_AHO.items():
        chain = 'heavy' if cdr.startswith('H') else 'light'
        r1, r2 = cdr_range[0], cdr_range[1]+1
        chain_name = 'Heavy' if cdr.startswith('H') else 'Light'
        segment = seed['fv_heavy_aho'].iloc[0][r1:r2].replace('-','') if  cdr.startswith('H') \
                        else seed['fv_light_aho'][r1:r2].replace('-','')
        edit_distance[f'edit_distance_{cdr}'] = [edlib.align(row[f'first_{chain_name}AHoAA'][r1:r2].replace('-',''), 
                                                segment)['editDistance'] 
                                                for i, row in df_paired.iterrows()]
    
        str_cat += f"{cdr}-{min(edit_distance[f'edit_distance_{cdr}'])};"
    
    print(str_cat)

    print(f'Min heavy chain ED to seed: {min(ed_heavy)}')
    print(f'Min light chain ED to seed: {min(ed_light)}')
    print(f'Min fv chain ED to seed: {min(ed_comb)}')


