import pandas as pd
import numpy as np
import scipy.stats
import scipy.special
import pysam
import h5py
import multiprocessing as mp
import itertools as it
import sequence_tools 
import nb_model
import math

DNA53 = 'NTCGA'
DNA35 = 'NAGCT'
trans = DNA53.maketrans(DNA53, DNA35)

DNA = 'ACGT'
NUC = 'CT'

prod_items = [DNA] + [NUC] + [DNA]
keys = [''.join(tup) for tup in it.product(*prod_items)]

def reverse_complement(seq):
    return seq[::-1].translate(trans)

def trans_to_context(seq, baseix=1):
    """ Convert any sequence into
        its unique nucleotide context
    kwarg   baseix: the index of the base around which the context is constructed
    """
    split = seq.split('>')
    first = split[0]
    sec = split[1]
    if 'N' in seq:
        return ''
    if first[baseix] == 'G' or first[baseix] == 'A':
        return reverse_complement(first) + '>' + reverse_complement(sec)

    return seq

#returns regions that do not overlap cds of gene
def seq_model_regions(all_idx_df, chrom, intervals):
    idxs = overlapping_windows(all_idx_df, intervals, chrom)
    return ['chr{}:{}-{}'.format(row[0], row[1], row[2]) for row in idxs.values]


def add_model(df, est_str, d_pr, f_model_str, f_genic_str, fasta_str, tbx_str, all_windows_df, cancer, mapp, window):
    df_obs = df.copy().astype({'Genes':str, 'CHROM':int, 'OBS_MIS':int, 'OBS_NONS':int})
    fasta = pysam.FastaFile(fasta_str)
    tbx = pysam.TabixFile(tbx_str)
    f_genic = h5py.File(f_genic_str, 'r')
    df_substmodel = pd.read_csv('/scratch1/priebeo/neurIPS/supressor_det/substmodel.csv')
    ##patch for sub_map bug. will remove for final 
    if est_str is not None:
        f = h5py.File(est_str.format(0))
        if f[cancer]['test/chr_locs'].shape[0] != f[cancer]['test/y_true'].shape[0]:
            f2 = h5py.File('/scratch1/maxas/ICGC_Roadmap/DNN_predictors/10kb/roadmap_tracks_735_10kb.unzipped.h5', 'r')
            idx_all = f2['idx'][:]
            mapp = f2['mappability'][:]
            idx_lowmap = idx_all[mapp < 0.5]
            est, _ = get_gp_submap_df(est_str, cancer, test_idx = idx_lowmap)
        else:
            est, _ = get_gp_submap_df(est_str, cancer)
        f.close()

    p_mis_lst = []
    p_nons_lst = []
    mu_lst = []
    s_lst = []
    pvals_mis = []
    pvals_nons = []
    R_obs_lst = []
    exp_mis_lst = []
    exp_nons_lst = []
    
    for i, row in df_obs.iterrows():
        gene = row[0]
        obs_mis = row[2]
        obs_nons = row[3]
        chrom = f_genic['chr'][gene][:][0].decode("utf-8") 
        intervals = f_genic['cds_intervals'][gene][:]
        L = pd.DataFrame(f_genic['L_data'][gene][:].T, index = df_substmodel['Unnamed: 0'])
        t_pi = genic_seq_model(d_pr, f_genic, f_model_str, fasta, gene, cancer, window)
        
        pi_sums = t_pi * L
        p_mis = pi_sums[1].sum()
        p_nons = pi_sums[2].sum()
        p_mis_lst.append(p_mis)
        p_nons_lst.append(p_nons)
        if est_str is None:
            mu,sigma = get_region_params_interp(all_windows_df, chrom, intervals, mapp, window)
        else:
            mu,sigma = get_region_params_est(all_windows_df, est, chrom, intervals, mapp, window)
        mu_lst.append(mu)
        s_lst.append(sigma)
        pval_mis = calc_pvalue(mu, sigma, p_mis, obs_mis)
        pvals_mis.append(pval_mis)
        pval_nons = calc_pvalue(mu, sigma, p_nons, obs_nons)
        pvals_nons.append(pval_nons)
        R_obs_lst.append(get_R_obs(chrom, intervals, tbx, all_windows_df, window))
        alpha, theta = nb_model.normal_params_to_gamma(mu,sigma)
        exp_mis_lst.append(alpha*theta*p_mis)
        exp_nons_lst.append(alpha*theta*p_nons)
        
    df_obs['EXP_MIS'] = exp_mis_lst
    df_obs['EXP_NONS'] = exp_nons_lst
    df_obs['R_OBS'] = R_obs_lst
    df_obs['MU'] = mu_lst
    df_obs['SIGMA'] = s_lst
    df_obs['PVAL_MIS'] = pvals_mis
    df_obs['PVAL_NONS'] = pvals_nons
    
    return df_obs

