import pandas as pd
import numpy as np
import json
import itertools as it
import pysam
import multiprocessing as mp
import subprocess as sp

DNA53 = 'NTCGA'
DNA35 = 'NAGCT'
trans = DNA53.maketrans(DNA53, DNA35) 

def reverse_complement(seq):
    return seq[::-1].translate(trans)

def fetch_sequence(fasta, CHROM, START, END, n_up=2, n_down=2):
    """ Fetch a sequence expanded by one bp on either end
        to allow for trinucleotide counting of all positions
    """
    if START == 0:
        START = n_up

    seq = fasta.fetch(CHROM, START-n_up, END+n_down).upper()
    return seq, START-n_up, END+n_down

def mk_context_sequences(n_up=2, n_down=2, collapse=False):
    DNA = 'ACGT'
    NUC = 'ACGT'
    if collapse:
        NUC = 'CT'

    prod_items = [DNA]*n_up + [NUC] + [DNA]*n_down

    keys = [''.join(tup) for tup in it.product(*prod_items)]
    return {key: 0 for key in keys}

def seq_to_context(seq, baseix=2, collapse=False):
    """ Convert any sequence into
        its unique nucleotide context

    kwarg   baseix: the index of the base around which the context is constructed
    """
    if 'N' in seq:
        return ''

    if collapse:
        if seq[baseix] == 'G' or seq[baseix] == 'A':
            return reverse_complement(seq)

    return seq

def type_mutation(REF, ALT, collapse=False):
    if collapse:
        if REF == 'G' or REF == 'A':
            REF = REF.translate(trans)
            ALT = ALT.translate(trans)

    return "{}>{}".format(REF, ALT)

def count_sequence_context(seq, n_up=2, n_down=2, nuc_dict=None, collapse=False):
    """ Count the nucleotides context present in a sequence
    """
    if not nuc_dict:
        nuc_dict = mk_context_sequences(n_up=n_up, n_down=n_down, collapse=collapse)

    for i in range(n_up, len(seq)-n_down):
        substr = seq_to_context(seq[i-n_up:i+n_down+1], baseix=n_up, collapse=collapse)
        if not substr:
            continue

        nuc_dict[substr] += 1

    return nuc_dict

def count_contexts_by_regions(f_fasta, chrom_lst, start_lst, end_lst, n_up=2, n_down=2, collapse=False):
    """ Sequence context counts within a set of regions
    """

    fasta = pysam.FastaFile(f_fasta)
    # print(set(chrom_lst), end = " ")

    idx_lst = []
    dict_lst = []
    for CHROM, START, END in zip(chrom_lst, start_lst, end_lst):
        seq, _, _ = fetch_sequence(fasta, CHROM, START, END, n_up=n_up, n_down=n_down)
        dict_lst.append(count_sequence_context(seq, n_up=n_up, n_down=n_down, collapse=collapse))
        idx_lst.append("{}:{}-{}".format(CHROM, START, END))
    
    return pd.DataFrame(dict_lst, index=idx_lst)

def count_contexts_in_genome(f_fasta, f_mapp, window, n_up=2, n_down=2, N_proc=1, N_chunk=10, collapse=False):
    """ Count nucleotide context based on a genome-wide mappability file
    """

    map_dict = json.load(open(f_mapp, 'r'))

    chrom_lst = []
    start_lst = []

    for key in map_dict:
        start_lst += map_dict[key]
        chrom_lst += [key] * len(map_dict[key])

    end_lst = [start+window for start in start_lst]

    # idx_lst = ["{}:{}-{}".format(chrom, start, end) 
    #            for chrom, start, end in zip(chrom_lst, start_lst, end_lst)]
    # chrom_lst_int = [int(chrom.split('chr')[-1]) for chrom in chrom_lst]
    # df_pos = pd.DataFrame([chrom_lst_int, start_lst, end_lst], columns=idx_lst,
    #                       index=['CHROM', 'START', 'END']).T

    chunk_size = int(len(chrom_lst) / (N_chunk - 1))
    chunk_idx = list(range(0, len(chrom_lst), chunk_size)) + [len(chrom_lst)]

    res = []
    pool = mp.Pool(N_proc)
    for idx_start, idx_end in zip(chunk_idx[:-1], chunk_idx[1:]):
        chrom_tmp = chrom_lst[idx_start:idx_end]
        start_tmp = start_lst[idx_start:idx_end]
        end_tmp   =   end_lst[idx_start:idx_end]

        res.append(pool.apply_async(count_contexts_by_regions, (f_fasta, chrom_tmp, start_tmp, end_tmp),
                                                          dict(n_up=n_up, n_down=n_down, collapse=collapse)
                                   )
                  )
        # res.append(count_trinuc_regions(fasta, chrom_tmp, start_tmp, end_tmp))

    pool.close()
    pool.join()

    df_lst = [r.get() for r in res]
    # df_lst = res

    df_cnt = pd.concat(df_lst)
    # df = df_pos.merge(df_cnt, left_index=True, right_index=True)

    return df_cnt

