import os
import pandas as pd
import numpy as np
import collections
import logging
import math
import unicodedata
import itertools
import random
import argparse
import pysam
import csv
from os.path import join
import h5py
import sys

def writetsv(data, label, savefile):
    with open(savefile, 'w') as f:
        f.write('sequence\tlabel\n')
        for item1, item2 in zip(data, label):
            f.write(f'{item1}\t{item2}\n')

    f.close()

def writehdf5(data, savefile):

    with h5py.File(savefile, 'w') as f:
        for key, value in data.items():
            f.create_dataset(key, data=np.array(value, dtype='S')) 
    f.close()

def random_split(tokens, maxlen, tolerance = 0.5):
    pass 

def nonoverlap_split(tokens, maxlen, tolerance = 0.5):
    seqs = []
    i = 0
    j = 0
    tokens_N = (tokens == 'N')
    num_windows = len(tokens) // maxlen

    for i in range(num_windows):

        seq = tokens[i*maxlen:i*maxlen+maxlen]
        if args.throw_N: 
            if 'N' not in seq: 
                seqs.append(" ".join(seq))
            else: 
                j = j + 1

        else:
            if np.sum(tokens_N[i*maxlen:i*maxlen+maxlen]) / maxlen < tolerance: 
                seqs.append(" ".join(seq))
            else:
                j = j + 1

    print('In this chromosome, the number of skipped sequences is: ', j)
    return seqs

def main(): 

    genome = pysam.FastaFile(args.fasta)

    data = []
    labels = []

    for chrm in genome.references:
        if chrm.startswith('chr'):
            print('split on ' + chrm)

            tokens = []
            f = join(args.token_path, chrm + '_tokens.txt')
            with open(f, "r") as f:
                content = " ".join(line.strip() for line in f)
            tokens = np.array(content.split(' '))

            if args.random_split: 
                seqs = random_split(tokens, args.maxlen, args.tolerance)
            elif args.nonoverlap_split:
                seqs = nonoverlap_split(tokens, args.maxlen, args.tolerance)
            data.extend(seqs)
            labels.extend([chrm] * len(seqs))

    # split numbers:
    print('total number of sequences:', len(data))
    train_num = int(len(data) * 0.9)
    # shuffle 
    combined = list(zip(data, labels))
    random.seed(42) 
    random.shuffle(combined)
    shuffle_data, shuffle_labels = zip(*combined)

    writetsv(shuffle_data[:train_num], shuffle_labels[:train_num], join(args.output_path, 'all_tokenized_train.tsv'))
    writetsv(shuffle_data[train_num:], shuffle_labels[train_num:], join(args.output_path, 'all_tokenized_val.tsv'))


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description="Tokenizer",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--token_path', type=str, required=True)
    parser.add_argument('--fasta', type=str, required=True, help='genome to be tokenized')
    parser.add_argument('--maxlen', type=int, required=True, help='number of tokens')
    parser.add_argument('--random_split', action="store_true", help='sequence splitting method')
    parser.add_argument('--nonoverlap_split', action="store_true", help='sequence splitting method')
    parser.add_argument('--tolerance', type=float, required=True, help='maximum N percentage')
    parser.add_argument('--throw_N', action="store_true", help='throw segments with N or not?')
    parser.add_argument('--output_path', type=str, required=True)

    args = parser.parse_args()

    if not args.random_split and not args.nonoverlap_split:
        raise ValueError("sequence split method required")

    # Open the log file for both stdout and stderr
    log_file = open(join(args.output_path, 'skip.log'), 'w')
    sys.stdout = log_file
    sys.stderr = log_file
    
    main()
