import os
import pandas as pd
import numpy as np
import collections
import logging
import math
import unicodedata
import itertools
import random
import argparse
import pysam
import csv
from os.path import join
import h5py
import Bio
import collections
import multiprocessing
from multiprocessing import Pool
import time
import subprocess
import pickle


class TrieNode:
    def __init__(self): 
        self.children = {}
        self.is_end_of_word = False 
        self.features = [] 


class Trie:
    def __init__(self):
        self.root = TrieNode()
        self.lookup_table = {}
    def insert(self, word, features = None):
        current_node = self.root
        for char in word:
            if char not in current_node.children:
                current_node.children[char] = TrieNode()
            current_node = current_node.children[char]
        current_node.is_end_of_word = True
        if features: 
            current_node.features.append(features)
    def print_trie(self, node=None, prefix="", level=0):
        if node is None:
            node = self.root
        for char, child_node in node.children.items():
            print("  " * level + "'{}'{}".format(char, " (end)" if child_node.is_end_of_word else ""))
            self.print_trie(child_node, prefix + char, level + 1)
    def search(self, word):
        current_node = self.root
        for char in word: 
            if char not in current_node.children: 
                return False         # Word not found
            current_node = current_node.children[char]
        if current_node.is_end_of_word:
            if len(current_node.features) > 0: 
                return current_node.features
            else:
                return True
        return False                 # Word not found

def save_str_to_Trie(strings, features = None):
    trie = Trie()

    if features:
        for word, feature in zip(strings, features):
            trie.insert(word, feature)

    else: 
        for word in strings:
            trie.insert(word)

    return trie

def save_trie_to_file(trie, filename):
    with open(filename, 'wb') as file:
        pickle.dump(trie, file)

def load_trie_from_file(filename):
    with open(filename, 'rb') as file:
        return pickle.load(file)


def writetxt(tokens, coordinates, names, savefile):
    with open(savefile, 'w') as f:
        for token, coord, name in zip(tokens, coordinates, names):
            f.write(f"{token} {coord} {name}\n")  # Tab-separated columns
    f.close()

def read_pwm(file):
    with open(file, 'r') as f:
        lines = f.readlines()
        motif_name = lines[0].strip()[1:]
        matrix = np.array([list(map(float, line.strip().split())) for line in lines[1:]])

    return motif_name, matrix


def motif2tokens(motif_path, threshold):
    # The 4 columns in pwm follows the order: A 0.25 C 0.25 G 0.25 T 0.25
    nt = ['A', 'C', 'G', 'T']
    nt_comp = ['T', 'G', 'C', 'A']
    motifs = []
    motifs_reversecomp = []
    names = []
    names_reversecomp = []

    for file in os.listdir(motif_path):
        if file.endswith('pwm'):

            motif_name, matrix = read_pwm(join(motif_path, file))

            # Step 1: only consider motifs shorter than 13bp
            if matrix.shape[0] > 12: 
                continue 

            # Step 2: remove the wildcards at both ends 
            tmp = (matrix > threshold).astype(int)
            non_empty_rows = np.where(tmp > 0)[0]
            matrix_trimmed = matrix[non_empty_rows[0]:non_empty_rows[-1] + 1]
    
            mot = ''.join([nt[index] for index in np.argmax(matrix_trimmed, axis = 1)])
            mot_comp = ''.join([nt_comp[index] for index in np.argmax(matrix_trimmed, axis = 1)])

            t1 = list(mot)
            t2 = list(mot_comp)
            
            motifs.append(''.join(t1))
            motifs_reversecomp.append(''.join(t2)[::-1]) 
            names.append(motif_name + '.+.')
            names_reversecomp.append(motif_name + '.-.')

    return motifs, motifs_reversecomp, names, names_reversecomp