def add_model_parallel(mut_obs_df, est_str, d_pr, f_model_str, f_genic_str, fasta_str, tbx_str, all_idx_df, cancer, mapp, N_procs, window):
    ## Parallel chunk parameters:
    chunksize = int(np.ceil(len(mut_obs_df) / N_procs))
    res = []
    pool = mp.Pool(N_procs)
    for i in np.arange(0, len(mut_obs_df), chunksize):
        df = mut_obs_df.iloc[i:i+chunksize]

        r = pool.apply_async(add_model, (df, est_str, d_pr, f_model_str, f_genic_str, fasta_str, tbx_str, all_idx_df, cancer, mapp, window))
        res.append(r)

    pool.close()
    pool.join()

    res_lst = [r.get() for r in res]
    complete = pd.concat(res_lst)
    return complete

def reformat_probs(Pr_mut_train):
    probs = pd.DataFrame(Pr_mut_train).reset_index()
    for i in probs.iterrows():
        base = i[1]['index'][1]
        mut = i[1]['index'][0].split('>')[1]
        alt = base[0] + mut + base[2]
        out = base + '>' + alt
        probs.loc[i[0],'index'] = out
    probs = probs.set_index('index')
    probs[1] = probs[0]
    probs[2] = probs[0]
    probs[3] = probs[0]
    return probs

def get_actual_muts(intervals, chrom, tbx):
    mis_sum = 0
    nons_sum = 0
    for i in range(intervals.shape[1]):
        start, end = intervals[:,i].astype(int)
        select = pd.DataFrame([x.split('\t') for x in tbx.fetch(chrom, start, end)])
        if len(select) > 0:
            mis_sum += len(select.loc[select[6] == 'Missense_Mutation'])
            nons_sum += len(select.loc[select[6] == 'Nonsense_Mutation'])
    return mis_sum, nons_sum

def get_obs_cds_muts(f_genic, tbx):
    mis_counts = []
    nons_counts = []
    chr_lst = []
    for k in list(f_genic['L_data'].keys()):
        intervals = f_genic['cds_intervals'][k][:]
        chrom = f_genic['chr'][k][:][0].decode("utf-8") 
        chr_lst.append(chrom)
        mis, nons = get_actual_muts(intervals, chrom, tbx)
        mis_counts.append(mis)
        nons_counts.append(nons)
    mut_counts = pd.DataFrame({'Genes':list(f_genic['L_data'].keys()), 'CHROM':chr_lst, 'OBS_MIS':mis_counts, 'OBS_NONS':nons_counts}, index = None)
    return mut_counts

def get_all_mut_df(tbx_str, f_genic_str):
    tbx = pysam.TabixFile(tbx_str)
    f_genic = h5py.File(f_genic_str, 'r')
    mut_counts = get_obs_cds_muts(f_genic, tbx)
    #remove muts from X and Y chromosomes
    mut_counts = mut_counts.loc[(mut_counts['CHROM'] != 'X') & (mut_counts['CHROM'] != 'Y')]
    f_genic.close()
    return mut_counts

