import torch
import pandas as pd
from cortex.model import predict_ab_properties
from cortex.utils import s3_glob
from prescient.io import load_parquet
from cortex.model.tree.ab_model import predict_ab_properties
import os
#from cortex.utils import add_anarci_metadata
#from cortex.tokenization.ab_tokenizer import tokenize_igg_ag_complex_df
import edlib
from cortex.transforms.functional import assemble_igg_ag_complex

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from logos import plot_logos

from prescient.constants import CDR_RANGES_AHO
from prescient.transforms.functional import anarci_numbering

#cortex_str=''
#s3_dir = "s3://prescient-data-dev/oracles/round_18/cortex"
cortex_str='r21'
s3_dir = "s3://prescient-data-dev/oracles/round_21/cortex"
checkpoint_list = s3_glob(s3_dir)
tokenization_kwargs = {}
source='mahajs17_structpropen'


def add_percentage_annotations(data, ax):
    # Calculate the total number of data points to compute percentages
    total = len(data)

    # Iterate over each bar (patch) in the histogram
    for p in ax.patches:
        # Get the height of the bar
        height = p.get_height()
        
        # Calculate the percentage
        percentage = 100 * height / total
        
        # Annotate the percentage on top of the bar
        ax.annotate(f'{percentage:.1f}%; N={int(height)}', 
                    (p.get_x() + p.get_width() / 2., height), 
                    ha='center', va='bottom')


def prepare_dataset(csvfile, seed_file, method='StructPropEnHOO', write_to_csv=False):
    ext = '.csv'
    if csvfile.endswith('.parquet'):
        ext = '.parquet'
        df = pd.read_parquet(csvfile)
    else:
        df = pd.read_csv(csvfile)
    print(df.columns)
    print(df.shape)
    
    df_seed = pd.read_parquet(seed_file)
    print(df_seed)
    print(df_seed.columns)
    for col in ['fv_heavy', 'fv_light']:
        df_seed[f'{col}_seed'] = df_seed[col]
        df_seed[f'{col}_seed_aho'] = df_seed[col+'_aho']
    
    df_seed = df_seed[['fv_heavy_seed', 'fv_light_seed', 'seed_id', 'affinity_antigen_sequence', 'target']]
    print(df_seed['fv_light_seed'].iloc[0])
    print(df['fv_light_seed'].iloc[0])
    assert df_seed['fv_light_seed'].iloc[0]==df['fv_light_seed'].iloc[0]
    assert df_seed['fv_heavy_seed'].iloc[0]==df['fv_heavy_seed'].iloc[0]
    df['affinity_antigen_sequence'] = df_seed['affinity_antigen_sequence'].iloc[0]
    df['target'] = df_seed['target'].iloc[0]
    if 'seed_id' in df_seed:
        df['seed_id'] = df_seed['seed_id'].iloc[0]
    if 'seed_alias' in df_seed:
        df['seed_id'] = df_seed['seed_alias'].iloc[0]
    df['method'] = method
    df['source'] = source
    print(df.shape[0])
    df = df[~df['fv_heavy'].str.contains('X')]
    df = df[~df['fv_light'].str.contains('X')]
    print(df.shape[0])

    if write_to_csv:
        outfile = csvfile.split(ext)[0] + '_litlsub.csv'
        df.to_csv(outfile, index=False)
    return df


def run_inference(df, outdir):
    print(df.shape)
    print(df.columns)
    print(df['fv_heavy'])
    df = df[~(df['fv_heavy_aho'].isna()) & ~(df['fv_light_aho'].isna())]
    df = assemble_igg_ag_complex(df,
                                    include_anarci_metadata=False,
                                    randomize_metadata_order=False,
                                    randomize_chain_order=False,
                                    use_custom_chain_tokens=False,
                                    use_custom_format_tokens=False,
                                    )
        
    res = {}
    for checkpoint in checkpoint_list:
        predictions = predict_ab_properties(
            data=df,
            cfg_fpath=checkpoint + ".yaml",
            weight_fpath=checkpoint + ".pt",
            format_for_leaderboard=False,
            return_as_df=False,
        )
        for key in predictions:
            res.setdefault(key, []).append(predictions[key])

    os.makedirs(outdir, exist_ok=True)
    for key in res:
        if not torch.is_tensor(res[key][0]):
            continue
        #print(key, [t.shape for t in res[key]])
        if key == 'full_seq_logits':
            continue
        skip=False
        for t in res[key]:
            if t.shape != res[key][0].shape:
                skip=True
                break
        if skip:
            continue
        res[key] = torch.cat(res[key], dim=0)
        print(key, res[key].shape, res[key].dtype)
        torch.save(res[key], f'{outdir}/{key}.pt')

    return res


