# Read data
import pandas as pd
import os
import seaborn as sns
import numpy as np
from edlib import align as align
import matplotlib.pyplot as plt
import matplotlib
from affinityenhancer.preprocess.paired_propen_utils import match_on_edist
from affinityenhancer.preprocess.constants import DEFAULT_DATASET_PATH, DEFAULT_PAIRING_SETTINGS, DEFAULT_TESTSET_PATH
from affinityenhancer.preprocess.utils import lookup_target_sequence
import boto3
from botocore.exceptions import NoCredentialsError, ClientError


def s3_path_exists(s3_uri):
    """
    Check if a given S3 path exists.

    Parameters:
    s3_uri (str): The S3 URI (e.g., 's3://your_bucket_name/path/to/your/file.txt').

    Returns:
    bool: True if the path exists, False otherwise.
    """
    # Parse the S3 URI
    if not s3_uri.startswith('s3://'):
        raise ValueError("Invalid S3 URI. It should start with 's3://'.")
    
    s3_uri_parts = s3_uri[5:].split('/', 1)
    if len(s3_uri_parts) != 2:
        raise ValueError("Invalid S3 URI. It should be in the format 's3://bucket_name/path/to/object'.")
    
    bucket_name, s3_path = s3_uri_parts

    # Create an S3 client
    s3_client = boto3.client('s3')
    
    try:
        s3_client.head_object(Bucket=bucket_name, Key=s3_path)
        return True
    except ClientError as e:
        if e.response['Error']['Code'] == '404':
            return False
        else:
            print(f"Unexpected error: {e}")
            return False
    except NoCredentialsError:
        print("Credentials not available.")
        return False

