import pandas as pd
import numpy as np
import torch
import os
import joblib
import ujson
import json
from glob import glob
from preprocessing_utils import *
from argparse import ArgumentParser
from argparse import ArgumentDefaultsHelpFormatter
from multiprocessing import Pool
from itertools import product
from tqdm import tqdm


drach_motifs = [['A', 'G', 'T'], ['G', 'A'], ['A'], ['C'], ['A', 'C', 'T']]
drach_kmers = set(["".join(x) for x in product(*(drach_motifs))])


def get_drach_positions(tx):
    drach_positions = []
    for i in range(len(tx) - 5):
        if tx[i:i+5] in drach_kmers:
            drach_positions.append(i + 2)
    return np.array(drach_positions)


def get_drach_index(fasta_dict, transcripts):
    return {tx.split(".")[0]: get_drach_positions(fasta_dict[tx.split(".")[0]]) for tx in transcripts}

def check_for_drach(start, end, drach_pos):
    return (start<= drach_pos) and (drach_pos <= end)

def index_transcripts(transcripts, drach_pos_per_tx, rstarts, rends):
    tx_index = {}
    pos_dict = {}
    for tx, pos in drach_pos_per_tx.items():
        pos_dict[tx] = {pos: [] for pos in drach_pos_per_tx[tx]}
        
    for i in range(len(transcripts)):
        tx = transcripts[i]
        drach_pos = drach_pos_per_tx[tx]
        
        if tx not in tx_index:
            tx_index[tx] = []
            
        start, end = rstarts[i], rends[i]
        pos_masks = [check_for_drach(start, end, pos) for pos in drach_pos]
        if np.any(pos_masks):
            for pos in drach_pos[pos_masks]:
                pos_dict[tx][pos].append(i)
                
    return pos_dict


def t2g(tx_id, fasta_dict, gtf_dict):
    t2g_dict = {}
    tx_seq = fasta_dict[tx_id]

    if tx_seq is None:
        return [], []

    for exon_num in range(len(gtf_dict[tx_id]['exon'])):
        g_interval = gtf_dict[tx_id]['exon'][exon_num]
        tx_interval = gtf_dict[tx_id]['tx_exon'][exon_num]
        for g_pos in range(g_interval[0], g_interval[1] + 1): # Exclude the rims of exons.
            dis_from_start = g_pos - g_interval[0]
            if gtf_dict[tx_id]['strand'] == "+":
                tx_pos = tx_interval[0] + dis_from_start
            elif gtf_dict[tx_id]['strand'] == "-":
                tx_pos = tx_interval[1] - dis_from_start

            t2g_dict[tx_pos] = g_pos # tx.contig is chromosome.
    return t2g_dict

def get_modified_positions(tx, positions, t2g_dict, gtf_dict, modified_coords):
    chr_ = gtf_dict[tx]['chr']
    g_positions = [t2g_dict[pos] for pos in positions]
    return [pos for pos, g_pos in zip(positions, g_positions)
            if ((g_pos, chr_) in modified_coords)]


nucleotide_dict = {'A': 1, 'C': 2, 'G': 3, 'T': 4}

def assign_code(fasta_seq):
    return np.array([nucleotide_dict[x] for x in fasta_seq])

def combine_data(task):
    signal, seq, seq_length, positions, pos_lengths, tx, save_dir, drach_positions = task
    arr = np.concatenate([signal, seq, seq_length.reshape(-1, 1), positions, 
                          pos_lengths.reshape(-1, 1)], axis=1)
    
    pos_dict = {tx: {}}
    for i in range(len(signal)):
        pos = positions[i][:pos_lengths[i]]
        common_pos = np.intersect1d(pos, drach_positions)
        if len(common_pos) > 0:
            for sub_pos in common_pos:
                if sub_pos not in pos_dict[tx]:
                    pos_dict[tx][sub_pos] = []
                    
                pos_dict[tx][sub_pos].append(i)
            
    np.save(os.path.join(save_dir, tx + ".npy"), arr)
    return pos_dict

def index_tx(all_tx):
    tx_index = {}
    for i in range(len(all_tx)):
        tx = all_tx[i].split(".")[0]
        if tx not in tx_index:
            tx_index[tx] = []
        tx_index[tx].append(i)
    return tx_index