def get_affinity_ami(seed_file):
    df = pd.read_parquet("s3://prescient-data-dev/designdb/lake/products/gred_affinity") 
    df = df[~((df['fv_heavy_aho'].isna()) | (df['fv_light_aho'].isna()))]
    for chain in ['heavy', 'light']:
        #print(df[f'fv_{chain}_aho'])
        #print(chain)
        df[f'fv_{chain}'] = df[f'fv_{chain}_aho'].apply(lambda x: x.replace('-', ''))
    
    amivantamab = df[df["batch_alias"] == "anti-MET.71G2"]
    print(amivantamab)
    print(amivantamab.columns)
    df_seed = pd.read_parquet(seed_file)
    #print(df_seed.columns)
    heavy_seq = df_seed['fv_heavy'].iloc[0]
    light_seq = df_seed['fv_light'].iloc[0]
    
    edit_distance = [edlib.align(seq, heavy_seq) for seq in amivantamab['fv_heavy_aho'].values.tolist()]
    print(edit_distance)
    edit_distance = [edlib.align(seq, light_seq) for seq in amivantamab['fv_light_aho'].values.tolist()]
    print(edit_distance)

    #df = df[(df['fv_heavy']==heavy_seq) & (df['fv_light']==light_seq)]
    #print(df)
    

def get_affinity_seed(seed_file):
    df_seed = pd.read_parquet(seed_file)
    heavy_seq = df_seed['fv_heavy'].iloc[0]
    light_seq = df_seed['fv_light'].iloc[0]
    round_number = 10.0
    df = pd.read_parquet(f's3://prescient-data-dev/designdb/lake/products/all_rounds/round_number={round_number}/')
    df = df[(df['fv_heavy']==heavy_seq) & (df['fv_light']==light_seq)]
    df = df[~(df['affinity_pkd'].isna())]
    if df.empty:
        return -1.0
    print(df['affinity_pkd'])
    print(df['affinity_pkd'].unique())
    return df['affinity_pkd'].unique()[0]


