import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from prescient.constants import CDR_RANGES_AHO
import logomaker

def one_hot_encode(seq, alphabet= 'ACDEFGHIKLMNPQRSTVWY-'):
    """
    Encode a string of amino acids

    inputs:
        (1) seq................a string of amino acids to be encoded
        (2) alphabet...........a string of all possible aminoacids, by default it is adapted from "Bio.Alphabet.IUPAC.IUPACProtein"


    output:
            arr................an array(n,m),where n=20 #of all amino acids, and m=length of input seq
            Example: one_hot_encode('AP')
            array([[1., 0.],
                   [0., 0.],
                   [0., 0.],
                   [0., 0.],
                   [0., 0.],
                   [0., 0.],
                   [0., 0.],
                   [0., 0.],
                   [0., 0.],
                   [0., 0.],
                   [0., 0.],
                   [0., 0.],
                   [0., 1.],
                   [0., 0.],
                   [0., 0.],
                   [0., 0.],
                   [0., 0.],
                   [0., 0.],
                   [0., 0.],
                   [0., 0.]])

    """
    # map amino acids to an integer index ussing a dict



    dict_aa2i= {aa: i for i, aa in enumerate(alphabet)}
    dict_aa2i['X']=20
    # dict_aa2i['X']=21

    # encoding

    arr = np.zeros((len(dict_aa2i), len(seq)))

    try:
        for i, a in enumerate(seq):

            arr[dict_aa2i[a.upper()], i] = 1
    except:
        print('unexpected letters: ',seq)
        pass

    return arr

def visulaize_seq_logo(insilico_VH_lib,wt, ax=None):
    sns.set_palette("Set2")
    sns.set_context("talk")
    sns.set_style("whitegrid")
    sns.set(font_scale=2)
    _2d_arr_global=[]
    X_vh = insilico_VH_lib
    X_vh = list(map(one_hot_encode, X_vh))

    try:
        X_vh = np.array([x.flatten('F') for x in X_vh])
        #print(X_vh.shape)
    except:
        for x in X_vh:
            try:
                a = x.flatten('F')
            except:
                print('Failing')
                #print(wt)
                #print('Failing: x ', x)

    ## pading if length dont match
    _max=0
    for i in X_vh:
        if len(i)>_max:
            _max=len(i)
    _sum=np.zeros(_max)
    for i in X_vh:
        if len(i)<_max:
            #print('before',i[0:20])
            i=np.pad(i, (0, abs(_max-len(i))), 'constant')
            #print('after ',i[0:20])

        try:
            _sum=[x + y for x, y in zip(_sum, i)]
        except:
            for x, y in zip(_sum, i):
                try:
                    a = x + y

                except:
                    print('Failing here')

    ## convert to probability
    _2d_arr=np.array([x / len(X_vh) for x in _sum]).reshape(-1,22)
    _2d_arr_global.append(_2d_arr)
    # make Figure and Axes objects
    ars_seq =wt
    ars_seq=ars_seq

    if ax is None:
        if len(insilico_VH_lib[0]) > 50:
            fig, ax = plt.subplots(1,1,figsize=[45,4])
        else:
            fig, ax = plt.subplots(1,1,figsize=[25,4])


    ww_df=pd.DataFrame(_2d_arr)
    columns_str=[]
    dict_alphabet= 'ACDEFGHIKLMNPQRSTVWY-'
    dict_i2aa = {i: aa for i, aa in enumerate(dict_alphabet)}
    dict_i2aa[20]='*'
    dict_i2aa[21]='X'
    for i in range(22):
        columns_str.append(dict_i2aa[i])
    ww_df.columns=columns_str
    ww_df.index=ww_df.index+1
    # create Logo object
    if not (ww_df.shape[0] > 0):
        return None
    ww_logo = logomaker.Logo(ww_df,
                             color_scheme='chemistry',
                             vpad=.1,
                             width=1,
                             ax=ax,
                             vsep=.005)

    # style using Logo methods
    ww_logo.style_xticks(anchor=1, spacing=5, rotation=45)
    if wt!='':
        ww_logo.style_glyphs_in_sequence(sequence=wt, color='silver')
    #ww_logo.highlight_position(p=4, color='gold', alpha=.5)
    #ww_logo.highlight_position(p=26, color='gold', alpha=.5)

    # style using Axes methods
    ww_logo.ax.set_ylabel('Frequency')
    ww_logo.ax.set_xlim([-1, len(ww_df)+2])

    return ww_logo

def plot_logos(sequences, ref_seq='', logo_file_base='logo', chain='', title=''):
    try:
        logo = visulaize_seq_logo(sequences, wt=ref_seq)
        if logo is None:
            return
        if title!='':
            plt.suptitle(title)
        plt.tight_layout()
        plt.savefig(logo_file_base+'.png', dpi=300, transparent=True)
        plt.close()
        print("Writing ", logo_file_base+'.png')
    except:
        print('Failed logo ,', logo_file_base)
        return
    for reg, reg_ranges in CDR_RANGES_AHO.items():
        if not reg.startswith(chain.upper()):
            continue
        print(reg, reg_ranges)
        seqs_reg = [s[reg_ranges[0]: reg_ranges[1]] for s in sequences]
        ref_seq_reg = ''
        if ref_seq!='':
            ref_seq_reg=ref_seq[reg_ranges[0]:reg_ranges[1]]
        try:
            logo = visulaize_seq_logo(seqs_reg, wt=ref_seq_reg)
            if title!='':
                plt.suptitle(title)
            plt.tight_layout()
            plt.savefig(logo_file_base+f'_{reg}.png', dpi=300, transparent=True)

            plt.close()
        except:
            print('Failed logo ,', logo_file_base, reg)
    print('HERE')