def get_args_parser():
    parser = ArgumentParser('prepare_m6araw_training',
        formatter_class=ArgumentDefaultsHelpFormatter,
        add_help=False
    )

    parser.add_argument("--root_dir", default=None, required=True)
    parser.add_argument("--fasta_path", default=None, required=True)
    parser.add_argument("--gtf_path", default=None, required=True)
    parser.add_argument("--annot_path", default=None, required=True)
    parser.add_argument("--save_dir", default=None, required=True)
    parser.add_argument("--sequence_context", default=10, type=int)

    return parser

def preprocess(args):
    root_dir = args.root_dir
    save_dir = args.save_dir

    fasta_dict = readFasta(args.fasta_path)
    gtf_dict = readGTF(args.gtf_path)

    signals = np.load(os.path.join(root_dir, "chunks.npy"))
    seqs = np.load(os.path.join(root_dir, "references.npy"))
    lengths = np.load(os.path.join(root_dir, "reference_lengths.npy"))
    transcripts = np.load(os.path.join(root_dir, "transcripts.npy"))
    positions = np.load(os.path.join(root_dir, "positions.npy"))
    positions_lengths = np.load(os.path.join(root_dir, "position_lengths.npy"))
    
    fasta_dict = {tx.split(".")[0]: seq for tx, seq in fasta_dict.items()}
    t2g_dict = {tx.split(".")[0]: t2g(tx.split(".")[0], fasta_dict, gtf_dict) for tx in np.unique(transcripts)}
    annot_cords = set([x for x in pd.read_csv(args.annot_path).set_index(["End", "Chr"]).index])
    
    metadata = {'signal': [0, signals.shape[-1]],
                'seq': [signals.shape[-1], signals.shape[-1] + seqs.shape[-1]],
                'seq_length': [signals.shape[-1] + seqs.shape[-1], 
                            signals.shape[-1] + seqs.shape[-1] + 1],
                'positions': [signals.shape[-1] + seqs.shape[-1] + 1,
                            signals.shape[-1] + seqs.shape[-1] + 1 + positions.shape[-1]],
                'positions_length':[signals.shape[-1] + seqs.shape[-1] + 1 + positions.shape[-1],
                                    signals.shape[-1] + seqs.shape[-1] + 1 + positions.shape[-1] + 1]}

    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    joblib.dump(metadata, os.path.join(save_dir, "metadata.joblib"))

    def save_split_data(tx_index, drach_index, save_dir):
        tasks = ((signals[indices], seqs[indices], lengths[indices],
                  positions[indices], positions_lengths[indices],
                  tx, save_dir, drach_index[tx])
                  for tx, indices in tx_index.items())
        tx_dict = {}
        with Pool(25) as p:
            for pos_dict in tqdm(p.imap_unordered(combine_data, tasks), total=len(tx_index)):
                tx_dict.update(pos_dict)
        return tx_dict

    tx_index = index_tx(transcripts)
    drach_index = get_drach_index(fasta_dict, np.unique(transcripts))
    tx_positions = save_split_data(tx_index, drach_index, save_dir)

    index_df = []
    with open(os.path.join(save_dir, "tx_index.json"), 'a') as f:
        for tx, pos_dict in tx_positions.items():
            pos_start = f.tell()
            f.write('{')
            f.write('"%s":' %tx)
            ujson.dump(pos_dict, f)
            f.write("}\n")
            pos_end = f.tell()
            index_df += [[tx, pos_start, pos_end]]
    index_df = pd.DataFrame(index_df, columns=['tx_id', 'start_pos', 'end_pos'])
    index_df.to_csv(os.path.join(save_dir, "data.index"), index=False)

    info_df = []
    sequence_context = args.sequence_context
        
    for tx, pos_dict in tx_positions.items():
        gene_id = gtf_dict[tx]['g_id']
        chr_ = gtf_dict[tx]['chr']
        tx_seq = fasta_dict[tx]
        for pos, indices in pos_dict.items():
            if (pos < sequence_context) or (len(tx_seq) - pos < sequence_context + 1):
                continue
            else:
                genomic_position = t2g_dict[tx][int(pos)]
                label = ((chr_, genomic_position) in annot_cords) * 1
                seq = tx_seq[pos-sequence_context: pos + sequence_context + 1]
                info_df += [(gene_id, genomic_position, tx, pos, label, len(indices), seq)]
    info_df = pd.DataFrame(info_df, columns=["gene_id", "genomic_position", "tx_id", "tx_position",
                                            "modification_status", "n_reads", "sequence"])
    info_df.to_csv(os.path.join(save_dir, "data.info"),index=False)

if __name__ == '__main__':
    parser = ArgumentParser('prepare_m6araw_training', parents=[get_args_parser()])
    args = parser.parse_args()
    preprocess(args)
