import pandas as pd
import numpy as np
import pysam
import os
from glob import glob
from preprocessing_utils import *
from argparse import ArgumentParser
from argparse import ArgumentDefaultsHelpFormatter
from multiprocessing import Pool


def get_args_parser():
    parser = ArgumentParser('prepare_basecalling_training',
        formatter_class=ArgumentDefaultsHelpFormatter,
        add_help=False
    )

    parser.add_argument("--fast5_dir", default=None, required=True)
    parser.add_argument("--bamfile", default=None, required=True)
    parser.add_argument("--tx_info", default=None, required=True)
    parser.add_argument("--fasta_path", default=None, required=True)
    parser.add_argument("--nanopolish_eventalign_index", default=None, required=True)
    parser.add_argument("--nanopolish_summary_fpath", default=None, required=True)
    parser.add_argument("--min_read_counts", default=5, type=int)
    parser.add_argument("--max_read_counts", default=100, type=int)
    parser.add_argument("--n_processes", default=25, type=int)
    parser.add_argument("--signal_length", default=5200, type=int)
    parser.add_argument("--nanopolish_eventalign_fpath", default=None, required=True)
    parser.add_argument("--save_dir", default=None, required=True)
    return parser


def create_tasks(read_info, transcripts, fpath_dict, index_df, eventalign_filepath, fasta_dict, signal_length):

    for tx, reads in read_info.items():
        if tx.split(".")[0] in transcripts:
            for read_id in reads:
                tx_info = index_df.loc[(tx, read_id)]

                start = tx_info["pos_start"]
                end = tx_info["pos_end"]
                if len(tx_info.shape) > 1:
                    max_idx = np.argmax(end.astype('int') - start.astype('int'))
                    start, end = start[max_idx], end[max_idx]
                
                yield tx, read_id, fpath_dict[read_id], start, end, eventalign_filepath, fasta_dict[tx], signal_length


def extract_signal_info(read_info, transcripts, fpath_dict, index_df, eventalign_filepath, fasta_dict, signal_length):

        total_tasks = 0
        for tx, reads in read_info.items():
            if tx.split(".")[0] in transcripts:
                for _ in reads:
                    total_tasks += 1

        all_signals = []
        all_positions = []
        all_sequences = []
        all_segments = []
        all_tx = []
        
        with Pool(args.n_processes) as p:
            for x in tqdm(p.imap(get_signal, create_tasks(read_info, transcripts, fpath_dict, index_df, eventalign_filepath, fasta_dict, signal_length)), total=total_tasks):
                signals, positions, segments, sequences, tx = x
                if len(signals) > 0:
                    all_signals.append(signals)
                    all_positions.extend(positions)
                    all_segments.extend(segments)
                    all_sequences.extend(sequences)
                    all_tx.extend(tx)
        all_signals = np.concatenate(all_signals)
        return all_signals, all_positions, all_segments, all_sequences, all_tx


def get_signal(task):
    tx, read_id, fpath, start, end, eventalign_filepath, seq, signal_length = task

    eventalign_result = read_event_align(eventalign_filepath, start, end)
    eventalign_result, strand = combine_eventalign_result(eventalign_result)

    if strand == '+':
        start_idx, end_idx = eventalign_result["start_idx"].values[-1], eventalign_result["end_idx"].values[0]
    else:
        start_idx, end_idx = eventalign_result["start_idx"].values[0], eventalign_result["end_idx"].values[-1]
        
    signal = get_read(read_id, get_fast5_file(fpath))[start_idx:end_idx]
    
    segments = eventalign_result["segment_length"].values
    positions = eventalign_result["position"].values

    if strand == "+":
        positions = positions[::-1]
        segments = segments[::-1]
    
    
    interval_indices = np.cumsum(segments)
    start_indices = np.arange(0, len(signal), signal_length)
    end_indices = start_indices + signal_length
    start_positions, end_positions = np.searchsorted(interval_indices, [start_indices, end_indices])
    
    all_signals = []
    all_segments = []
    all_seqs = []
    all_positions = []
    for start_pos, end_pos, start_idx, end_idx in zip(start_positions, end_positions, start_indices, end_indices):
        end_pos = min(end_pos, len(positions) - 1)
        sub_positions, sub_segments, sub_seqs = extract_segment_and_positions(start_pos, end_pos, 
                                                                            segments, positions, strand, seq)
        sub_signal = signal[start_idx: min(end_idx, len(signal))]
        
        if len(sub_signal) < signal_length:
            sub_signal = np.pad(sub_signal, (0, signal_length - len(sub_signal)), 'constant')
            
        all_signals.append(sub_signal)
        all_segments.append(sub_segments)
        all_seqs.append(sub_seqs)
        all_positions.append(sub_positions)
        
    return np.stack(all_signals), all_positions, all_segments, all_seqs, np.repeat(tx, len(all_signals))


