import os
import pandas as pd
import numpy as np
from edlib import align
import glob


def edit_distance(seq1: str, seq2: str) -> int:
    '''
    Calculate the edit distance between two sequences.
    '''
    return align(seq1, seq2)['editDistance']


def select_run_by_edit_distance(csvs: list, outpath: str, mean_edit_distance_range: tuple = (5, 7)) -> pd.DataFrame:
    '''
    find all csvs in a given path, calculate the mean edit distance of each design to the seed sequence, 
    and write out a file with seed_id/seed_alias, mean edit distance, std edit distance, number of designs and number of unique sequences,
    Also report edit distance for each chain and CDR region.
    '''
    # calculate the mean edit distance of each design to the seed sequence
    results = {'seed_alias': [], 'mean_edit_distance': [], 'std_edit_distance': [],
               'mean_edit_distance_cdrs': [], 'std_edit_distance_cdrs': [],
               'mean_edit_distance_fw': [], 'std_edit_distance_fw': [],
     'num_unique_sequences': [],'num_binders': [],'num_improv_binders': [], 'iterations': [], 'temperature': [],
     'step_size': [], 'steps': [], 'max_steps': [], 'path': []}
    for csv in csvs:
        designs_df = pd.read_csv(os.path.join(path, csv))
        if 'is_binder_class_probs' not in designs_df:
            continue
        seed_alias = designs_df['seed_id'].iloc[0]
        #designs_df = designs_df[(designs_df['edit_distance']>=5) & (designs_df['edit_distance']<16)]
        mean_edit_distance = designs_df['edit_distance'].mean()
        std_edit_distance = designs_df['edit_distance'].std()
        designs_df['edit_distance_cdrs'] = designs_df['edit_distance_H1'] + \
            designs_df['edit_distance_H2'] + designs_df['edit_distance_H3'] + \
            designs_df['edit_distance_L1'] + designs_df['edit_distance_L2'] + \
            designs_df['edit_distance_L3']
        
        designs_df['edit_distance_fw'] = designs_df['edit_distance'] - \
            designs_df['edit_distance_cdrs']
        mean_edit_distance_cdrs = designs_df['edit_distance_cdrs'].mean()
        std_edit_distance_cdrs = designs_df['edit_distance_cdrs'].std()
        mean_edit_distance_fw = designs_df['edit_distance_fw'].mean()
        std_edit_distance_fw = designs_df['edit_distance_fw'].std()

        num_designs = designs_df.shape[0]
        print(csv, 'is_binder_class_probs' in designs_df)
        binders = designs_df[designs_df['is_binder_class_probs']==1]
        num_binders, num_improv_binders = 0, 0
        if not binders.empty:
            num_binders = binders.shape[0]
            improved_binders = pd.DataFrame()
            improved_binders = binders[(binders['affinity_mean']>=binders['affinity_pkd_seed']) & (binders['affinity_pkd_seed']!=-1)]
            if not improved_binders.empty:
                num_improv_binders = improved_binders.shape[0]

        print(f'{csv} - Mean edit distance: {mean_edit_distance}, Std edit distance: {std_edit_distance}, Number of designs: {num_designs}')
        print(f'{csv} - Seed alias: {seed_alias}')
        results['seed_alias'].append(seed_alias)
        results['mean_edit_distance'].append(mean_edit_distance)
        results['std_edit_distance'].append(std_edit_distance)
        results['mean_edit_distance_cdrs'].append(mean_edit_distance_cdrs)
        results['std_edit_distance_cdrs'].append(std_edit_distance_cdrs)
        results['mean_edit_distance_fw'].append(mean_edit_distance_fw)
        results['std_edit_distance_fw'].append(std_edit_distance_fw)
        results['num_unique_sequences'].append(num_designs)
        results['num_binders'].append(num_binders)
        results['num_improv_binders'].append(num_improv_binders)

        for col in ['iterations', 'temperature', 'step_size', 'steps', 'max_steps']:
            if col in designs_df.columns:
                results[col].append(designs_df[col].iloc[0])
            else:
                results[col].append(None)

        results['path'].append(os.path.join(path, csv))
        
    results_df = pd.DataFrame(results)
    results_df.to_csv(outpath, index=False)

    # filter results by mean edit distance range
    results_filt_df = results_df[results_df['mean_edit_distance'].between(mean_edit_distance_range[0], mean_edit_distance_range[1])]
    results_filt_df.to_csv(
        outpath.replace('.csv',
                        f'_filtered_{mean_edit_distance_range[0]}-{mean_edit_distance_range[1]}.csv'
                        ),
                        index=False
                        )
    return results_df, results_filt_df