#added Median absolute difference theshold to remove outliers
def load_fold_avg(f, cancer, test_idx=[]):
    hf = h5py.File(f, 'r')
    dset = hf[cancer]
    runs = [int(key) for key in dset['test'].keys() if key.isdigit()]
    
    test_Y = dset['test']['y_true'][:].reshape(-1, 1)
    if not len(test_idx):
        test_idx = dset['test']['chr_locs'][:]

    # print(test_idx.shape)
    test_Yhat_lst = []
    test_std_lst = []
    r2_lst = []
    run_lst = []
    for run in runs:
        y_hat = dset['test']['{}'.format(run)]['mean'][:].reshape(-1, 1)
        #gets rid of runs with all means predicted the same (casuses nan pearsonr)
        # if (y_hat-y_hat.mean()).sum() == 0:
        #     continue
        r2 = scipy.stats.pearsonr(test_Y.squeeze(), y_hat.squeeze())[0]**2

        if np.isnan(r2):
            continue

        r2_lst.append(r2)
        test_Yhat_lst.append(y_hat)
        test_std_lst.append(dset['test']['{}'.format(run)]['std'][:].reshape(-1, 1))
        run_lst.append(run)
        # print(r2_lst[-1])
    hf.close()
    r2s = np.array(r2_lst)
    med = np.median(r2s)
    mad = np.median(np.abs(r2s - med))
    # idx = np.array(run_lst)[r2s > (med - 2*mad)]
    idx = np.where(r2s > (med - 2*mad))
    test_Yhat = np.array(test_Yhat_lst)[idx].mean(axis = 0)
    test_std = np.array(test_std_lst)[idx].mean(axis = 0)
    vals = np.hstack([test_idx, test_Y, test_Yhat, test_std])
    df = pd.DataFrame(vals, columns=['CHROM', 'START', 'END', 'Y_TRUE', 'Y_PRED', 'STD'])

    return df

# loads the gp results from all folds 
# inputs : f_gp_base - base path to gp results df
#          fold_num - number of folds
# output : full dataframe of gp results
def get_gp_results_df(f_gp_base, cancer_str, fold_num=5, drop_pos_cols=True, sort=False):
    fold_nums = np.arange(fold_num)
    df_lst = [load_fold_avg(f_gp_base.format(fold), cancer=cancer_str) for fold in fold_nums]
    df = pd.concat(df_lst).astype({'CHROM':int, 'START':int, 'END':int, 'Y_TRUE':float, 'Y_PRED':float, 'STD':float})
    window = int(df.iloc[0]['END'] - df.iloc[0]['START'])
    df['Region'] = ['chr{}:{}-{}'.format(row[0], row[1], row[2]) for row in zip(df.CHROM, df.START, df.END)]

    if sort:
        df = df.sort_values(by=['CHROM', 'START'])

    if drop_pos_cols:
        df = df.drop(['CHROM', 'START', 'END'], axis = 1)

    df.set_index('Region', inplace=True)
    return df, window

def get_gp_submap_df(f_gp_base, cancer_str, fold_num=5, drop_pos_cols=True, sort=False, test_idx=[]):
    fold_nums = np.arange(fold_num)
    df_lst = [load_fold_avg(f_gp_base.format(fold), cancer=cancer_str, test_idx=test_idx) for fold in fold_nums]

    a_mean = np.array([df.Y_PRED.values for df in df_lst])
    mean = np.mean(a_mean, axis=0)

    a_std = np.array([df.STD.values for df in df_lst])
    std = np.mean(a_std, axis=0)

    df = pd.DataFrame({'CHROM': df_lst[0].CHROM.values,
                       'START': df_lst[0].START.values,
                       'END': df_lst[0].END.values,
                       'Y_TRUE': df_lst[0].Y_TRUE.values,
                       'Y_PRED': mean,
                       'STD': std,
                       }
                      ).astype({'CHROM':int, 'START':int, 'END':int, 'Y_TRUE':float, 'Y_PRED':float, 'STD':float})

    window = int(df.iloc[0]['END'] - df.iloc[0]['START'])
    df['Region'] = ['chr{}:{}-{}'.format(row[0], row[1], row[2]) for row in zip(df.CHROM, df.START, df.END)]

    if sort:
        df = df.sort_values(by=['CHROM', 'START'])

    if drop_pos_cols:
        df = df.drop(['CHROM', 'START', 'END'], axis = 1)

    df.set_index('Region', inplace=True)
    return df, window

# finds estimated region parameters (mu, sigma) for a given gene
# inputs : df - gp results df with chr locs and region parameters
#          chrom - gene chrom
#          intervals - 2d numpy array of start, end positions of cds regions
# output : average mu and sigma values for the non-duplicated overlapping regions
#          or returns -1,-1 if no regions are overlapped
def get_region_params(df, intervals, chrom):
    accum = overlapping_windows(df, chrom, intervals)
    if len(accum) ==0:
        return -1, -1
    mu = accum['Y_PRED'].sum()
    sigma = np.sqrt((accum['STD']**2).sum())
    return mu, sigma

