import os
import sys
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import edlib
from prescient.transforms.functional import anarci_numbering
from logos import plot_logos
from prescient.constants import CDR_RANGES_AHO

def combine_and_plot_data():
    pass

def process_baseline_data(csvfile, outdir='.', siter=1, temp=0.9, samples=200):
    if csvfile.endswith('.parquet'):
        df = pd.read_parquet(csvfile)
    else:
        df = pd.read_csv(csvfile)
    
    assert 'fv_heavy' in df
    assert 'fv_light' in df
    assert 'fv_heavy_seed' in df
    assert 'fv_light_seed' in df
    
    # ADD AHO
    if not 'fv_heavy_aho' in df:
        heavy_aho = anarci_numbering(df['fv_heavy'])
        light_aho = anarci_numbering(df['fv_light'])
        df['fv_heavy_aho'] = heavy_aho
        df['fv_light_aho'] = light_aho
    
    if not 'fv_light_seed_aho' in df:
        df['fv_heavy_seed_aho'] = anarci_numbering(df['fv_heavy_seed'])
        df['fv_light_seed_aho'] = anarci_numbering(df['fv_light_seed'])

    # ADD PER CDR ED
    df['edit_distance_heavy'] = [edlib.align(row['fv_heavy'], 
                                                row['fv_heavy_seed'])['editDistance'] 
                                                for i, row in df.iterrows()]
    df['edit_distance_light'] = [edlib.align(row['fv_light'], 
                                                row['fv_light_seed'])['editDistance'] 
                                                for i, row in df.iterrows()]

    df['edit_distance'] = df['edit_distance_heavy'] + df['edit_distance_light']
    
    print(df.columns)
    
    for cdr, cdr_range in CDR_RANGES_AHO.items():
        chain = 'heavy' if cdr.startswith('H') else 'light'
        print(cdr, cdr_range, chain)
        r1, r2 = cdr_range[0], cdr_range[1]+1
        print(f'fv_{chain}_aho', f'fv_{chain}_seed_aho')
        try:
            df[f'edit_distance_{cdr}'] = [edlib.align(row[f'fv_{chain}_aho'][r1:r2].replace('-',''), 
                                                    row[f'fv_{chain}_seed_aho'][r1:r2].replace('-',''))['editDistance'] 
                                                    for i, row in df.iterrows()]
        except:
            continue
    df['temperature'] = temp
    df['iterations'] = siter

    suffix = '' #f'iter{siter}_t{temp}_N{samples}'
    #add edlib calc
    df = df.drop_duplicates(['fv_heavy', 'fv_light'])
    df.to_csv(f'{outdir}/samples_{suffix}.csv', index=False)

    plt.rcdefaults()
    sns.histplot(df, x="edit_distance_light", discrete=True)
    plt.savefig(f"{outdir}/distribution_EDlight_{suffix}.png")
    plt.close()
    sns.histplot(df, x="edit_distance_heavy", discrete=True)
    plt.savefig(f"{outdir}/distribution_EDheavy_{suffix}.png")
    plt.close()

    fig, axes = plt.subplots(nrows=2, ncols=4)
    axes = axes.flatten()
    for i, cdr in enumerate(CDR_RANGES_AHO):
        if f"edit_distance_{cdr}" in df:
            sns.histplot(df, x=f"edit_distance_{cdr}", ax=axes[i], discrete=True)
    plt.subplots_adjust(wspace=0.05, hspace=0.05)
    fig.tight_layout()
    plt.savefig(f"{outdir}/distribution_EDCDRs_{suffix}.png")
    plt.close()
    plt.rcdefaults()
    for chain in ['heavy', 'light']:
        for i, ref_seq in enumerate(df[f'fv_{chain}_seed_aho'].unique()):
            df_filter = df[df[f'fv_{chain}_seed_aho']==ref_seq]
            sequences = df_filter[f'fv_{chain}_aho'].values.tolist()
            print(chain, ref_seq, len(sequences))
            print(sequences[:2])
            if len(sequences)>2:
                plot_logos(sequences,
                           ref_seq=ref_seq,
                           logo_file_base=f'{outdir}/logos_seed{i}_{chain}_{suffix}',
                           chain=chain.upper()[0]
                           )
                
def debug():
    file = '/data/mahajs17/Propen_sampling/EGFRN032/hoo_models/PropEn_PrescientED5pkd6.5_EGFRN032_gt/logits/samples_iter2_t0.9_N5000.csv'

    df = pd.read_csv(file)
    for chain in ['heavy', 'light']:
        for i, ref_seq in enumerate(df[f'fv_{chain}_seed_aho'].unique()):
            df_filter = df[df[f'fv_{chain}_seed_aho']==ref_seq]
            sequences = df_filter[f'fv_{chain}_aho'].values.tolist()
            print(chain, ref_seq, len(sequences))
            sequences = [t for t in sequences if t.find('X')==-1]
            print(sequences[:20])
            if len(sequences)>2:
                plot_logos(sequences[:10],
                           logo_file_base=f'test_debug_{chain}',
                           chain=chain.upper()[0]
                           )

#csvfile = "/data/mahajs17/Propen_sampling/EGFRN032/hoo_models/Seq_Mixup_EGFRN032_SkempiED5pkd6.5/EGFR_0.3.parquet"
#outdir = "/data/mahajs17/Propen_sampling/EGFRN032/hoo_models/Seq_Mixup_EGFRN032_SkempiED5pkd6.5/"
#process_baseline_data(csvfile, outdir=outdir)
debug()