def create_vocabulary(motif_path, threshold, savepath):
    '''
    Include N, 1mer(A/T/C/G), 3-mer, special tokens 
    '''

    vocab = []
    # 1-mer 
    bases = ['A', 'T', 'C', 'G']

    # common special tokens
    token_spec = ['N', '[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]']

    # 3-mer 
    combinations = list(itertools.product(bases, repeat=3))
    token_3mer = [''.join(term) for term in combinations]

    # convert motifs to fixed sequence, and its reverse complement
    token_motifs, token_motifs_reversecomp, names, names_reversecomp = motif2tokens(motif_path, threshold)

    # Save motifs as motifs.txt
    motifs = token_motifs + token_motifs_reversecomp
    print('Size of motifs:', len(motifs))

    motifs_names = names + names_reversecomp 

    with open(os.path.join(savepath, 'motifs.txt'), "w") as f:
        for motif, name in zip(motifs, motifs_names):
            f.write(f"{motif} {name}\n")  # Writing in two columns (tab-separated)

    motifs_file = os.path.join(savepath, "motifs.txt")
    motifs_dedup = os.path.join(savepath, "motifs_dedup.txt")    # Run the awk command
    awk_command = f"""awk '{{segments[$1] = segments[$1] ? segments[$1] "," $2 : $2}} END {{for (seg in segments) print seg, segments[seg]}}' {motifs_file} > {motifs_dedup}"""
    subprocess.run(awk_command, shell=True, check=True)
    print('Size of motifs after removing duplicates:', len(set(motifs)))


    vocab = bases + token_spec + token_3mer + token_motifs + token_motifs_reversecomp
    print('Size of entire vocabulary:', len(vocab))
    # Save all as vocab.txt
    with open(os.path.join(savepath, 'vocab.txt'), "w") as f:
        for word in vocab:
            f.write(f"{word}\n")

    # remove duplicates!
    vocab_dedup = list(set(vocab))
    print('Size of entire vocabulary after removing duplicates: ', len(vocab_dedup))

    with open(os.path.join(savepath, 'vocab_dedup.txt'), "w") as f:
        for word in vocab_dedup:
            f.write(f"{word}\n")

    # look up table for base tokens and all possible combinations: 
    # rule: only one operation allowed each time for each motif 
    motifs_uniq = list(set(motifs))
    motifs_hardcode = motifs_uniq

    with open(os.path.join(savepath, 'motifs_hardcode.txt'), "w") as f:
        for motif in motifs_hardcode:
            f.write(f"{motif}\n")  # Writing in two columns (tab-separated)
    print('Size of hardcoded motifs:', len(motifs_hardcode))

    motifs_hardcode_trie = save_str_to_Trie(motifs_hardcode)

    save_trie_to_file(motifs_hardcode_trie, os.path.join(savepath, 'motifs_hardcode_trie.pkl'))


def tokenize(seg, i, maxlen, motif_hardcoded_trie, k3, k1, lookup_table):
    '''
    Parameters: 
        seg: a sequence chunk from the chromosome
        i: the start position at this segment
        maxlen: the longest distance considered to find motif, should be the longest word in vocabulary
    '''

    score = 0
    t = []

    best_token = None
    best_score = -float('inf')

    # for l in range(maxlen, 3, -1):

    best_token_candidates = []
    for l in range(4, maxlen + 1):
        segment = seg[i:i+l]
        if motif_hardcoded_trie.search(segment): 
            best_token_candidates.append(segment)
    
    if len(best_token_candidates) > 0: 
    
        best_token = [random.choice(best_token_candidates)]
        best_score = 1 * len(best_token[0])
    # if cannot find motifs, tokenize with 3mer then 1mer
    if best_token == None: 

        for l in range(3, 0, -1):

            segment = seg[i:i+l]
            
            if segment in k3: 
                best_token = [segment]
                best_score = 3 
                break

            if segment in k1: 
                best_token = [segment]
                best_score = 1 


    name = lookup_table.get(best_token[0].split()[0], '-')   # '-' represent the given name for non-motif tokens   
    next_pos = i + len(best_token[0].split()[0])

    return best_token[0], name, best_score, next_pos