#finds all windows that overap cds regions for a given gene
# inputs : df - chromosome location dataframe with at least chr, start and end
#          chrom - gene chrom
#          intervals - 2d numpy array of start, end positions of cds regions
# output : windows overlapped, with duplicates removed
def overlapping_windows(df, chrom, intervals):
    accum = []
    for i in range(intervals.shape[1]):
        start, end = intervals[:,i]
        accum.append(df.loc[(df['CHROM'] == int(chrom)) & 
                (((df['START'] < start) & (df['END'] > start)) |
                 ((df['END'] > end) & (df['START'] < end)) |
                ((df['START'] > start) & (df['END'] < end)))])
    return pd.concat(accum).drop_duplicates()

def calc_pvalue(mu, sigma, exp_sum, obs):
    alpha, theta = nb_model.normal_params_to_gamma(mu,sigma)
    p = 1 / (exp_sum * theta +1)
    return nb_model.nb_pvalue_exact(obs, alpha, p)

def get_S_prob(f_model_str, regions, cancer):
    mut_key = cancer + '_mutation_counts'
    df_mut = pd.read_hdf(f_model_str, key=mut_key)
    df_gen = pd.read_hdf(f_model_str, key='genome_counts')
    S_mut_train = df_mut.loc[regions].sum(axis=0) ## mutation context counts in train set
    S_gen_train = df_gen.loc[regions].sum(axis=0) ## trinucloetide counts in train set
    Pr_mut_train = nb_model.mutation_freq_conditional(S_mut_train, S_gen_train, 1)
    probs = pd.DataFrame(Pr_mut_train).reset_index()
    for i in probs.iterrows():
        base = i[1]['index'][1]
        mut = i[1]['index'][0].split('>')[1]
        alt = base[0] + mut + base[2]
        out = base + '>' + alt
        probs.loc[i[0],'index'] = out
    probs = probs.set_index('index')
    return probs

#gets counts for the possible occurences of all 192 SNP transitions across the given regions
def si_by_regions(fasta, trans_idx, regions, n_up=1, n_down=1, normed=True):
    """ Get the probability of mutation at every position across a region
    """
    keys = set(list(trans_idx))
    d = {key: 0 for key in keys}
    all_nucs = set('ATCG')
    for r in regions:
        splt1 = r.split('-')
        end = int(splt1[1])
        splt2 = splt1[0].split(':')
        start = int(splt2[1])
        chrom = splt2[0]
        seq, _, _ = sequence_tools.fetch_sequence(fasta, chrom, start, end, n_up=n_up, n_down=n_down)
   
        for i in range(n_up, len(seq)-n_down):
            substr = sequence_tools.seq_to_context(seq[i-n_up:i+n_down+1], baseix=n_up)
            # trinucs.append(substr)
            if not substr:
                continue

            alts = ''.join(list(all_nucs.difference(substr[n_up])))
            prod_items = [substr[:n_up]] + [alts] + [substr[n_down + 1:]]
            trans = [''.join(tup) for tup in it.product(*prod_items)]
            trans = [substr + '>' + t for t in trans]
            for t in trans:
                d[t] += 1 
    return pd.DataFrame(d.values(), index = d.keys())

def genic_seq_model(S_probs, f_genic, f_model_str, fasta, gene, cancer, window):
    chrom = f_genic['chr'][gene][:][0].decode("utf-8")
    intervals = f_genic['cds_intervals'][gene][:]
    regions_overlapped = [trip_to_str(r) for r in get_ideal_overlaps(chrom, intervals, window)]
    
    S_i = si_by_regions(fasta, S_probs.index, regions_overlapped)
    
    prob_sum = S_i * S_probs
    rescaled = S_probs / prob_sum[0].sum()
    rescaled[1] = rescaled[0]
    rescaled[2] = rescaled[0]
    rescaled[3] = rescaled[0]
    return rescaled

def get_ideal_overlaps(chrom, intervals, window):
    region_lst = []
    for i in intervals.T:
        low = math.floor(i[0].min() / window) * window
        high = math.ceil(i[1].max() / window) * window
        borders = np.arange(low, high +window, window)
        for i in range(len(borders)-1):
            region_lst.append((int(chrom),borders[i], borders[i+1]))
    return list(set(region_lst))