if __name__ == '__main__':
    #/data/mahajs17/Propen_sampling/ablations/${model_type}_models/
    # for model_type in ['cnn', 'cnn_skempiED5', 'cnn_skempiED7']:
    #     for seed 
    path = "/data/mahajs17/Propen_sampling/ablations/hoo_models/"
    subdirs = [t for t in os.listdir(path) if os.path.isdir(f"{path}/{t}")]
    dfs, dfs_filt = [], []
    for model in subdirs:
        subdir_path = f"{path}/{model}"
        subdir_seeds = [t for t in os.listdir(subdir_path) if os.path.isdir(f"{path}/{model}/{t}")]
        for seed_dir in subdir_seeds:
            eval_samples = glob.glob(f"{subdir_path}/{seed_dir}/results_cortex_inferencer21/*/*_affinity.csv")
            outpath = f"{subdir_path}/{seed_dir}/results_cortex_inferencer21/designs_edit_distance.csv"
            if len(eval_samples) > 0:
                df, df_filt = select_run_by_edit_distance(eval_samples, outpath=outpath)
                df['seed_tag'] = seed_dir
                df['model_tag'] = model
                df_filt['seed_tag'] = seed_dir
                df_filt['model_tag'] = model
                dfs.append(df)
                dfs_filt.append(df_filt)
    # csvs = glob.glob("/data/mahajs17/Propen_sampling/denovo/sabdab_7243_more/hoo_models/PropEn_SkempiED5pkd6.5_gt/logits/results_cortex_inferencer21/*/*_affinity.csv")
    # #["/data/mahajs17/Propen_sampling/denovo/sabdab_7243_more/hoo_models/PropEn_SkempiED5pkd6.5_gt/logits/results_cortex_inferencer21/samples_iter7_t1_N5000/samples_iter7_t1_N5000_affinity.csv",
    # #     "/data/mahajs17/Propen_sampling/denovo/sabdab_7243_more/hoo_models/PropEn_SkempiED5pkd6.5_gt/logits/results_cortex_inference/samples_iter10_t1_N5000/samples_iter10_t1_N5000_affinity.csv"]:
    # outpath = "/data/mahajs17/Propen_sampling/denovo/sabdab_7243_more/hoo_models/PropEn_SkempiED5pkd6.5_gt/logits/results_cortex_inferencer21/designs_edit_distance.csv"
    # df, df_filt = select_run_by_edit_distance(csvs, outpath=outpath)
    # df['seed_tag'] = 'denovo_sabdab_7243'
    # df['model_tag'] = 'gt_skempiED5_trastuzumab_ood'
    # df_filt['seed_tag'] = 'denovo_sabdab_7243'
    # df_filt['model_tag'] = 'gt'
    # dfs.append(df)
    # dfs_filt.append(df_filt)
    
    # csvs = glob.glob("/data/mahajs17/Propen_sampling/EGFRN032/hoo_models/PropEn_SkempiED5pkd6.5_gt/logits/results_cortex_inferencer21/*/*_N5000_affinity.csv")
    # #"/data/mahajs17/Propen_sampling/EGFRN032/hoo_models/PropEn_SkempiED5pkd6.5_gt/logits/results_cortex_inferencer21/samples_iter7_t1_N5000/samples_iter7_t1_N5000_affinity.csv"]:
    # outpath = "/data/mahajs17/Propen_sampling/EGFRN032/hoo_models/PropEn_SkempiED5pkd6.5_gt/logits/results_cortex_inferencer21/designs_edit_distance.csv"
    # df, df_filt = select_run_by_edit_distance(csvs, outpath=outpath)
    # df['seed_tag'] = 'egfrseed32'
    # df['model_tag'] = 'gt_skempiED5'
    # df_filt['seed_tag'] = 'egfrseed32'
    # df_filt['model_tag'] = 'gt'
    # dfs.append(df)
    # dfs_filt.append(df_filt)

    # csvs = glob.glob("/data/mahajs17/Propen_sampling/IL61409/hoo_models/PropEn_SkempiED5pkd6.5_IL61409_gt/logits/results_cortex_inferencer21/*/*_N5000_affinity.csv")
    # #["/data/mahajs17/Propen_sampling/IL61409/hoo_models/PropEn_SkempiED5pkd6.5_IL61409_gt/logits/results_cortex_inferencer21/samples_iter5_t1_N5000/samples_iter5_t1_N5000_affinity.csv"]:
    # outpath = "/data/mahajs17/Propen_sampling/IL61409/hoo_models/PropEn_SkempiED5pkd6.5_IL61409_gt/logits/results_cortex_inferencer21/designs_edit_distance.csv"
    # df, df_filt = select_run_by_edit_distance(csvs, outpath=outpath)
    # df['seed_tag'] = 'il6seed1409'
    # df['model_tag'] = 'gt_skempiED5'
    # df_filt['seed_tag'] = 'il6seed1409'
    # df_filt['model_tag'] = 'gt'
    # dfs.append(df)
    # dfs_filt.append(df_filt)

    # csvs = glob.glob("/data/mahajs17/Propen_sampling/OSMN013/hoo_models/PropEn_SkempiPrescientED5pkd6.5_OsmN013_gt/logits/results_cortex_inferencer21/*/*_N5000_affinity.csv")
    # #["/data/mahajs17/Propen_sampling/OSMN013/hoo_models/PropEn_SkempiPrescientED5pkd6.5_OsmN013_gt/logits/results_cortex_inferencer21/samples_iter5_t1_N5000/samples_iter5_t1_N5000_affinity.csv",
    # #            "/data/mahajs17/Propen_sampling/OSMN013/hoo_models/PropEn_SkempiPrescientED5pkd6.5_OsmN013_gt/logits/results_cortex_inferencer21/samples_iter5_t0._N5000/samples_iter5_t1_N5000_affinity.csv"]:
    # outpath = "/data/mahajs17/Propen_sampling/OSMN013/hoo_models/PropEn_SkempiPrescientED5pkd6.5_OsmN013_gt/logits/results_cortex_inferencer21/designs_edit_distance.csv"
    # df, df_filt = select_run_by_edit_distance(csvs, outpath=outpath)
    # df['seed_tag'] = 'OSM-N013'
    # df['model_tag'] = 'gt_skempiED5'
    # df_filt['seed_tag'] = 'OSM-N013'
    # df_filt['model_tag'] = 'gt'
    # dfs.append(df)
    # dfs_filt.append(df_filt)

    df = pd.concat(dfs)
    df.to_csv(f"{path}/designs_edit_distance_allmodels_all_seeds.csv")

    df_filt = pd.concat(dfs_filt)
    df_filt.to_csv(f"{path}/designs_edit_distance_allmodels_all_seeds_filtered.csv")