def preprocess(args):

    bamfile = args.bamfile
    fasta_dict = readFasta(args.fasta_path)

    samfile = pysam.AlignmentFile(bamfile, "rb")
    fast5_dir = args.fast5_dir
    fast5_files = glob(os.path.join(fast5_dir, "**/*.fast5"), recursive=True)

    min_read_counts = args.min_read_counts
    max_reads = args.max_read_counts

    index_df = pd.read_csv(args.nanopolish_eventalign_index).set_index(["transcript_id"])
    summary_df = pd.read_csv(args.nanopolish_summary_fpath, delimiter="\t").set_index("read_index")["read_name"].to_dict()
    index_df["read_name"] = index_df.read_index.apply(lambda x: summary_df[x])

    valid_read_ids = set(index_df["read_name"].values)

    read_info, allowed_reads = extract_read_info(samfile, min_read_counts, max_reads, fasta_dict, valid_read_ids)

    print("Preprocessing {} eligible transcripts".format(len(read_info.keys())))
    fpath_dict = retrieve_fpath(fast5_files, allowed_reads)

    eventalign_filepath = args.nanopolish_eventalign_fpath
    index_df = index_df.reset_index().set_index(["transcript_id", "read_name"])
    index_df = index_df.sort_index()

    tx_info = pd.read_csv(args.tx_info)
    train_transcripts = set(tx_info[tx_info["set_type"] == 'Train']["tx_id"].apply(lambda x: x.split(".")[0]).unique())
    val_transcripts = set(tx_info[tx_info["set_type"] == 'Val']["tx_id"].apply(lambda x: x.split(".")[0]).unique())
    test_transcripts = set(tx_info[tx_info["set_type"] == 'Test']["tx_id"].apply(lambda x: x.split(".")[0]).unique())

    save_dir = args.save_dir
    
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    for transcripts, set_type in zip([train_transcripts, val_transcripts, test_transcripts],
                            ['train', 'validation', 'test']):
        all_signals, all_positions, all_segments, all_sequences, all_tx = extract_signal_info(read_info, transcripts, fpath_dict, index_df, eventalign_filepath, fasta_dict, args.signal_length)

        all_positions_length = np.array([len(pos) for pos in all_positions])
        all_seq_lengths = np.array([len(seq) for seq in all_sequences])
        
        max_pos_length = np.max(all_positions_length)
        max_seq_length = np.max(all_seq_lengths)

        all_positions = np.concatenate([np.pad(pos, (0, max_pos_length - len(pos)), 'constant').reshape(1, -1)
                                        for pos in all_positions])
        all_segments = np.concatenate([np.pad(seg, (0, max_pos_length - len(seg)), 'constant').reshape(1, -1)
                                        for seg in all_segments])
        all_sequences = np.concatenate([np.pad(seq, (0, max_seq_length - len(seq)), 'constant').reshape(1, -1)
                                        for seq in all_sequences])
        all_tx = np.array(all_tx)
        
        if set_type == 'train':
            save_path = save_dir
        else:
            save_path = os.path.join(save_dir, set_type)

            if not os.path.exists(save_path):
                os.mkdir(save_path)

        np.save(os.path.join(save_path, "chunks.npy"), all_signals)
        np.save(os.path.join(save_path, "reference_lengths.npy"), all_seq_lengths)
        np.save(os.path.join(save_path, "references.npy"), all_sequences)
        np.save(os.path.join(save_path, "positions.npy"), all_positions)
        np.save(os.path.join(save_path, "position_lengths.npy"), all_positions_length)
        np.save(os.path.join(save_path, "segments.npy"), all_segments)
        np.save(os.path.join(save_path, "transcripts.npy"), all_tx)

        del all_signals
        del all_seq_lengths
        del all_sequences
        del all_positions
        del all_positions_length
        del all_segments
        del all_tx


if __name__ == '__main__':
    parser = ArgumentParser('prepare basecalling', parents=[get_args_parser()])
    args = parser.parse_args()
    preprocess(args)