def get_region_params_interp(df, chrom, intervals, mapp, window):
    ideal = get_ideal_overlaps(chrom, intervals, window)
    mu_sum = 0
    var_sum = 0
    for row in ideal:
        row_str = trip_to_str(row)
        if row_str in df.index:
            mu_sum += df.loc[row_str, 'Y_PRED']
            var_sum += df.loc[row_str, 'STD']**2
        else:
            right_str, right_dis = get_right_parent(row, df, mapp)
            left_str, left_dis = get_left_parent(row, df, mapp)
            total_dis = right_dis + left_dis

            if right_dis > 10 or left_dis > 10:
                return -1, -1

            if right_dis == 0:
                mu_sum += df.loc[left_str, 'Y_PRED']
                var_sum += df.loc[left_str, 'STD']**2
            elif left_dis == 0:
                mu_sum += df.loc[right_str, 'Y_PRED']
                var_sum += df.loc[right_str, 'STD']
            else:
                mu_sum += (right_dis * df.loc[right_str, 'Y_PRED'] + left_dis * df.loc[left_str, 'Y_PRED']) / total_dis
                var_sum += (right_dis * df.loc[right_str, 'STD']**2 + left_dis * df.loc[left_str, 'STD']**2) / total_dis
    mu = mu_sum
    sigma = np.sqrt(var_sum)
    return mu, sigma

def get_R_obs(chrom, intervals, tbx, all_windows_df, window):
    windows = get_ideal_overlaps(chrom, intervals, window)
    obs_sum = 0
    for row in windows:
            chrom, start, end = int(row[0]), row[1], row[2]
            obs_sum += len(list(tbx.fetch(chrom, start, end)))
    return obs_sum

def non_genic_windows(f_genic_str, all_regions, window):
    f_genic = h5py.File(f_genic_str, 'r')
    overlapped = []
    for i, gene in enumerate(list(f_genic['cds_intervals'].keys())):
        chrom = f_genic['chr'][gene][:][0].decode("utf-8")
        if chrom != 'X' and chrom != 'Y':
            overlapped.extend([trip_to_str(t) for t in get_ideal_overlaps(chrom, f_genic['cds_intervals'][gene][:], window)])
    f_genic.close()
    return list(set(all_regions).difference(set(overlapped)))

def trip_to_str(trip):
    return 'chr{}:{}-{}'.format(trip[0], trip[1], trip[2])

def get_right_parent(cur, df, mapp):
    distance = 0
    start = cur
    cur_str = trip_to_str(cur)
    while not cur_str in df.index:
        if distance > 100:
            #print('Warning: right parent > 50 away, starting{}'.format(start))
            return trip_to_str(start), 0
        cur = (cur[0], cur[2] , cur[2]+ 10000)
        cur_str = trip_to_str(cur)
        distance += 1
    return trip_to_str(cur), distance

def get_left_parent(cur, df, mapp):
    distance = 0
    cur_str = trip_to_str(cur)
    while not cur_str in df.index:
        if cur[1] == 0:
            return trip_to_str(cur), 0
        cur = (cur[0], cur[1] - 10000, cur[1])
        cur_str = trip_to_str(cur)
        distance += 1
    return trip_to_str(cur), distance

def load_ho_avg(f, cancer):
    hf = h5py.File(f, 'r')
    dset = hf[cancer]
    N_runs = len([key for key in dset['held-out'].keys() if key.isdigit()])
    
    test_Y = dset['held-out']['y_true'][:].reshape(-1, 1)
    test_idx = dset['held-out']['chr_locs'][:]
    test_Yhat_lst = []
    test_std_lst = []
    r2_lst = []
    for run in range(N_runs):
        y_hat = dset['held-out']['{}'.format(run)]['mean'][:].reshape(-1, 1)
        test_Yhat_lst.append(y_hat)
        test_std_lst.append(dset['held-out']['{}'.format(run)]['std'][:].reshape(-1, 1))
        r2_lst.append((scipy.stats.pearsonr(test_Y, y_hat)[0]**2)[0])
    hf.close()
    r2s = np.array(r2_lst)
    med = np.median(r2s)
    mad = np.median(np.abs(r2s - med))
    idx = np.where(r2s > (med - 2*mad))
    test_Yhat = np.array(test_Yhat_lst)[idx].mean(axis = 0)
    test_std = np.array(test_std_lst)[idx].mean(axis = 0)
    vals = np.hstack([test_idx, test_Y, test_Yhat, test_std])
    df = pd.DataFrame(vals, columns=['CHROM', 'START', 'END', 'Y_TRUE', 'Y_PRED', 'STD'])
    df['Region'] = ['chr{}:{}-{}'.format(int(row[0]), int(row[1]), int(row[2])) for row in zip(df.CHROM, df.START, df.END)]
    df = df.drop(['CHROM', 'START', 'END'], axis = 1)
    df.set_index('Region', inplace=True)
    return df

