#!/usr/bin/env python

import pandas as pd
import numpy as np
import h5py
import argparse
import pysam
import json

import sequence_tools
import nb_model
import gp_tools

def countGenomeContext(args):
    print('Counting nucleotide contexts genome-wide')
    df = sequence_tools.count_contexts_in_genome(args.fasta, args.mapDict, args.window, 
                                                 n_up=args.up, n_down=args.down,
                                                 N_proc = args.n_procs, N_chunk=args.n_procs,
                                                 collapse=args.collapse)

    print('Saving context counts to {}'.format(args.fout))
    df.to_hdf(args.fout, key='genome_counts', mode='a')

    h5 = h5py.File(args.fout, 'a')
    h5.attrs['n_up'] = args.up
    h5.attrs['n_down'] = args.down

def annotateMutationFile(args):
    h5 = h5py.File(args.h5genome, 'r')
    n_up = h5.attrs['n_up']
    n_down = h5.attrs['n_down']
    h5.close()

    print('Reading in mutation file')
    df_mut = pd.read_csv(args.fmut, sep="\t", low_memory=False,
                    names=['CHROM', 'START', 'END', 'REF', 'ALT', 'ID', 'ANNOT'])
    print(df_mut[0:5])

    print('Extracting mutation contexts')
    df_mut2 = sequence_tools.add_context_to_mutations(args.fasta, df_mut, 
                                                      n_up=n_up, n_down=n_down,
                                                      N_proc=args.n_procs,
                                                      collapse=args.collapse
    )

    print('Saving annotated mutation file')
    df_mut2.to_csv(args.fout, sep="\t", index=False)
    sequence_tools.bgzip(args.fout)
    sequence_tools.tabix_index(args.fout+'.gz')

def countMutationContext(args):
    h5 = h5py.File(args.h5genome, 'r')
    n_up = h5.attrs['n_up']
    n_down = h5.attrs['n_down']
    h5.close()
    
    print('Counting mutations by context genome-wide')
    df = sequence_tools.count_mutations_in_genome(args.fmut, args.mapDict, args.window,
                                                  n_up=n_up, n_down=n_down, N_procs=args.n_procs,
                                                  collapse=args.collapse
                                                 )
    print('Saving context counts to {} under key {}_mutation_counts'.format(args.h5genome, args.keyPrefix))
    df.to_hdf(args.h5genome, key='{}_mutation_counts'.format(args.keyPrefix), mode='a')

def applySequenceModel(args):
    if args.run in ('None', 'none'):
        print('Picking best GP run')
        args.run = gp_tools.pick_gp_by_calibration(args.GPresults, cancer=args.cancer, dataset=args.dataset)

    print('Selected run is: {}'.format(args.run))

    h5 = h5py.File(args.fmodel, 'r')
    n_up = h5.attrs['n_up']
    n_down = h5.attrs['n_down']
    h5.close()

    bins = [b for b in args.bins if b < args.window]

    bins_ignore = [b for b in args.bins if b > args.window]
    if len(bins_ignore) > 0:
        print('Ignoring bin sizes of {} b/c larger than initial window of {}.'.format(bins_ignore, args.window))

    for binsize in bins:
        args.binsize = binsize
        print('\nRunning NB sequence model with binsize: {}'.format(args.binsize))
        N = 1  ## because mutation probs are normalized, the value of N is inconsequential.
        df = nb_model.sequence_model_parallel(N, args.GPresults, args.fmodel, 
                                         args.run, args.fasta, args.fmut, 
                                         n_up=n_up, n_down=n_down, 
                                         binsize=args.binsize, N_procs=args.n_procs,
                                         cancer=args.cancer, key_prefix=args.key_prefix, collapse=args.collapse)

        print('retyping')
        df = df.apply(pd.to_numeric, errors='ignore')
        # df.iloc[:, :-1] = df.iloc[:, :-1].astype(float)
        # print(df.CHROM)
        savekey = "test/nb_model_up{}_down{}_binsize{}_run_{}".format(n_up, n_down, args.binsize, args.run)
        if args.cancer:
            savekey = args.cancer + "/" + savekey

        print('\tSaving results to {} under key {}'.format(args.GPresults, savekey))
        df.to_hdf(args.GPresults, key=savekey, mode='a', format='table')
        # df.to_pickle(args.GPresults + '.pkl')

    h5 = h5py.File(args.GPresults, 'a')
    if args.cancer:
        h5[args.cancer]['test'].attrs['gp_run_for_nb_model'] = args.run
    else:
        h5['test'].attrs['gp_run_for_nb_model'] = args.run

    h5.close()