class BaseMatcher():
    def __init__(self,
                 csv_file='s3://prescient-data-dev/sandbox/vasilaks/sabdab_vs_world/DATA_v5.1/epitome_dataset.csv',
                 source_name='skempi',
                 dataset_path: str = DEFAULT_DATASET_PATH,
                 source_column: str ='affinity_datasource',
                 property_column: str = 'affinity_pkd',
                 hold_out_ag: str = None,
                 hold_out_ab: str = None,
                 target_name: str = '',
                 seed_name: str = '',
                 hold_out_ag_ed: int = 80,
                 hold_out_ab_ed: int = 50
                 ):
        super().__init__()
    
        dataset = pd.read_csv(csv_file, low_memory=False)
        self.df = dataset[dataset[source_column]==source_name]
        self.df = self.df[~self.df[property_column].isna()]
        self.df['seqid'] = self.df['seq_id']
        self.dataset_path = dataset_path
        self.source_name = source_name
        self.property_column = property_column
        self.hold_out_ab = hold_out_ab
        self.hold_out_ag =  hold_out_ag
        self.hold_out_ab_ed = hold_out_ab_ed
        self.hold_out_ag_ed = hold_out_ag_ed
        self.seed_name = seed_name
        self.target_name = target_name

    
    def get_org_stat(self):

        df = self.df    
        unique_ag_ids = df['ag_id'].unique()
        print(len(unique_ag_ids))
        #print(unique_ag_ids)
        unique_targets = df['sabdab_idx'].unique()
        print(len(unique_targets))
        unique_ag_ids_st = df['ag_id_strict'].unique()
        print(len(unique_ag_ids_st))
        unique_ag_seq = df['affinity_antigen_sequence'].unique()
        print('Unique Ag seq: ', len(unique_ag_seq))
        unique_ag_seq = list(unique_ag_seq)
        ed = np.zeros((len(unique_ag_seq), len(unique_ag_seq))).astype(int)
        for idx in range(len(unique_ag_seq)):
            ed[idx, :] = [align(unique_ag_seq[idy], unique_ag_seq[idx])["editDistance"]
                          for idy in range(len(unique_ag_seq))]

        upper_triangle_indices = np.triu_indices(ed.shape[0], k=1)

        # Extract the upper triangle values 
        ed_unique = ed[upper_triangle_indices]
        sns.histplot(ed_unique)
        outdir = f"/data/mahajs17/Propen/{self.source_name}/"
        os.makedirs(outdir, exist_ok=True)
        plt.savefig(f"{outdir}/distribution_antigenED.png")
        plt.close()

        ax = sns.histplot(df, x="ag_id")
        ax.set_xticklabels([])
        outdir = f"/data/mahajs17/Propen/{self.source_name}/"
        os.makedirs(outdir, exist_ok=True)
        plt.savefig(f"{outdir}/distribution_antigenId.png")
        plt.close()

        ax = sns.histplot(df, x="affinity_antigen_sequence", y="ag_id")
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        outdir = f"/data/mahajs17/Propen/{self.source_name}/"
        os.makedirs(outdir, exist_ok=True)
        plt.savefig(f"{outdir}/distribution_antigenSeqvsAgId.png")
        plt.close()

    
    def get_stats(self, settings: dict = DEFAULT_PAIRING_SETTINGS):
        print('Unique antibodies', self.df.shape[0])
        print(self.property_column, self.df[self.property_column].describe())

        sns.histplot(self.df, x=self.property_column)
        outdir = f"/data/mahajs17/Propen/{self.source_name}/"
        os.makedirs(outdir, exist_ok=True)
        plt.savefig(f"{outdir}/distribution_org_{self.property_column}.png")
        plt.close()

        property_th_lb = settings['property_th_lb']
        property_th_ub = settings['property_th_ub']
        edist_th = settings['edist_th']
        property_to_match = settings['edist_th']
        min_prop = None
        suffix = ''
        if 'property_min' in settings:
            min_prop = settings['property_min']
            suffix = f'_minprop{min_prop}'

        if self.source_name == 'skempi':
            pairs_file = f'{self.dataset_path}/{self.source_name}_matched_th{property_th_lb}-{property_th_ub}_edth{edist_th}_prop{property_to_match}{suffix}.parquet'
        else:
            pairs_file = f'{self.dataset_path}/{self.source_name}_matched_th{property_th_lb}-{property_th_ub}_edth{edist_th}_propED{suffix}.parquet'

        df = pd.read_parquet(pairs_file)
        print(df.columns)
        df['ED'] = df.apply(lambda row: align(row['first_HeavyAA'], row['second_HeavyAA'])['editDistance']
                            + align(row['first_LightAA'], row['second_LightAA'])['editDistance'],
                            axis=1)
        df[f'delta_{self.property_column}'] = df['second_property'] - df['first_property']
        ax = sns.histplot(df, x=f'delta_{self.property_column}', y='ED', cbar=True, cbar_kws=dict(shrink=.75))
        plt.savefig(f"{outdir}/delta{self.property_column}vsED_th{property_th_lb}-{property_th_ub}_edth{edist_th}{suffix}.png")
        plt.close()

        ax = sns.histplot(df, x=f'delta_{self.property_column}', y='first_property', cbar=True, cbar_kws=dict(shrink=.75))
        plt.savefig(f"{outdir}/delta{self.property_column}vs{self.property_column}_th{property_th_lb}-{property_th_ub}_edth{edist_th}{suffix}.png")
        plt.close()

    def get_suffix(self, settings):
        suffix = ''
        if 'property_min' in settings:
            min_prop = settings['property_min']
            suffix = f'_minprop{min_prop}'

        if (self.hold_out_ab is not None):
            suffix += f'_testAb{self.seed_name}_ED{self.hold_out_ab_ed}'
            
        if (self.hold_out_ag is not None):
            suffix += f'_testAg{self.target_name}_ED{self.hold_out_ag_ed}'
        
        return suffix


    def set_partitions(self, df_paired):
        if self.hold_out_ag is not None:
            seqs_ag = list(self.df['affinity_antigen_sequence'].unique())
            seqs_ag_rm = []
            for seq in seqs_ag:
                if align(seq, self.hold_out_ag)["editDistance"] <= self.hold_out_ag_ed:
                    seqs_ag_rm.append(seq)
            assert 'first_affinity_antigen_sequence' in df_paired
            df_paired['partition'] = ['test' if seq in seqs_ag_rm else 'train'
                                      for seq in df_paired['first_affinity_antigen_sequence'].values.tolist()]

        if self.hold_out_ab is not None:
            self.df['fv_sequence'] = self.df['fv_heavy'] + self.df['fv_light']
            seqs_ab = list(self.df['fv_sequence'].unique())
            seqs_ab_rm = []
            for seq in seqs_ab:
                if align(seq, self.hold_out_ab)["editDistance"] <= self.hold_out_ab_ed:
                    seqs_ab_rm.append(seq)
            df_paired['first_sequence'] = df_paired['first_HeavyAA'] + df_paired['first_LightAA']
            df_paired['partition'] = ['test' if seq in seqs_ab_rm else 'train'
                                      for seq in df_paired['first_sequence'].values.tolist()]
        
        return df_paired

        
    def match(self, settings: dict = DEFAULT_PAIRING_SETTINGS, overwrite: bool = False):

        property_th_lb = settings['property_th_lb']
        property_th_ub = settings['property_th_ub']
        edist_th = settings['edist_th']
        property_to_match = settings['edist_th']
        min_prop = None
        
        suffix = self.get_suffix(settings)
        outfile = f'{self.dataset_path}/skempi_matched_th{property_th_lb}-{property_th_ub}_edth{edist_th}_prop{property_to_match}{suffix}.parquet'
        print(outfile)

        if s3_path_exists(outfile) and not overwrite:
            print(outfile, 'exists. Set overwrite == True to overwrite existing file.')
            return
        
        df = match_on_edist(self.df,
                            save_dir=f'{self.dataset_path}/matched_distances_{self.source_name}/',
                            **settings, col='ag_id')
        
        df = df.drop_duplicates()
        print('Final size ', df.shape)

        if (self.hold_out_ab is not None) or (self.hold_out_ag is not None):
            df = self.set_partitions(df)
            assert 'partition' in df
        

        df.to_parquet(outfile)