def get_predictions(csvfile, seed_file, method='StructPropEn', title='',
                    single_chain_only=False,
                    ext='.csv'):
    basename = os.path.basename(csvfile).replace(ext, '')
    outdir = os.path.dirname(csvfile) + f'/inference_cortex{cortex_str}/{basename}'

    outdir_results = os.path.dirname(csvfile) + f'/results_cortex_inference{cortex_str}/{basename}'
    outfile_final = f'{outdir_results}/{basename}_improved_binders_stats.csv'
    
    if single_chain_only:
        outdir = os.path.dirname(csvfile) + f'/inference_cortex{cortex_str}/{basename}_heavy'
    keys = ['affinity_mean', 'is_binder_class_probs', 'affinity_st_dev', 'is_binder_top_1',
            'is_expressor_class_probs']
    results_found = True
    for key in keys:
        if not os.path.exists(f'{outdir}/{key}.pt'):
            results_found = False
            break
    
    df = prepare_dataset(csvfile, seed_file, method=method, write_to_csv=False)
    print(df.columns, df.shape)
    if 'fv_heavy_aho' not in df.columns:
        print('fv_heavy_aho not found')
        return
    if not results_found:
        if df.empty:
            return
        if single_chain_only:
            df['fv_light'] = df['fv_light_seed']
            df['fv_light_aho'] = df['fv_light_seed_aho']
            df['edit_distance_light'] = 0
        res = run_inference(df, outdir)
    else:
        results = os.listdir(outdir)
        res = {}
        for file in results:
            key = os.path.basename(file).replace('.pt', '')
            res[key] = torch.load(f'{outdir}/{file}')
    
    for key in res:
        if key in ['affinity_mean', 'affinity_st_dev']:
            print(key, res[key].shape)
            res[key] = res[key].mean(0)
        elif key in ['is_binder_class_probs', 'is_expressor_class_probs',
                     'is_binder_top_1', 'is_expressor_top_1']:
            res[key] = res[key].float().mean(0)

    for key in ['is_binder_class_probs', 'is_expressor_class_probs']:
        if df.shape[0] == res[key].shape[0]:
            df[key] = res[key].argmax(-1).numpy().tolist()
            df[key + '_class1'] = res[key][:, 1].numpy().tolist()
            df[[key, key + '_class1']] = df[[key, key + '_class1']].apply(pd.to_numeric)
            print(df[key])
    
    for key in ['is_binder_top_1', 'is_expressor_top_1']:
        if df.shape[0] == res[key].shape[0]:
            df[key] = res[key].numpy().tolist()
            print(df[key])
    
    for key in ['affinity_mean', 'affinity_st_dev']:
        if df.shape[0] == res[key].shape[0]:
            df[key] = res[key].squeeze(-1).numpy().tolist()
            print(df[key])
    
    
    if single_chain_only:
        outdir_results = os.path.dirname(csvfile) + f'/results_cortex_inference{cortex_str}/{basename}_heavy'
    os.makedirs(outdir_results, exist_ok=True)
    pkd_seed = get_affinity_seed(seed_file)
    df['affinity_pkd_seed'] = pkd_seed

    if not 'fv_heavy_aho' in df:
        heavy_aho = anarci_numbering(df['fv_heavy'])
        light_aho = anarci_numbering(df['fv_light'])
        df['fv_heavy_aho'] = heavy_aho
        df['fv_light_aho'] = light_aho
    
    if not 'fv_light_seed_aho' in df:
        df['fv_heavy_seed_aho'] = anarci_numbering(df['fv_heavy_seed'])
        df['fv_light_seed_aho'] = anarci_numbering(df['fv_light_seed'])
    
    if 'edit_distance_heavy' not in df:
        # ADD PER CDR ED
        df['edit_distance_heavy'] = [edlib.align(row['fv_heavy'], 
                                                    row['fv_heavy_seed'])['editDistance'] 
                                                    for i, row in df.iterrows()]
        df['edit_distance_light'] = [edlib.align(row['fv_light'], 
                                                    row['fv_light_seed'])['editDistance'] 
                                                    for i, row in df.iterrows()]
        print(df.columns)

        for cdr, cdr_range in CDR_RANGES_AHO.items():
            chain = 'heavy' if cdr.startswith('H') else 'light'
            print(cdr, cdr_range, chain)
            r1, r2 = cdr_range[0], cdr_range[1]+1
            print(df[f'fv_{chain}_aho'])
            print(df[f'fv_{chain}_seed_aho'])
            try:
                df[f'edit_distance_{cdr}'] = [edlib.align(row[f'fv_{chain}_aho'][r1:r2].replace('-',''), 
                                                        row[f'fv_{chain}_seed_aho'][r1:r2].replace('-',''))['editDistance'] 
                                                        for i, row in df.iterrows()]
            except:
                continue

    df['edit_distance'] = df['edit_distance_heavy'] + df['edit_distance_light']
    if 'edit_distance_H3' in df:
        df['edit_distance_heavy_cdrs'] = df['edit_distance_H1'] \
                                        + df['edit_distance_H2'] + \
                                        df['edit_distance_H3'] + \
                                            df['edit_distance_H4']
    if 'edit_distance_heavy_cdrs' in df:
        df['edit_distance_heavy_fws'] = df['edit_distance_heavy'] - \
            df['edit_distance_heavy_cdrs']
        
    if 'edit_distance_L3' in df:
        df['edit_distance_light_cdrs'] = df['edit_distance_L1'] \
                                    + df['edit_distance_L2'] + \
                                      df['edit_distance_L3'] + \
                                        df['edit_distance_L4']
    
    if 'edit_distance_light_cdrs' in df:
        df['edit_distance_light_fws'] = df['edit_distance_light'] - \
            df['edit_distance_light_cdrs']

    if single_chain_only:
        df['edit_distance_light_cdrs'] = 0
        df['edit_distance_light_fws'] = 0

    df.to_csv(f'{outdir_results}/{basename}_affinity.csv', index=False)

    df_stats = df[["edit_distance_heavy", "edit_distance_light", "edit_distance"]].describe()
    df_stats.to_csv(f'{outdir_results}/{basename}_stats.csv')

    binders = df[df['is_binder_class_probs']==1]
    binders.to_csv(f'{outdir_results}/{basename}_binders.csv', index=False)
    binders_stats = binders[["edit_distance_heavy", "edit_distance_light", "edit_distance"]].describe()
    binders_stats.to_csv(f'{outdir_results}/{basename}_binders_stats.csv')

    improved_binders = pd.DataFrame()
    if (pkd_seed != -1.0):
        improved_binders = df[(df['is_binder_class_probs']==1) & (df['affinity_mean']>=pkd_seed)]
        improved_binders.to_csv(f'{outdir_results}/{basename}_improved_binders.csv', index=False)
        improved_binders_stats = improved_binders[["edit_distance_heavy", "edit_distance_light", "edit_distance"]].describe()
        improved_binders_stats.to_csv(f'{outdir_results}/{basename}_improved_binders_stats.csv')

    matplotlib.rcdefaults()
    for curdf, label in zip([binders, improved_binders],
                     ['binders', 'improved_binders']
                     ):
        if curdf.empty:
                continue
            
        for chain in ['', '_heavy', '_light']:
            if chain == '':
                continue
            
            for i, ref_seq in enumerate(curdf[f'fv{chain}_seed_aho'].unique()):
                df_filter = curdf[curdf[f'fv{chain}_seed_aho']==ref_seq]
                sequences = df_filter[f'fv{chain}_aho'].values.tolist()
                plot_logos(sequences,
                           ref_seq=ref_seq,
                           logo_file_base=f'{outdir_results}/logos_seed{i}{chain}_{label}_{basename}',
                           chain=chain.upper()[0]
                           )
            matplotlib.rcdefaults()
            outfile_name = f'{outdir_results}/histplotED{chain}_{label}_{basename}_{{}}.png'
            if df_filter.empty:
                continue
            x = f'edit_distance{chain}'
            ax = sns.histplot(curdf, x=x, discrete=True)
            ax.set_title(title)
            plt.tight_layout()
            plt.savefig(outfile_name.format(x), dpi=300, transparent=False)
            plt.close()
            matplotlib.rcdefaults()
            outfile_name = f'{outdir_results}/histplotEDBinding{chain}_{label}_{basename}_{{}}.png'
            x = f'edit_distance{chain}'
            if len(curdf['is_binder_class_probs'].unique())>1:
                ax = sns.histplot(curdf, x=x, hue='is_binder_class_probs', discrete=True)
                ax.set_title(title)
                plt.tight_layout()
                plt.savefig(outfile_name.format(x), dpi=300, transparent=False)
                plt.close()
    
    matplotlib.rcdefaults()

    for x in ['is_binder_class_probs', 'is_expressor_class_probs']:
        ax = sns.countplot(df, x=x)
        add_percentage_annotations(df[x].values.tolist(), ax)
        ax.set_title(title)
        outfile_name = f'{outdir_results}/counts_{basename}_{{}}.png'
        plt.tight_layout()
        plt.savefig(outfile_name.format(x), dpi=300, transparent=False)
        plt.close()
        
        for suf in ['', '_heavy', '_light']:
            df[f'edit_distance{suf}_cat'] = ['<8' if row[f'edit_distance{suf}']<8 else '>8'
                                             for i, row in df.iterrows()]
            outfile_name = f'{outdir_results}/histplot_{basename}_{{}}.png'
            sns.histplot(df, x=x, hue=f'edit_distance{suf}_cat', multiple="dodge", stat='density', common_norm=False)
            ax.set_title(title)
            plt.tight_layout()
            plt.savefig(outfile_name.format(x+f'ED{suf}'), dpi=300, transparent=False)
            plt.close()
            df_ed = df[df[f'edit_distance{suf}_cat']=='<8']
            if df_ed.empty:
                continue
            sns.histplot(df_ed, x=x, hue=f'edit_distance{suf}_cat', stat='density')
            ax.set_title(title)
            plt.tight_layout()
            plt.savefig(outfile_name.format(x+f'ED{suf}_lessthan8'), dpi=300, transparent=False)
            plt.close()
            outfile_name = f'{outdir_results}/counts_{basename}_{{}}.png'
            ax = sns.countplot(df_ed, x=x)
            add_percentage_annotations(df_ed[x].values.tolist(), ax)
            ax.set_title(title)
            plt.tight_layout()
            plt.savefig(outfile_name.format(x+f'ED{suf}_lessthan8'), dpi=300, transparent=False)
            plt.close()

    for x in ['is_binder_class_probs', 'is_expressor_class_probs']:
        sns.histplot(df, x=x+'_class1')
        ax.set_title(title)
        plt.tight_layout()
        plt.savefig(outfile_name.format(x+'_class1'), dpi=300, transparent=False)
        plt.close()
    
    for x in ['is_binder_top_1', 'is_expressor_top_1']:
        sns.histplot(df, x=x)
        ax.set_title(title)
        plt.tight_layout()
        plt.savefig(outfile_name.format(x), dpi=300, transparent=False)
        plt.close()

    outfile_name = f'{outdir_results}/histplot_{basename}_{{}}.png'
    for x in ['affinity_mean', 'affinity_st_dev']:
        ax = sns.histplot(df, x=x)
        if (x == 'affinity_mean') and (pkd_seed != -1.0):
            ax.axvline(x=pkd_seed, ls='dashed', c='black')
        ax.set_title(title)
        plt.tight_layout()
        plt.savefig(outfile_name.format(x), dpi=300, transparent=False)
        plt.close()

    outfile_name = f'{outdir_results}/stripplot_{basename}_{{}}.png'
    for x in ['affinity_mean', 'is_binder_class_probs']:
        ax = sns.stripplot(df, x=x)
        if (x == 'affinity_mean') and (pkd_seed != -1.0):
            ax.axvline(x=pkd_seed, ls='dashed', c='black')
        ax.set_title(title)
        plt.tight_layout()
        plt.savefig(outfile_name.format(x), dpi=300, transparent=False)
        plt.close()

    # outfile_name = f'{outdir_results}/histplot2d_{basename}_{{}}.png'
    # fig, axes = plt.subplots(nrows=4, ncols=2)
    # for i, x in enumerate(['affinity_mean', 'affinity_st_dev']):
    #     ax = sns.histplot(df, x=x, y='edit_distance_heavy_cdrs', discrete=True, ax=axes[0, i])
    #     ax.set_ylabel('ED Heavy CDRs')
    #     if x == 'affinity_mean':
    #         ax.axvline(x=pkd_seed, ls='dashed', c='black')
    #     if not single_chain_only:
    #         ax = sns.histplot(df, x=x, y='edit_distance_light_cdrs', discrete=True, ax=axes[1, i])
    #         ax.set_ylabel('ED Light CDRs')
    #         if x == 'affinity_mean':
    #             ax.axvline(x=pkd_seed, ls='dashed', c='black')
    #     ax = sns.histplot(df, x=x, y='edit_distance_heavy_fws', discrete=True, ax=axes[2, i])
    #     ax.set_ylabel('ED Heavy FW')
    #     if x == 'affinity_mean':
    #         ax.axvline(x=pkd_seed, ls='dashed', c='black')
    #     if not single_chain_only:
    #         ax = sns.histplot(df, x=x, y='edit_distance_light_fws', discrete=True, ax=axes[3, i])
    #         ax.set_ylabel('ED Light FW')
    #         if x == 'affinity_mean':
    #             ax.axvline(x=pkd_seed, ls='dashed', c='black')
    # plt.suptitle(title)
    # plt.subplots_adjust(wspace=0.05, hspace=0.05)
    # plt.tight_layout()
    # plt.savefig(outfile_name.format(f'{x}vsEDregions'), dpi=300, transparent=False)
    # plt.close()

    if 'edit_distance_heavy_cdrs' in df:
        outfile_name = f'{outdir_results}/scatterplot2d_{basename}_{{}}.png'
        fig, axes = plt.subplots(nrows=4, ncols=2, figsize=(12,6))
        for i, x in enumerate(['affinity_mean', 'affinity_st_dev']):
            ax = sns.scatterplot(df, x=x, y='edit_distance_heavy_cdrs', ax=axes[0, i])
            ax.set_ylabel('ED Heavy CDRs')
            if (x == 'affinity_mean') and (pkd_seed != -1.0):
                ax.axvline(x=pkd_seed, ls='dashed', c='black')
            if (not single_chain_only) & ('edit_distance_light_cdrs' in df) :
                ax = sns.scatterplot(df, x=x, y='edit_distance_light_cdrs', ax=axes[1, i])
                ax.set_ylabel('ED Light CDRs')
                if x == 'affinity_mean':
                    ax.axvline(x=pkd_seed, ls='dashed', c='black')
            ax = sns.scatterplot(df, x=x, y='edit_distance_heavy_fws', ax=axes[2, i])
            ax.set_ylabel('ED Heavy FW')
            if (x == 'affinity_mean') and (pkd_seed != -1.0):
                ax.axvline(x=pkd_seed, ls='dashed', c='black')
            if (not single_chain_only) & ('edit_distance_light_cdrs' in df):
                ax = sns.scatterplot(df, x=x, y='edit_distance_light_fws', ax=axes[3, i])
                if x == 'affinity_mean':
                    ax.axvline(x=pkd_seed, ls='dashed', c='black')
                ax.set_ylabel('ED Light FW')
        plt.suptitle(title)
        plt.subplots_adjust(wspace=0.05, hspace=0.05)
        plt.tight_layout()
        plt.savefig(outfile_name.format(f'{x}vsEDregions'), dpi=300, transparent=False)
        plt.close()

    
        
    # outfile_name = f'{outdir_results}/histplot2d_{basename}_{{}}.png'
    # fig, axes = plt.subplots(nrows=3, ncols=2)
    # for i, x in enumerate(['affinity_mean', 'affinity_st_dev']):
    #     ax = sns.histplot(df, x=x, y='edit_distance_heavy', discrete=True, ax=axes[0, i])
    #     if x == 'affinity_mean':
    #         ax.axvline(x=pkd_seed, ls='dashed', c='black')
    #     if not single_chain_only:
    #         ax = sns.histplot(df, x=x, y='edit_distance_light', discrete=True, ax=axes[1, i])
    #         if x == 'affinity_mean':
    #             ax.axvline(x=pkd_seed, ls='dashed', c='black')
    #     ax = sns.histplot(df, x=x, y='edit_distance', discrete=True, ax=axes[2, i])
    #     if x == 'affinity_mean':
    #         ax.axvline(x=pkd_seed, ls='dashed', c='black')
    # plt.suptitle(title)
    # plt.subplots_adjust(wspace=0.05, hspace=0.05)
    # plt.tight_layout()
    # plt.savefig(outfile_name.format(f'{x}vsEDtotal'), dpi=300, transparent=False)
    # plt.close()

    outfile_name = f'{outdir_results}/scatterplot2d_{basename}_{{}}.png'
    fig, axes = plt.subplots(nrows=3, ncols=2)
    for i, x in enumerate(['affinity_mean', 'affinity_st_dev']):
        ax = sns.scatterplot(df, x=x, y='edit_distance_heavy', ax=axes[0, i])
        if (x == 'affinity_mean') and (pkd_seed != -1.0):
            ax.axvline(x=pkd_seed, ls='dashed', c='black')
        if not single_chain_only:
            ax = sns.scatterplot(df, x=x, y='edit_distance_light', ax=axes[1, i])
            if (x == 'affinity_mean') and (pkd_seed != -1.0):
                ax.axvline(x=pkd_seed, ls='dashed', c='black')
        ax = sns.scatterplot(df, x=x, y='edit_distance', ax=axes[2, i])
        if (x == 'affinity_mean') and (pkd_seed != -1.0):
            ax.axvline(x=pkd_seed, ls='dashed', c='black')
    plt.suptitle(title)
    plt.subplots_adjust(wspace=0.05, hspace=0.05)
    plt.tight_layout()
    plt.savefig(outfile_name.format(f'{x}vsEDtotal'), dpi=300, transparent=False)
    plt.close()

    outfile_name = f'{outdir_results}/jointplot_colored_{basename}_{{}}.png'
    g = sns.jointplot(df, x='affinity_mean', y='edit_distance',
                      palette='grey')
    g.plot_joint(sns.kdeplot, color="green", zorder=0, levels=6)
    g.plot_marginals(sns.rugplot, color="green", height=-.15, clip_on=False)
    if (pkd_seed != -1.0):
        g.refline(x=pkd_seed, ls='dashed', c='black')
    plt.suptitle(title)
    plt.tight_layout()
    plt.savefig(outfile_name.format(f'{x}vsEDtotal'), dpi=300, transparent=False)
    plt.close()
        
  