def parse_args(text = None):
    parser = argparse.ArgumentParser()
    subparser = parser.add_subparsers()

    parse_a = subparser.add_parser('countGenomeContext', 
                                   help='count the number of occurences of nucleotide contexts in a genome')
    parse_a.add_argument('mapDict', type=str, help='path to mappable regions file')
    parse_a.add_argument('window', type=int, help='window size of mappable regions')
    parse_a.add_argument('fasta', type=str, help='path to fasta file')
    parse_a.add_argument('fout', type=str, help='output h5 file name')
    parse_a.add_argument('--up', type=int, default=2, 
                         help='number of bases upstream of position to include as context')
    parse_a.add_argument('--down', type=int, default=2, 
                         help='number of bases downstream of position to include as context')
    parse_a.add_argument('--n-procs', type=int, default=15, 
                         help='number of threads to parallelize over')
    parse_a.add_argument('--collapse', action='store_true', default=False, 
                         help='collapse mutations to C or T using reverse complementation')
    parse_a.set_defaults(func=countGenomeContext)

    parse_b = subparser.add_parser('annotateMutationFile',
                                   help='Add sequence context annotations to a file of mutations')
    parse_b.add_argument('fmut', type=str, help='path to mutation file')
    parse_b.add_argument('h5genome', type=str, 
                         help='h5 file containing genome-wide context counts (from countGenomeContext)')
    parse_b.add_argument('fasta', type=str, help='path to fasta file')
    parse_b.add_argument('fout', type=str, help='output file name')
    # parse_b.add_argument('--up', type=int, default=2, 
    #                      help='number of bases upstream of mutation to include as context')
    # parse_b.add_argument('--down', type=int, default=2, 
    #                      help='number of bases downstream of mutation to include as context')
    parse_b.add_argument('--n-procs', type=int, default=20, 
                         help='number of threads to parallelize over')
    parse_b.add_argument('--collapse', action='store_true', default=False, 
                         help='collapse mutations to C or T using reverse complementation')
    parse_b.set_defaults(func=annotateMutationFile)

    parse_c = subparser.add_parser('countMutationContext', 
                                   help='count the number of occurences of mutations by context in a genome')
    parse_c.add_argument('mapDict', type=str, help='path to mappable regions file')
    parse_c.add_argument('fmut', type=str, help='path to mutation file')
    parse_c.add_argument('h5genome', type=str, 
                         help='h5 file containing genome-wide context counts (from countGenomeContext)')
    parse_c.add_argument('window', type=int, help='window size of mappable regions')
    parse_c.add_argument('keyPrefix', type=str, help='Prefix for mutation count key in h5 file')
    parse_c.add_argument('--n-procs', type=int, default=20, 
                         help='number of threads to parallelize over')
    parse_c.add_argument('--collapse', action='store_true', default=False, 
                         help='collapse mutations to C or T using reverse complementation')
    # parse_c.add_argument('fout', type=str, help='output h5 file name (same as for countGenomeContext')
    # parse_c.add_argument('--up', type=int, default=2, 
    #                      help='number of bases upstream of position to include as context')
    # parse_c.add_argument('--down', type=int, default=2, 
    #                      help='number of bases downstream of position to include as context')
    parse_c.set_defaults(func=countMutationContext)

    parse_d = subparser.add_parser('applySequenceModel', 
                                   help='count the number of occurences of mutations by context in a genome')
    # parse_d.add_argument('N', type=int, help='number of individuals in dataset')
    parse_d.add_argument('GPresults', type=str, help='path to h5 file with NN+GP results')
    parse_d.add_argument('fmodel', type=str, help='path to h5 file of context counts (mutation & genome)')
    parse_d.add_argument('fmut', type=str, help='path to mutation file')
    parse_d.add_argument('fasta', type=str, help='path to fasta file')
    parse_d.add_argument('window', type=int, help='window size of CNN+GP analysis')
    # parse_d.add_argument('--up', type=int, default=2, 
    #                      help='number of bases upstream of position to include as context. MAKE SURE THIS MATCHEST THE SPECIFICATION OF fmodel!!')
    # parse_d.add_argument('--down', type=int, default=2, 
    #                      help='number of bases downstream of position to include as context. MAKE SURE THIS MATCHES THE SPECIFICATION OF fmodel!!')
    parse_d.add_argument('--cancer', type=str, default='', help='Cancer dataset name in gp results file')
    parse_d.add_argument('--key-prefix', type=str, default='', help='mutation count prefix in model file')
    parse_d.add_argument('--key', type=str, default='', 
                         help='Key to use for saving results to h5 file')
    parse_d.add_argument('--run', default='None', 
                         help='which GP results to use. Picked automatically if not supplied.')
    parse_d.add_argument('--bins', nargs='+', type=int,
                         default=[50, 100, 1000, 2000, 5000, 10000, 25000, 50000, 100000, 1000000],
                         help='list of binsizes to use, separated by spaces')
    parse_d.add_argument('--binsize', type=int, default=50, 
                         help='IGNORED')
    parse_d.add_argument('--n-procs', type=int, default=20, 
                         help='number of threads to parallelize over')
    parse_d.add_argument('--dataset', default='test', 
                         help='name of datasets to operate over (test or held-out)')
    parse_d.add_argument('--collapse', action='store_true', default=False, 
                         help='collapse mutations to C or T using reverse complementation')
    parse_d.set_defaults(func=applySequenceModel)


    # parse_e = subparser.add_parser('batchSequenceModel', 
    #                                help='count the number of occurences of mutations by context in a genome')
    # parse_e.add_argument('GPresults', type=str, help='path to h5 file with NN+GP results')
    # parse_e.add_argument('fmodel', type=str, help='path to h5 file of context counts (mutation & genome)')
    # parse_e.add_argument('fmut', type=str, help='path to mutation file')
    # parse_e.add_argument('fasta', type=str, help='path to fasta file')
    # parse_e.add_argument('--cancer', type=str, default='', help='Cancer dataset name in gp results file')
    # parse_e.add_argument('--key-prefix', type=str, default='', help='mutation count prefix in model file')
    # parse_e.add_argument('--key', type=str, default='', 
    #                      help='Key to use for saving results to h5 file')
    # parse_e.add_argument('--run', default=None, 
    #                      help='GP run to use or ensemble. Picked automatically if not supplied.')
    # parse_e.add_argument('--binsize', type=int, default=50, 
    #                      help='IGNORED')
    # parse_e.add_argument('--n-procs', type=int, default=20, 
    #                      help='number of threads to parallelize over')
    # parse_e.set_defaults(func=batchSequenceModel)

    if text:
        args = parser.parse_args(text.split())
    else:
        args = parser.parse_args()

    return args

if __name__ == "__main__":
    args = parse_args()
    args.func(args)