class PerAntigenMatcher(BaseMatcher):
    def __init__(self,
                 csv_file='s3://prescient-data-dev/sandbox/vasilaks/sabdab_vs_world/DATA_v5.1/epitome_dataset.csv',
                 source_name='prescient',
                 dataset_path: str = DEFAULT_DATASET_PATH,
                 source_column: str = 'affinity_datasource',
                 property_column: str = 'affinity_pkd',
                 hold_out_ag: str = None,
                 hold_out_ab: str = None,
                 target_name: str = '',
                 seed_name: str = '',
                 hold_out_ag_ed: int = 80,
                 hold_out_ab_ed: int = 50
                 ):
        super().__init__(csv_file,
                         source_name,
                         dataset_path=dataset_path,
                         source_column=source_column,
                         property_column=property_column,
                         hold_out_ag=hold_out_ag,
                         hold_out_ab=hold_out_ab,
                         target_name=target_name,
                         seed_name=seed_name,
                         hold_out_ab_ed=hold_out_ab_ed,
                         hold_out_ag_ed=hold_out_ag_ed
                         )
        
    def match(self, settings: dict = DEFAULT_PAIRING_SETTINGS, overwrite=False):
        
        property_th_lb = settings['property_th_lb']
        property_th_ub = settings['property_th_ub']
        edist_th = settings['edist_th']
        min_prop = None
        
        suffix = self.get_suffix(settings)
        outname_split = f'{self.dataset_path}/{self.source_name}_split{{}}_matched_th{property_th_lb}-{property_th_ub}_edth{edist_th}_propED{suffix}.parquet'
        outname_full = f'{self.dataset_path}/{self.source_name}_matched_th{property_th_lb}-{property_th_ub}_edth{edist_th}_propED{suffix}.parquet'
        if s3_path_exists(outname_full) and not overwrite:
            print(outname_full, 'exists. Set overwrite == True to overwrite existing file.')
            return

        ag_ids = list(self.df.ag_id.unique())
        for ag_id in ag_ids:

            df_ag = self.df[self.df['ag_id']==ag_id]
            df_ag = match_on_edist(df_ag,
                                   save_dir=f'{self.dataset_path}/matched_distances_{self.source_name}_{ag_id}/',
                                   **settings
                                   )
            df_ag = df_ag.drop_duplicates()
            df_ag.to_parquet(outname_split.format(ag_id))
        
        df_all = [pd.read_parquet(outname_split.format(ag_id))
                  for ag_id in ag_ids]

        df_all = pd.concat(df_all)
        df_all = df_all.drop_duplicates()
        if (self.hold_out_ab is not None) or (self.hold_out_ag is not None):
            df_all = self.set_partitions(df_all)
            assert 'partition' in df_all
        
        df_all.to_parquet(outname_full)