if __name__ == "__main__":
    #csvfile='/data/mahajs17/Propen_sampling/EGFRN032/hoo_models/Mixup_EGFRN032_SkempiPrescientED5pkd6.5/logits/samples_iter2_t0.5_N1000.csv'
    #csvfile='/data/mahajs17/Propen_sampling/EGFRN032/iid_models/Mixup_SkempiPrescientED5pkd6.5/logits/samples_iter2_t0.5_N1000.csv'
    #seed_file ='s3://prescient-data-dev/sandbox/mahajs17/paratop_enhancer/testsets/prescient_accepted_seeds_seed_idEGFR-N032.parquet'
    # #prepare_dataset(csvfile, seed_file, method='StructPropEnIID')
    # csvfile='/data/mahajs17/Propen_sampling/EGFRN032/iid_models/Mixup_SkempiPrescientED5pkd6.5/logits/samples_iter1_t0.5_N1000.csv'
    # #get_predictions(csvfile, seed_file, method='StructPropEnIID')
    # csvfile='/data/mahajs17/Propen_sampling/EGFRN032/hoo_models/Mixup_EGFRN032_SkempiPrescientED5pkd6.5/logits/samples_iter1_t0.5_N1000.csv'
    # #get_predictions(csvfile, seed_file, method='StructPropEnHOO')
    # csvfile='/data/mahajs17/Propen_sampling/EGFRN032/hoo_models/PropEn_EGFRN032_SkempiPrescientED5pkd6.5/logits/samples_iter1_t0.5_N1000.csv'
    # #get_predictions(csvfile, seed_file, method='StructPropEnHOO')
    # csvfile='/data/mahajs17/Propen_sampling/EGFRN032/iid_models/Mixup_SkempiPrescientAABED5pkd6.5Large/logits/samples_iter2_t0.5_N1000.csv'
    # #get_predictions(csvfile, seed_file, method='StructPropEnIID')
    
    seed_file ='s3://prescient-data-dev/sandbox/mahajs17/paratop_enhancer/testsets/prescient_accepted_seeds_seed_idEGFR-N032.parquet'
    #csvfile='/data/mahajs17/Propen_sampling/EGFRN032/iid_models/PropEn_SkempiED5pkd6.5/logits/samples_iter1_t0.5_N5000.csv'
    #get_predictions(csvfile, seed_file, method='StructPropEnIID')
    
    #seed_file = 's3://prescient-data-dev/sandbox/mahajs17/paratop_enhancer/testsets/prescient_accepted_seeds_seed_idEGFR-N032.parquet'
    #csvfile = "/data/mahajs17/Propen_sampling/EGFRN032/hoo_models/PropEn_SkempiED5pkd6.5_gt/logits/samples_iter10_t1_N5000.csv"
    # get_predictions(csvfile, seed_file, method='StructPropEnHOO')

    #csvfile = "/data/mahajs17/Propen_sampling/EGFRN032/hoo_models/PropEn_SkempiED5pkd6.5_gtonly/logits/samples_iter1_t0.7_N5000.csv"
    #get_predictions(csvfile, seed_file, method='StructPropEnHOO')

    #seed_file ='s3://prescient-data-dev/sandbox/mahajs17/paratop_enhancer/testsets/prescient_accepted_seeds_seed_idEGFR-N032.parquet'
    #csvfile = "/data/mahajs17/Propen_sampling/EGFRN032/hoo_models/PropEn_SkempiED5pkd6.5_gt/logits/samples_iter7_t1_N5000.csv"
    #get_predictions(csvfile, seed_file, method='StructPropEnHOO')
    #csvfile = "/data/mahajs17/Propen_sampling/EGFRN032/iid_models/PropEn_SkempiED5pkd6.5/logits/samples_iter1_t0.5_N5000.csv"
    #get_predictions(csvfile, seed_file, method='StructPropEnHOO')
    #csvfile = "/data/mahajs17/Propen_sampling/EGFRN032/hoo_models/Seq_Mixup_EGFRN032_SkempiPrescientED5pkd6.5/EGFR_0.9.parquet"
    #get_predictions(csvfile, seed_file, method='SeqPropEnHOO')
    #csvfile = "/data/mahajs17/Propen_sampling/EGFRN032/hoo_models/Seq_Mixup_EGFRN032_SkempiED5pkd6.5/EGFR_0.3.parquet"
    #get_predictions(csvfile, seed_file, method='SeqPropEnHOO')

    #csvfile = "/data/mahajs17/Propen_sampling/EGFRN032/hoo_models/PropEn_SkempiPrescientED5pkd6.5_EGFRN032_gt/logits/samples_iter2_t1_N5000.csv"
    #get_predictions(csvfile, seed_file, method='StructPropEnHOO')
    # csvfile = "/data/mahajs17/Propen_sampling/EGFRN032/hoo_models/PropEn_SkempiED5pkd6.5_gt/logits/samples_iter4_t1_N5000.csv"
    # get_predictions(csvfile, seed_file, method='StructPropEnHOO')
    # csvfile = "/data/mahajs17/Propen_sampling/EGFRN032/hoo_models/PropEn_SkempiED5pkd6.5_gt/logits/samples_iter3_t1_N5000.csv"
    # get_predictions(csvfile, seed_file, method='StructPropEnHOO')

    #seed_file = "s3://prescient-data-dev/sandbox/mahajs17/paratop_enhancer/testsets/prescient_accepted_seeds_seed_idAmivantamab.parquet"
    #get_affinity_seed(seed_file)
    #get_affinity_ami(seed_file)
    #seed_file = "s3://prescient-data-dev/sandbox/mahajs17/paratop_enhancer/testsets/prescient_denovo_round5_seeds_seed_aliassabdab_7243_updated.parquet"
    #get_affinity_seed(seed_file)
    #csvfile = "/data/mahajs17/Propen_sampling/denovo/sabdab_7243/hoo_models/PropEn_SkempiED5pkd6.5_gt/logits/samples_iter2_t1_N4000.csv"
    #get_predictions(csvfile, seed_file, method='StructPropEnHOO')
    #csvfile = "/data/mahajs17/Propen_sampling/denovo/sabdab_7243/hoo_models/PropEn_SkempiED5pkd6.5_gt/logits/samples_iter3_t1_N4000.csv"
    #get_predictions(csvfile, seed_file, method='StructPropEnHOO')
    #csvfile = "/data/mahajs17/Propen_sampling/denovo/sabdab_7243/hoo_models/PropEn_SkempiED5pkd6.5_gt/logits/samples_iter4_t1_N4000.csv"
    #get_predictions(csvfile, seed_file, method='StructPropEnHOO')
    #csvfile = "/data/mahajs17/Propen_sampling/denovo/sabdab_7243_more/hoo_models/PropEn_SkempiED5pkd6.5_gt/logits/samples_iter9_t1_N5000.csv"
    #get_predictions(csvfile, seed_file, method='StructPropEnHOO')
    #csvfile = "/data/mahajs17/Propen_sampling/denovo/sabdab_7243_more/hoo_models/PropEn_SkempiED5pkd6.5_gt/logits/samples_iter10_t1_N5000.csv"
    #get_predictions(csvfile, seed_file, method='StructPropEnHOO')
    
    seed_file = "s3://prescient-data-dev/sandbox/mahajs17/paratop_enhancer/testsets/prescient_denovo_round5_seeds_seed_aliassabdab_7243_updated.parquet"
    #csvfile = "/data/mahajs17/Propen_sampling/Paper/baselines/antifold/20250417_0614443_1n8z_antifold_t0.2_N5000.csv"
    #get_predictions(csvfile, seed_file, method='AntiFold')
    #csvfile = "/data/mahajs17/Propen_sampling/Paper/baselines/antifold/20250416_0761546_1n8z_antifold_t0.5_N5000.csv"
    #get_predictions(csvfile, seed_file, method='AntiFold')

    #csvfile = "/homefs/home/mahajs17/repositories/AntiFold/antifold_output/1n8z_BAC_sequences_t0.2.csv"
    #get_predictions(csvfile, seed_file, method='AntiFold')
    #csvfile = "/homefs/home/mahajs17/repositories/AntiFold/antifold_output_Tpt5/1n8z_BAC_sequences_t0.5.csv"
    #get_predictions(csvfile, seed_file, method='AntiFold')

    #csvfile="/data/mahajs17/Propen_sampling/IL61409/hoo_models/PropEn_SkempiED5pkd6.5_IL61409_gt/logits/samples_iter5_t1_N5000.csv"
    #seed_file = "s3://prescient-data-dev/sandbox/mahajs17/paratop_enhancer/testsets/prescient_accepted_seeds_seed_idIL6-1409.parquet"
    #get_predictions(csvfile, seed_file, method='StructPropEnHOO')
    #csvfile="/data/mahajs17/Propen_sampling/OSMN013/hoo_models/PropEn_SkempiPrescientED5pkd6.5_OsmN013_gt/logits/samples_iter5_t1_N5000.csv"
    #csvfile="/data/mahajs17/Propen_sampling/OSMN013/hoo_models/PropEn_SkempiPrescientED5pkd6.5_OsmN013_gt/logits/samples_iter7_t1_N5000.csv"    
    #seed_file = "s3://prescient-data-dev/sandbox/mahajs17/paratop_enhancer/testsets/prescient_accepted_seeds_seed_idOSM-N013.parquet"
    #get_predictions(csvfile, seed_file, method='StructPropEnHOO')

    #csvfile="/data/mahajs17/Propen_sampling/denovo/sabdab_7243_more/hoo_models/PropEn_SkempiED5pkd6.5_gt/logits/samples_iter10_t1_N5000.csv"
    #seed_file = "s3://prescient-data-dev/sandbox/mahajs17/paratop_enhancer/testsets/prescient_denovo_round5_seeds_seed_aliassabdab_7243_updated.parquet"
    #get_predictions(csvfile, seed_file, method='StructPropEnHOO')

    #csvfile="/data/mahajs17/Propen_sampling/sequence_propen/samples_sabdab_7243_ED50/samples_heavy_and_light/samples_heavy_and_light_t1.0_iter1_N5000_clean.csv"
    #get_predictions(csvfile, seed_file, method='SeqPropEn')
    csvfile="/data/mahajs17/Propen_sampling/sequence_propen/samples_egfrn032/samples_heavy_and_light/samples_heavy_and_light_t0.7_iter1_N5000_clean.csv"
    seed_file ='s3://prescient-data-dev/sandbox/mahajs17/paratop_enhancer/testsets/prescient_accepted_seeds_seed_idEGFR-N032.parquet'
    get_predictions(csvfile, seed_file, method='SeqPropEn')
    csvfile="/data/mahajs17/Propen_sampling/sequence_propen/samples_egfrn032/samples_heavy_and_light/samples_heavy_and_light_t1.0_iter1_N5000_clean.csv"
    seed_file ='s3://prescient-data-dev/sandbox/mahajs17/paratop_enhancer/testsets/prescient_accepted_seeds_seed_idEGFR-N032.parquet'
    get_predictions(csvfile, seed_file, method='SeqPropEn')

    