def tokenize_chunk(seg, vocabs, maxlen, chrmname, start_position):

    motif_hardcoded_trie = load_trie_from_file(join(args.savepath, 'motifs_hardcode_trie.pkl'))
    k1 = ['A', 'T', 'C', 'G', 'N']
    # 3-mer 
    combinations = list(itertools.product(['A', 'T', 'C', 'G'], repeat=3))
    k3 = [''.join(term) for term in combinations]

    lookup_table = {}
    with open(join(args.savepath, "motifs_dedup.txt"), "r") as file:
        for line in file:
            segment, name = line.strip().split(maxsplit=1)  # Split only on the first space     
            lookup_table[segment] = name  # Store in dictionary

    i = 0       # start position
    tokens = []
    names = []
    coordinates = []

    t = []
    while i < len(seg):

        t = []

        best_token, best_name, best_score, next_pos = tokenize(seg, i, maxlen, motif_hardcoded_trie, k3, k1, lookup_table)
        best_i = i

        _curr_token = best_token
        offsets = []

        if len(_curr_token) > 1:    # our token only has length 1, 3, >=5, no length at 2
            offsets = [1, 2]

        if offsets: 
            for shift in offsets: 
                i_shifted = i + shift
                if i_shifted < len(seg):
                    token_, name_, score_, next_pos_ = tokenize(seg, i_shifted, maxlen, motif_hardcoded_trie, k3, k1, lookup_table)
                    best_token, best_name, best_i, next_pos, best_score = max([(best_token, best_name, best_i, next_pos, best_score), (token_, name_, i_shifted, next_pos_, score_ )], key=lambda x: x[4])

        for skip in range(best_i - i):
            tokens.append(seg[i + skip])
            names.append('-')
            coordinates.append(chrmname + ':' + str(start_position + i + skip) + '-' + str(start_position + i + skip + 1)) 

        coordinate = chrmname + ':' + str(start_position + best_i) + '-' + str(min(start_position + next_pos, start_position + len(seg)))
        tokens.append(best_token)
        names.append(best_name)
        coordinates.append(coordinate)

        i = next_pos

    return tokens, coordinates, names

def tokenize_chrm(chrmname, vocabs): 
    # non-overlapping
    genome = pysam.FastaFile(args.fasta)
    chrm = genome.fetch(reference=chrmname)
    maxlen = len(vocabs[0])
    chunk_size = 2**22
    chunks = [chrm[i:i + chunk_size] for i in range(0, len(chrm), chunk_size)]
    start_positions = [i for i in range(0, len(chrm), chunk_size)]

    num_cpus = min(args.num_processor, multiprocessing.cpu_count())  # Use up to 8 CPUs
    print('number of cpus:', num_cpus)

    with Pool(processes=num_cpus) as pool:
        print('length of chromosome', len(chrm))
        print('number of chunks:', len(chunks))
        results = pool.starmap(tokenize_chunk, [(chunk , vocabs, maxlen, chrmname, start) for chunk, start in zip(chunks, start_positions)])     # result will be a tuple
    
    # Flatten the list of tokens
    tokens, coordinates, names = zip(*results)
    tokens = [token for sublist in tokens for token in sublist]
    coordinates = [coord for sublist in coordinates for coord in sublist]
    names = [name for sublist in names for name in sublist]

    print('sequence original length: ', len(chrm))
    print('the last token coordinate:', coordinates[-1])
    return tokens, coordinates, names

    
def tokenize_all(fasta, savepath):

    # Tokenize the given genome
    genome = pysam.FastaFile(args.fasta)
    v = pd.read_csv(join(savepath, 'vocab_dedup.txt'), header = None, names = ['column'])
    v_sorted = v.sort_values(by='column', key=lambda col: col.str.len(), ascending=False)
    
    # 32GB memory is enough 
    for chrm in genome.references:
        if chrm.startswith('chr') and len(chrm) <= 5:
            print('tokenizing ' + chrm + ' ... ')
            chr_token = {}

            tokens, coordinates, names = tokenize_chrm(chrm, list(v_sorted['column']))
            writetxt(tokens, coordinates, names, join(savepath, chrm + '_tokenized.txt'))


def main(): 

    # create our vocabulary, save as 'vocab.txt'
    start_time = time.time()
    create_vocabulary(args.motif_path, args.threshold, args.savepath)
    end_time = time.time()
    print(f"Time for creating vocab.txt: {end_time - start_time:.4f} seconds")

    # tokenize the genome
    start_time = time.time()
    tokenize_all(args.fasta, args.savepath)
    end_time = time.time()
    print(f"Time for tokenizing whole genome: {end_time - start_time:.4f} seconds")
    
if __name__ == '__main__':

    parser = argparse.ArgumentParser(description="Tokenizer",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--motif_path', type=str, required=True, help='output path for saving vobaculary as .txt file, default file name: vocab.txt')
    parser.add_argument('--threshold', type=float, required=True, help='probability threshold for position probability matrix')
    parser.add_argument('--fasta', type=str, required=True, help='genome to be tokenized')
    parser.add_argument('--savepath', type=str, required=False, help='output path for saving vobaculary as .txt file, default file name: vocab.txt')
    parser.add_argument('--num_processor', type=int, required=False, default = 32)
    args = parser.parse_args()
    os.makedirs(args.savepath, exist_ok = True)

    main()