def get_region_params_est(df, est, chrom, intervals, mapp, window):
    ideal = get_ideal_overlaps(chrom, intervals, window)
    mu_sum = 0
    var_sum = 0
    for row in ideal:
        row_str = trip_to_str(row)
        if row_str in df.index:
            mu_sum += df.loc[row_str, 'Y_PRED']
            var_sum += df.loc[row_str, 'STD']**2
        else:
            mu_sum += est.loc[row_str, 'Y_PRED']
            var_sum += est.loc[row_str, 'STD']**2
    mu = mu_sum
    sigma = np.sqrt(var_sum)
    return mu, sigma

#non_coding tools
def get_nonc_mut_parallel(tbx_str, f_nonc_str, f_key, N_procs):
    tbx = pysam.TabixFile(tbx_str)
    f_nonc = pd.read_hdf(f_nonc_str, f_key)
    chunksize = int(np.ceil(len(f_nonc) / N_procs))
    res = []
    pool = mp.Pool(N_procs)
    for i in np.arange(0, len(f_nonc), chunksize):
        df = f_nonc.iloc[i:i+chunksize]

        r = get_obs_nonc_muts(df, tbx)
        res.append(r)

    pool.close()
    pool.join()

    res_lst = [r.get() for r in res]
    complete = pd.concat(res_lst)
    return complete

def nonc_mut_df(tbx_str, f_nonc_str, f_key):
    tbx = pysam.TabixFile(tbx_str)
    f_nonc = pd.read_hdf(f_nonc_str, f_key)
    mut_counts = get_obs_nonc_muts(f_nonc, tbx)
    #assumed that no x, y or m chrom data is present
    #mut_counts = mut_counts.loc[(mut_counts['CHROM'] != 'X') & (mut_counts['CHROM'] != 'Y')]
    return mut_counts

def get_tbx_nonc_muts(start, end, chrom, tbx):
    select = pd.DataFrame([x.split('\t') for x in tbx.fetch(chrom, start, end)])
    return len(select)

def get_obs_nonc_muts(f_nonc, tbx):
    mut_counts = []
    chr_lst = []
    strt_lst = []
    end_lst = []
    elt_lst = []
    type_lst = []
    for row in f_nonc.iterrows():
        chrom = row[1][0].split('chr')[1]
        chr_lst.append(chrom)
        strt = row[1][1]
        end = row[1][2]
        strt_lst.append(strt)
        end_lst.append(end)
        elt_lst.append(row[1][3])
        type_lst.append(row[1][4])
        muts = get_tbx_nonc_muts(strt, end, chrom, tbx)
        mut_counts.append(muts)
    mut_counts = pd.DataFrame({'CHROM':chr_lst, 'START':strt_lst, 'END':end_lst, 'ELT':elt_lst, 'ELT_TYPE':type_lst, 'OBS_MUT':mut_counts}, index = None)
    return mut_counts

def nonc_complement_windows(f_nonc_str, all_regions, f_key, window):
    f_nonc = pd.read_hdf(f_nonc_str, f_key)
    overlapped = []
    for row in f_nonc.iterrows():
        chrom = row[1][0].split('chr')[1]
        start = row[1][1]
        end = row[1][2]
        overlapped.extend([trip_to_str(t) for t in get_elt_ideal_overlaps(chrom, start, end, window)])
    
    return list(set(all_regions).difference(set(overlapped)))
    
def get_elt_ideal_overlaps(chrom, start, end, window):
    region_lst = []
    low = math.floor(start / window) * window
    high = math.ceil(end / window) * window
    borders = np.arange(low, high + window, window)
    for i in range(len(borders)-1):
        region_lst.append((int(chrom),borders[i], borders[i+1]))
    return list(set(region_lst))