def create_hold_out_datasets(settings):
    seed_df = pd.read_csv("/homefs/home/mahajs17/projects/C1sr/scripts/accepted_seeds.csv")
    ag_seq_lookup = lookup_target_sequence()
    seed_df["affinity_antigen_sequence"] = seed_df.target.apply(lambda x: ag_seq_lookup.get(x, "")).values
    seed_df.rename(columns={"seed_id": "seed_alias"}, inplace=True)
    seed_df_ag = seed_df[['target', 'affinity_antigen_sequence']].drop_duplicates()
    target_seq_lookup = {row['target']:row['affinity_antigen_sequence']
                            for i, row in seed_df_ag.iterrows()}
    
    for target_name, ag_seq in target_seq_lookup.items():
        print(target_name)
        for name in ['prescient', 'aalphabio']:
            print(name)
            matcher = PerAntigenMatcher(source_name=name,
                                        hold_out_ag=ag_seq,
                                        target_name=target_name
                                        )
            matcher.match(settings=settings)
            #matcher.get_stats(settings=settings)
            #matcher.get_org_stat()

        for name in ['skempi']:
            print(name)
            matcher = BaseMatcher(source_name=name,
                                        hold_out_ag=ag_seq,
                                        target_name=target_name)
            matcher.match(settings=settings)

def create_test_sets_prescient():
    seed_df = pd.read_csv("/homefs/home/mahajs17/projects/C1sr/scripts/accepted_seeds.csv")
    ag_seq_lookup = lookup_target_sequence()
    seed_df["affinity_antigen_sequence"] = seed_df.target.apply(lambda x: ag_seq_lookup.get(x, "")).values
    seed_df['seqid'] = seed_df['seed_id']
    seed_df['partition'] =  'test'
    seed_df.to_parquet(f'{DEFAULT_TESTSET_PATH}/prescient_accepted_seeds.parquet')
    
    for col in ['target', 'seed_id']:
        uniq_vals = seed_df[col].unique()
        for name in uniq_vals:
            df = seed_df[seed_df[col]==name]
            df.to_parquet(f'{DEFAULT_TESTSET_PATH}/prescient_accepted_seeds_{col}{name}.parquet')


def create_denovo_round5_seed():
    seed_df = pd.read_csv("s3://prescient-data-dev/sandbox/mahajs17/Propen/denovo/round5_denovo_seeds_allcols.csv")
    seed_df["affinity_antigen_sequence"] = seed_df["seqres_antigen"]
    seed_df.rename(columns={"id": "seed_alias"}, inplace=True)
    seed_df['seqid'] = seed_df['seed_alias']
    seed_df['partition'] = 'test'
    
    for col in ['seed_alias']:
        uniq_vals = seed_df[col].unique()
        for name in uniq_vals:
            df = seed_df[seed_df[col]==name]
            df.to_parquet(f'{DEFAULT_TESTSET_PATH}/prescient_denovo_round5_seeds_{col}{name}.parquet')