def mutation_contexts_by_chrom(f_fasta, df, n_up=2, n_down=2, collapse=False):
    fasta = pysam.FastaFile(f_fasta)
    CHROM = str(df.CHROM.iloc[0])
    if not CHROM.startswith('chr'):
        CHROM = "chr{}".format(CHROM)

    seq = fasta.fetch(CHROM).upper()

    cntxt_lst = []
    muttype_lst = []
    prev_start = -1
    prev_alt = ''
    prev_ref = ''

    for START, REF, ALT in zip(df.START.values, df.REF.values, df.ALT.values):
        if seq[START] != REF:
            print('WARNING: REF {} does not match sequence {} at {}'.format(REF, seq[START], START))

        if START == prev_start:
            substr = cntxt_lst[-1]

            if ALT == prev_alt:
                mut = muttype_lst[-1]
            else:
                mut = type_mutation(REF, ALT, collapse=collapse)

        else:
            substr = seq_to_context(
                        seq[START-n_up:START+n_down+1], baseix=n_up, collapse=collapse
                     )

            if ALT == prev_alt and REF == prev_ref:
                mut = muttype_lst[-1]
            else:
                mut = type_mutation(REF, ALT, collapse=collapse)

        muttype_lst.append(mut)
        cntxt_lst.append(substr)

        prev_start = START
        prev_ref   = REF
        prev_alt   = ALT

    df.insert(df.shape[1], 'MUT_TYPE', muttype_lst)
    df.insert(df.shape[1], 'CONTEXT', cntxt_lst)

    return df

def add_context_to_mutations(f_fasta, df_mut, n_up=2, n_down=2, N_proc=1, collapse=False):
    """ Add sequence context annotations to mutations
    """
    # cols = ['CHROM', 'START', 'END', 'REF', 'ALT', 'ID']
    # df_mut = pd.read_csv(f_mut, sep="\t", names=cols)

    res = []
    pool = mp.Pool(N_proc)
    for chrom, df in df_mut.groupby('CHROM'):
        if 'MT' in str(chrom):
            continue

        res.append(pool.apply_async(
                        mutation_contexts_by_chrom, 
                        (f_fasta, df), dict(n_up=n_up, n_down=n_down, collapse=collapse)
                   )
                  )
        # res.append(trinuc_mutation_by_chrom(f_fasta, df))

    pool.close()
    pool.join()

    df_lst = [r.get() for r in res]
    df_out = pd.concat(df_lst)

    return df_out

def bgzip(filename):
    """Call bgzip to compress a file."""
    sp.run(['bgzip', '-f', filename])

def tabix_index(filename, preset="bed", skip=1, comment="#"):
    """Call tabix to create an index for a bgzip-compressed file."""
    sp.run(['tabix', '-p', preset, '-S {}'.format(skip), filename])