def get_elt_hi_lo(chrom, start, end, window):
    low = math.floor(start / window) * window
    high = math.ceil(end / window) * window
    return high,low

#copied to allow for thesholding later
def base_probabilities_by_nonc_region(fasta, S_prob, CHROM, START, END, n_up=1, n_down=1, normed=True, collapse=False):
    """ Get the probability of mutation at every position across a region
    """
    seq, start, end = sequence_tools.fetch_sequence(fasta, 'chr' + str(CHROM), START, END, n_up=n_up, n_down=n_down)
    probs = []
    poss = []
    # trinucs = []
    for i in range(n_up, len(seq)-n_down):
        poss.append(start+i)
        substr = sequence_tools.seq_to_context(seq[i-n_up:i+n_down+1], baseix=n_up, collapse=collapse)
        # trinucs.append(substr)
        if not substr:
            probs.append(0)
            continue

        probs.append(S_prob[substr])

    probs = np.array(probs)
    poss = np.array(poss)
    # trinucs = np.array(trinucs)

    if normed:
        probs = probs / np.sum(probs)

    return probs, poss
    # return probs, poss, trinucs

def nonc_model(df, est_str, d_pr, f_model_str, f_nonc_str, fasta_str, tbx_str, all_windows_df, cancer, mapp, window):
    df_obs = df.copy().astype({'CHROM':int, 'START':int, 'END':int, 'ELT':str, 'ELT_TYPE':str, 'OBS_MUT':int})
    fasta = pysam.FastaFile(fasta_str)
    tbx = pysam.TabixFile(tbx_str)
    ##patch for sub_map bug. will remove for final 
    if est_str is not None:
        f = h5py.File(est_str.format(0))
        if f[cancer]['test/chr_locs'].shape[0] != f[cancer]['test/y_true'].shape[0]:
            f2 = h5py.File('/scratch1/maxas/ICGC_Roadmap/DNN_predictors/10kb/roadmap_tracks_735_10kb.unzipped.h5', 'r')
            idx_all = f2['idx'][:]
            mapp = f2['mappability'][:]
            idx_lowmap = idx_all[mapp < 0.5]
            est, _ = get_gp_submap_df(est_str, cancer, test_idx = idx_lowmap)
        else:
            est, _ = get_gp_submap_df(est_str, cancer)
        f.close()
    
    p_mut_lst = []
    mu_lst = []
    s_lst = []
    pvals_lst = []
    R_obs_lst = []
    exp_lst = []
    elt_lst = []
    type_lst = []

    for i, row in df_obs.iterrows():
        chrom = row[0]
        start = row[1]
        end = row[2]
        elt_lst.append(row[3])
        type_lst.append(row[4])
        obs_mut = row[5]

        #t_pi = nonc_seq_model(d_pr, chrom, start, end, f_model_str, fasta, window)
        hi,lo = get_elt_hi_lo(chrom, start, end, window)
        probs, poss = base_probabilities_by_nonc_region(fasta, d_pr, chrom, lo, hi)
        p_mut = nonc_sum_p(probs, poss, chrom, start, end)
        p_mut_lst.append(p_mut)
        if est_str is None:
            mu,sigma = get_nonc_params_nointerp(all_windows_df, chrom, start, end, mapp, window)
        else:
            mu, sigma = get_nonc_region_params_est(all_windows_df, est, chrom, start, end, window)
        mu_lst.append(mu)
        s_lst.append(sigma)
        pval_mut = calc_pvalue(mu, sigma, p_mut, obs_mut)
        pvals_lst.append(pval_mut)
        R_obs_lst.append(get_R_nonc(chrom, start, end, window, tbx, all_windows_df))
        alpha, theta = nb_model.normal_params_to_gamma(mu,sigma)
        exp_lst.append(alpha*theta*p_mut)
        
    df_obs['EXP'] = exp_lst
    df_obs['R_OBS'] = R_obs_lst
    df_obs['MU'] = mu_lst
    df_obs['SIGMA'] = s_lst
    df_obs['PVAL'] = pvals_lst
    return df_obs