def create_hold_out_aabio_datasets(settings, col='matalpha_description'):
    aabio_mapping = \
    {"AARG-002_A0A805RSY6_MACFA": "AARG_002_var10", "AARG-002_A0A2K6T447_SAIBB": "AARG_002_var15", "AARG-002_Q0Z973_CALJA": "AARG_002_var14",
      "AARG-002_G1S4D6_NOMLE": "AARG_002_var17", "AARG-002_A0A2K5Q7P2_CEBIM": "AARG_002_var12", "AARG-002_30-212_1ALU": "AARG_002_var16", 
      "AARG-002_A0A2Y9EHI1_PHYMC": "AARG_002_var11", "AARG-002_O97540_AOTNA": "AARG_002_var19", "AARG-002_A0A2K6NAV7_RHIRO": "AARG_002_var18", 
      "AARG-002_Q9TTH3_AOTLE": "AARG_002_var13", "AARG-010_XP_007965218": "AARG_010_var9", "AARG-010_XP_035162289": "AARG_010_var5", 
      "AARG-010_XP_004427823": "AARG_010_var3", "AARG-010_A0A2K6U601_SAIBB": "AARG_010_var4", "AARG-010_25-146_AP": "AARG_010_var1", 
      "AARG-010_B0LAJ3_MACFA": "AARG_010_var8", "AARG-010_A0A2K5RFC6_CEBIM": "AARG_010_var6", "AARG-010_XP_032616905": "AARG_010_var0", 
      "AARG-010_A0A2K5E9H6_AOTNA": "AARG_010_var7", "AARG-010_XP_008532654": "AARG_010_var2", "AARG_009_var41": "AARG_009_var41", 
      "AARG-023_var33": "AARG_023_var33", "AARG-009_212-630_2DH2": "AARG_009_212_630_2DH2", "AARG-006_var13": "AARG_006_var13", 
      "AARG_009_var27": "AARG_009_var27", "AARG-006_var17": "AARG_006_var17", "AARG-006_var46": "AARG_006_var46", "AARG_009_var23": "AARG_009_var23",
        "AARG-006_var4": "AARG_006_var4", "AARG-023_var14": "AARG_023_var14", "AARG-023_var48": "AARG_023_var48", "AARG-006_var38": "AARG_006_var38", 
        "AARG-013_var40": "AARG_013_var40", "AARG-013_var24": "AARG_013_var24", "AARG-023_var20": "AARG_023_var20", "AARG-013_var13": "AARG_013_var13",
          "AARG-013_var22": "AARG_013_var22", "AARG-006_var6": "AARG_006_var6", "AARG_009_var17": "AARG_009_var17", "AARG-013_var42": "AARG_013_var42", 
          "AARG-006_var40": "AARG_006_var40", "AARG-013_var31": "AARG_013_var31", "AARG-023_var13": "AARG_023_var13", "AARG-006_var19": "AARG_006_var19", 
          "AARG-013_var4": "AARG_013_var4", "AARG-006_var7": "AARG_006_var7", "AARG-006_var49": "AARG_006_var49", "AARG-023_var39": "AARG_023_var39", 
          "AARG-006_var2": "AARG_006_var2", "AARG-023_var24": "AARG_023_var24", "AARG-006_var30": "AARG_006_var30", "AARG-006_var43": "AARG_006_var43", 
          "AARG-013_var6": "AARG_013_var6", "AARG-006_var45": "AARG_006_var45", "AARG-023_var34": "AARG_023_var34", "AARG-023_var15": "AARG_023_var15", 
          "AARG-023_var49": "AARG_023_var49", "AARG-006_var20": "AARG_006_var20", "AARG-023_var6": "AARG_023_var6", "AARG_009_var30": "AARG_009_var30", 
          "AARG-013_var33": "AARG_013_var33", "AARG_009_var3": "AARG_009_var3", "AARG-023_34-202_1PVH_ECD": "AARG_023_34_202_1PVH_ECD", 
          "AARG-013_var34": "AARG_013_var34", "AARG-013_var27": "AARG_013_var27", "AARG-006_var33": "AARG_006_var33", "AARG-013_var19": "AARG_013_var19", 
          "AARG_009_var33": "AARG_009_var33", "AARG-013_var29": "AARG_013_var29", "AARG_009_var24": "AARG_009_var24", "AARG_009_var18": "AARG_009_var18", 
          "AARG_009_var40": "AARG_009_var40", "AARG-013_var11": "AARG_013_var11", "AARG-006_var28": "AARG_006_var28", "AARG-023_var47": "AARG_023_var47", 
          "AARG-006_var32": "AARG_006_var32", "AARG-013_var1": "AARG_013_var1", "AARG-013_var9": "AARG_013_var9", "AARG-013_var15": "AARG_013_var15", 
          "AARG-006_var29": "AARG_006_var29", "AARG-013_var14": "AARG_013_var14", "AARG-023_var31": "AARG_023_var31", "AARG-006_var48": "AARG_006_var48", 
          "AARG-006_var14": "AARG_006_var14", "AARG-013_var36": "AARG_013_var36", "AARG-013_28-252_ECD": "AARG_013_28_252_ECD", "AARG_009_var37": "AARG_009_var37", 
          "AARG-023_var17": "AARG_023_var17", "AARG-013_var2": "AARG_013_var2", "AARG-023_var9": "AARG_023_var9", "AARG-023_var11": "AARG_023_var11", 
          "AARG_009_var20": "AARG_009_var20", "AARG-013_var39": "AARG_013_var39", "AARG-023_var1": "AARG_023_var1", "AARG-023_var18": "AARG_023_var18", "AARG_009_var4": "AARG_009_var4", 
          "AARG-006_var44": "AARG_006_var44", "AARG-006_var1": "AARG_006_var1", "AARG-023_var41": "AARG_023_var41", "AARG_009_var35": "AARG_009_var35", "AARG_009_var26": "AARG_009_var26", 
          "AARG-006_var15": "AARG_006_var15", "AARG-006_var27": "AARG_006_var27", "AARG-013_var32": "AARG_013_var32", "AARG-023_var22": "AARG_023_var22", "AARG_009_var9": "AARG_009_var9", 
          "AARG-006_var3": "AARG_006_var3", "AARG-023_var4": "AARG_023_var4", "AARG_009_var28": "AARG_009_var28", "AARG-006_var21": "AARG_006_var21", "AARG-006_var8": "AARG_006_var8", 
          "AARG_009_var7": "AARG_009_var7", "AARG-013_var37": "AARG_013_var37", "AARG-023_var38": "AARG_023_var38", "AARG-013_var5": "AARG_013_var5", "AARG-006_var36": "AARG_006_var36", 
          "AARG-013_var16": "AARG_013_var16", "AARG-023_var19": "AARG_023_var19", "AARG_009_var29": "AARG_009_var29", "AARG-006_var37": "AARG_006_var37", 
          "AARG-006_197-378_6Y76": "AARG_006_197_378_6Y76", "AARG_009_var15": "AARG_009_var15", "AARG_009_var2": "AARG_009_var2", "AARG-006_var11": "AARG_006_var11", "AARG-006_var18": "AARG_006_var18",
         "AARG-006_var39": "AARG_006_var39", "AARG_009_var38": "AARG_009_var38", "AARG_009_var16": "AARG_009_var16", "AARG-023_var44": "AARG_023_var44", "AARG-006_var35": "AARG_006_var35", 
         "AARG_009_var22": "AARG_009_var22", "AARG-013_var21": "AARG_013_var21", "AARG-006_var23": "AARG_006_var23", "AARG-023_var26": "AARG_023_var26", "AARG-013_var49": "AARG_013_var49",
           "AARG-006_var31": "AARG_006_var31", "AARG_009_var8": "AARG_009_var8", "AARG-006_var34": "AARG_006_var34", "AARG_009_var25": "AARG_009_var25", "AARG-023_var23": "AARG_023_var23", "AARG_009_var45": "AARG_009_var45", "AARG-023_var21": "AARG_023_var21", "AARG-013_var43": "AARG_013_var43", "AARG_009_var1": "AARG_009_var1", "AARG-023_var29": "AARG_023_var29", 
           "AARG_009_var43": "AARG_009_var43", "AARG_009_var46": "AARG_009_var46", "AARG-023_var8": "AARG_023_var8", "AARG-006_var10": "AARG_006_var10", "AARG-013_var20": "AARG_013_var20", "AARG_009_var12": "AARG_009_var12", "AARG-006_var26": "AARG_006_var26", "AARG_009_var49": "AARG_009_var49", "AARG-023_var36": "AARG_023_var36", "AARG_009_var32": "AARG_009_var32", "AARG_009_var6": "AARG_009_var6", "AARG-013_var28": "AARG_013_var28", "AARG-023_var30": "AARG_023_var30", "AARG-013_var12": "AARG_013_var12", "AARG_009_var44": "AARG_009_var44", "AARG_009_var36": "AARG_009_var36", "AARG-023_var32": "AARG_023_var32", "AARG_009_var42": "AARG_009_var42", "AARG-006_var41": "AARG_006_var41", "AARG-023_var7": "AARG_023_var7", "AARG-013_var46": "AARG_013_var46", "AARG-023_var2": "AARG_023_var2", "AARG-013_var17": "AARG_013_var17", "AARG-013_var26": "AARG_013_var26", "AARG-023_var3": "AARG_023_var3", "AARG-013_var10": "AARG_013_var10", "AARG_009_var10": "AARG_009_var10", "AARG-006_var42": "AARG_006_var42", "AARG-023_var27": "AARG_023_var27", "AARG-023_var40": "AARG_023_var40", "AARG-023_var46": "AARG_023_var46", "AARG-023_var12": "AARG_023_var12", "AARG-006_var16": "AARG_006_var16", "AARG-006_var22": "AARG_006_var22", "AARG-013_var45": "AARG_013_var45", "AARG-013_var25": "AARG_013_var25", "AARG-006_var9": "AARG_006_var9", "AARG-006_var24": "AARG_006_var24", "AARG-023_var45": "AARG_023_var45", "AARG_009_var14": "AARG_009_var14", "AARG-006_var47": "AARG_006_var47", "AARG-013_var35": "AARG_013_var35", "AARG-023_var43": "AARG_023_var43", "AARG-013_var18": "AARG_013_var18", "AARG-006_var25": "AARG_006_var25", "AARG_009_var34": "AARG_009_var34", "AARG_009_var5": "AARG_009_var5", "AARG-006_var5": "AARG_006_var5", "AARG_009_var31": "AARG_009_var31", "AARG-013_var30": "AARG_013_var30", "AARG-023_var25": "AARG_023_var25", "AARG-023_var35": "AARG_023_var35", "AARG_009_var11": "AARG_009_var11", "AARG_009_var39": "AARG_009_var39", "AARG_009_var21": "AARG_009_var21", "AARG-013_var47": "AARG_013_var47", "AARG-013_var8": "AARG_013_var8", "AARG-013_var48": "AARG_013_var48", "AARG-013_var44": "AARG_013_var44", "AARG_009_var47": "AARG_009_var47", "AARG-023_var5": "AARG_023_var5", "AARG-023_var16": "AARG_023_var16", "AARG-006_var12": "AARG_006_var12", "AARG-013_var41": "AARG_013_var41", "AARG_009_var19": "AARG_009_var19", "AARG-013_var7": "AARG_013_var7", "AARG-023_var42": "AARG_023_var42", "AARG-013_var23": "AARG_013_var23", "AARG_009_var13": "AARG_009_var13", "AARG-013_var38": "AARG_013_var38", "AARG-013_var3": "AARG_013_var3", "AARG-023_var10": "AARG_023_var10", "AARG_009_var48": "AARG_009_var48", "AARG-023_var37": "AARG_023_var37", "AARG-023_var28": "AARG_023_var28"}

    aab_df = pd.read_parquet("s3://prescient-data-dev/sandbox/wanga84/alpha-bio/lib1/all_alpha_bio_data.parquet")
    
    aab_df_ag = aab_df[[col, 'affinity_antigen_sequence', 'target']].drop_duplicates()
    aab_df_ag['clean_names'] = [aabio_mapping[t] for t in aab_df_ag[col].values.tolist()]
    
    for i, row in aab_df_ag.iterrows():

        ag_seq = row["affinity_antigen_sequence"]
        target_name = row['target'].replace('-','') + '_' + row[col]
        for name in ['prescient', 'aalphabio']:
            matcher = PerAntigenMatcher(source_name=name,
                                        hold_out_ag=ag_seq,
                                        target_name=target_name,
                                        hold_out_ag_ed=40
                                        )
            matcher.match(settings=settings)

        for name in ['skempi']:
            print(name)
            matcher = BaseMatcher(source_name=name,
                                  hold_out_ag=ag_seq,
                                  target_name=target_name,
                                  hold_out_ag_ed=40
                                  )
            matcher.match(settings=settings)