def mk_mutation_context(n_up=2, n_down=2, collapse=False):
    DNA = 'ACGT'
    prod_items_T = [DNA]*n_up + ['T'] + [DNA]*n_down
    prod_items_C = [DNA]*n_up + ['C'] + [DNA]*n_down

    keys_T = [''.join(tup) for tup in it.product(*prod_items_T)]
    keys_C = [''.join(tup) for tup in it.product(*prod_items_C)]

    muts_T = ['T>A', 'T>G', 'T>C']
    muts_C = ['C>A', 'C>G', 'C>T']

    if collapse:
        tups = [tup for tup in it.product(muts_C, keys_C)] + \
               [tup for tup in it.product(muts_T, keys_T)]

    else:
        prod_items_A = [DNA]*n_up + ['A'] + [DNA]*n_down
        prod_items_G = [DNA]*n_up + ['G'] + [DNA]*n_down

        keys_A = [''.join(tup) for tup in it.product(*prod_items_A)]
        keys_G = [''.join(tup) for tup in it.product(*prod_items_G)]

        muts_A = ['A>T', 'A>C', 'A>G']
        muts_G = ['G>T', 'G>C', 'G>A']

        tups = [tup for tup in it.product(muts_A, keys_A)] + \
               [tup for tup in it.product(muts_C, keys_C)] + \
               [tup for tup in it.product(muts_G, keys_G)] + \
               [tup for tup in it.product(muts_T, keys_T)]

    # multi_idx = pd.MultiIndex.from_tuples(tups, sortorder=1)
    # # print(multi_idx)
    # return pd.Series([0]*96, index=multi_idx)

    d = {tup: 0 for tup in tups}
    # for mut in muts_C + muts_T:
    #     if mut.startswith('C'):
    #         d[mut] = {key: 0 for key in keys_C}
    #     else:
    #         d[mut] = {key: 0 for key in keys_T}

    return d

    # return tups

def count_mutation_contexts(tbx_mut, CHROM, START, END, n_up=2, n_down=2, collapse=False):
    d = mk_mutation_context(n_up=n_up, n_down=n_down, collapse=collapse)

    for row in tbx_mut.fetch(CHROM, START, END):
        vals = row.split("\t")
        mut     = vals[-2]
        context = vals[-1]
        d[(mut, context)] += 1

    return d

def count_mutations_by_regions(f_mut, chrom_lst, start_lst, end_lst, n_up=2, n_down=2, collapse=False):
    tbx = pysam.TabixFile(f_mut)

    idx_lst = []
    dict_lst = []

    chrom_prev = ''
    for CHROM, START, END in zip(chrom_lst, start_lst, end_lst):
        # if CHROM != chrom_prev:
        #     print(CHROM)

        dict_lst.append(count_mutation_contexts(tbx, CHROM, START, END, n_up=n_up, n_down=n_down, collapse=collapse))
        idx_lst.append("chr{}:{}-{}".format(CHROM, START, END))
        chrom_prev = CHROM

    return pd.DataFrame(dict_lst, index=idx_lst)

def count_mutations_in_genome(f_mut, f_mapp, window, n_up=2, n_down=2, N_procs=1, collapse=False):
    """ Count mutation trinucleotide occurences in a mutation bed file
    """
    map_dict = json.load(open(f_mapp, 'r'))

    chrom_lst = []
    start_lst = []

    for key in map_dict:
        start_lst += map_dict[key]
        chrom_lst += [key.split('chr')[-1]] * len(map_dict[key])

    end_lst = [start+window for start in start_lst]

    # print(chrom_lst[0:5])
    chunksize = int(np.ceil(len(chrom_lst) / (N_procs*2)))
    print('Chunk size is: ', chunksize)

    res = []
    pool = mp.Pool(N_procs)

    for i in np.arange(0, len(chrom_lst), chunksize):
        chunk_chr = chrom_lst[i:i+chunksize]
        chunk_start = start_lst[i:i+chunksize]
        chunk_end = end_lst[i:i+chunksize]
        
        

        r = pool.apply_async(count_mutations_by_regions, (f_mut, chunk_chr, chunk_start, chunk_end), 
                                                         dict(n_up=n_up, n_down=n_down, collapse=collapse))
        res.append(r)

    pool.close()
    pool.join()

    res_lst = [r.get() for r in res]
    df = pd.concat(res_lst)

    return df

def base_probabilities_by_region(fasta, S_prob, CHROM, START, END, n_up=2, n_down=2, normed=True, collapse=False):
    """ Get the probability of mutation at every position across a region
    """
    seq, start, end = fetch_sequence(fasta, CHROM, START, END, n_up=n_up, n_down=n_down)

    probs = []
    poss = []
    # trinucs = []
    for i in range(n_up, len(seq)-n_down):
        poss.append(start+i)
        substr = seq_to_context(seq[i-n_up:i+n_down+1], baseix=n_up, collapse=collapse)
        # trinucs.append(substr)
        if not substr:
            probs.append(0)
            continue

        probs.append(S_prob[substr])

    probs = np.array(probs)
    poss = np.array(poss)
    # trinucs = np.array(trinucs)

    if normed:
        probs = probs / np.sum(probs)

    return probs, poss
    # return probs, poss, trinucs