def nonc_model_parallel(nonc_muts_obs, est_str, d_pr, f_model_str, f_nonc_str, fasta_str, tbx_str, all_windows_df, cancer, mapp, window, N_procs):
        
    ## Parallel chunk parameters:
    chunksize = int(np.ceil(len(nonc_muts_obs) / N_procs))
    res = []
    pool = mp.Pool(N_procs)
    for i in np.arange(0, len(nonc_muts_obs), chunksize):
        df = nonc_muts_obs.iloc[i:i+chunksize]

        r = pool.apply_async(nonc_model,(df, est_str, d_pr, f_model_str, f_nonc_str, fasta_str, tbx_str, all_windows_df, cancer, mapp, window))
        res.append(r)

    pool.close()
    pool.join()

    res_lst = [r.get() for r in res]
    complete = pd.concat(res_lst)
    return complete

def get_nonc_region_params_est(df, est, chrom, start, end, window):
    ideal = get_elt_ideal_overlaps(chrom, start, end, window)
    mu_sum = 0
    var_sum = 0
    for row in ideal:
        row_str = trip_to_str(row)
        if row_str in df.index:
            mu_sum += df.loc[row_str, 'Y_PRED']
            var_sum += df.loc[row_str, 'STD']**2
        else:
            mu_sum += est.loc[row_str, 'Y_PRED']
            var_sum += est.loc[row_str, 'STD']**2
    mu = mu_sum
    sigma = np.sqrt(var_sum)
    return mu, sigma

def nonc_seq_model(S_probs, chrom, start, end, f_model_str, fasta, window):
    regions_overlapped = [trip_to_str(r) for r in get_elt_ideal_overlaps(chrom, start, end, window)]
    print(S_probs.values.sum())
    S_i = si_by_regions(fasta, S_probs.index, regions_overlapped)
    
    prob_sum = S_i * S_probs
    rescaled = S_probs / prob_sum.sum()
    return rescaled

def nonc_sum_p(probs, poss, chrom, start, end):
    pos_start = poss[0]
    length_seq = end - start
    start_offset = start - pos_start
    pt = np.sum(probs[start_offset:start_offset+length_seq])

    return pt

def get_nonc_params_nointerp(df, chrom, start, end, mapp, window):
    ideal = get_elt_ideal_overlaps(chrom, start, end, window)
    mu_sum = 0
    var_sum = 0
    for row in ideal:
        row_str = trip_to_str(row)
        if row_str in df.index:
            mu_sum += df.loc[row_str, 'Y_PRED']
            var_sum += df.loc[row_str, 'STD']**2
        else:
            return -1, -1
    mu = mu_sum
    sigma = np.sqrt(var_sum)
    return mu, sigma

def get_R_nonc(chrom, start, end, window, tbx, all_windows_df):
    windows = get_elt_ideal_overlaps(chrom, start, end, window)
    obs_sum = 0
    for row in windows:
            chrom, start, end = int(row[0]), row[1], row[2]
            obs_sum += len(list(tbx.fetch(chrom, start, end)))
    return obs_sum

def nonc_train_sequence_model(train_idx, f_model, N, key_prefix=None):
    """ Train a trinucleotide sequence model based on precalculated mutational frequencies
        and trinucleotide occurences across the genome
    """

    train_idx_str = train_idx

    key_mut = 'mutation_counts'
    if key_prefix:
        key_mut = key_prefix + "_" + key_mut

    df_mut = pd.read_hdf(f_model, key=key_mut)
    df_gen = pd.read_hdf(f_model, key='genome_counts')

    S_mut_train = df_mut.loc[train_idx_str, :].sum(axis=0) ## mutation context counts in train set
    S_gen_train = df_gen.loc[train_idx_str, :].sum(axis=0) ## trinucloetide counts in train set

    ## Probabilities stratified by mutation type
    Pr_mut_train = nb_model.mutation_freq_conditional(S_mut_train, S_gen_train, N)

    ## Probabilities by trinucleotide context
    keys = set([tup[1] for tup in Pr_mut_train.index])
    d = {key: 0 for key in keys}
    for key in d:
        d[key] = sum([Pr_mut_train[tup] for tup in Pr_mut_train.index if tup[1]==key])
            

    # return Pr_mut_train, S_pr
    return Pr_mut_train, d

def legacy_get_gp_submap_df(f_gp, cancer_str):
    return load_fold_avg(f_gp, cancer=cancer_str)