def create_hold_out_seed_datasets(settings):
    seed_df = pd.read_csv("/homefs/home/mahajs17/projects/C1sr/scripts/accepted_seeds.csv")
    ag_seq_lookup = lookup_target_sequence()
    seed_df["affinity_antigen_sequence"] = seed_df.target.apply(lambda x: ag_seq_lookup.get(x, "")).values
    seed_df.rename(columns={"seed_id": "seed_alias"}, inplace=True)
    seed_df_ab = seed_df[['seed_alias', 'fv_heavy', 'fv_light']].drop_duplicates()
    
    for i, row in seed_df_ab.iterrows():
        seed_name = row['seed_alias']
        ab_seq = row['fv_heavy'] + row['fv_light']
        for name in ['prescient', 'aalphabio']:
            print(name)
            matcher = PerAntigenMatcher(source_name=name,
                                        hold_out_ab=ab_seq,
                                        seed_name=seed_name
                                        )
            matcher.match(settings=settings)
            #matcher.get_stats(settings=settings)
            #matcher.get_org_stat()

        for name in ['skempi']:
            print(name)
            matcher = BaseMatcher(source_name=name,
                                  hold_out_ab=ab_seq,
                                  seed_name=seed_name
                                  )
            matcher.match(settings=settings)


def create_hold_out_denovo_seed_datasets(settings):
    seed_df = pd.read_csv("s3://prescient-data-dev/sandbox/mahajs17/Propen/denovo/round5_denovo_seeds_allcols.csv")
    seed_df["affinity_antigen_sequence"] = seed_df["seqres_antigen"]
    seed_df.rename(columns={"id": "seed_alias"}, inplace=True)
    seed_df_ab = seed_df[['seed_alias', 'fv_heavy', 'fv_light']].drop_duplicates()
    
    for i, row in seed_df_ab.iterrows():
        seed_name = row['seed_alias']
        ab_seq = row['fv_heavy'] + row['fv_light']
        for name in ['prescient']:
            print(name)
            matcher = PerAntigenMatcher(source_name=name,
                                        hold_out_ab=ab_seq,
                                        seed_name=seed_name
                                        )
            matcher.match(settings=settings)
            #matcher.get_stats(settings=settings)
            #matcher.get_org_stat()

        for name in ['skempi']:
            print(name)
            matcher = BaseMatcher(source_name=name,
                                  hold_out_ab=ab_seq,
                                  seed_name=seed_name
                                  )
            matcher.match(settings=settings)


def create_iid_datasets(settings):
    for name in ['prescient', 'aalphabio']:
       print(name)
       matcher = PerAntigenMatcher(source_name=name)
       matcher.match(settings=settings)
       #matcher.get_stats(settings=settings)
       matcher.get_org_stat()

    for name in ['skempi']:
       print(name)
       matcher = BaseMatcher(source_name=name)
       matcher.match(settings=settings)
       #matcher.get_stats(settings=settings)
       #matcher.get_org_stat()

        
