import torch
import AMP_Dataset as Data
import Encoder_Decoder as ED
from params import vocab_in, vocab_out
import pyfastx
import time
import argparse
import os


def f_check(path):
    if not os.path.exists(path):
        os.makedirs(path)


def write_fa(path, dic, style):
    with open(path, style, encoding='utf-8') as file:
        for key, value in dic.items():
            file.write(f'>{key}\n')
            file.write(f'{value}\n')


def write_csv(path, l_id, l_src, l_pre, style):
    with open(path, style, encoding='utf-8') as file:
        for i, item in enumerate(l_id):
            file.write(f'{l_id[i]},{l_src[i]},{l_pre[l_id[i]]}\n')


def getkey(vocab, value_need):
    key = [key for key, value in vocab.items() if value == value_need]
    return key


if __name__ == '__main__':
    # settings
    start_time = time.time()
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--input', default=None, type=str, help='Input file (xx.fasta)')
    parser.add_argument('-o', '--output', default=None, type=str, help='The output dir (end of /)')
    parser.add_argument('-l', '--max_len', default=200, type=int, help='The max length of the generated sequence!')
    parser.add_argument('-m', '--path_model', default=None, type=str)
    parser.add_argument('-n', "--file_name", default='model_pred_seq', type=str)
    parser.add_argument('-a', '--layers', default=2, type=int, help='number of layers')
    args = parser.parse_args()

    # path
    path_fa = args.input
    path_med = args.output
    f_check(path_med)
    path_pred = path_med + args.file_name + ".fasta"
    path_csv = path_med + args.file_name + ".csv"

    # setting:
    fa = pyfastx.Fasta(path_fa)
    src_vocab = vocab_in
    tgt_vocab = vocab_out
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    max_len = args.max_len
    l_idx = []
    l_src_seq = []
    l_pred_seq = {}

    # loading model
    model = ED.make_model(src_vocab=len(src_vocab), tgt_vocab=len(tgt_vocab), N=args.layers, d_model=512, d_ff=2048,
                          h=8)
    model.load_state_dict(torch.load(args.path_model))
    model.to(device)
    model.eval()
    n_pad = 0
    num = 0
    for item in fa:
        num += 1
        idx = item.name.strip()
        seq = item.seq.strip()
        l_idx.append(idx)
        l_src_seq.append(seq)
        test_sequences = seq
        # test_tokens = [Data.TokenSplit(test_sequences.lower())]
        test_tokens = [Data.TokenSplit(test_sequences)]
        # print("seq ==> ", test_sequences)
        test_src = Data.numericalize(test_tokens, src_vocab)
        # print("numerical==> ", test_src[0, :])
        # print('*' * 30)
        tgt = torch.LongTensor([[1]])
        tgt = tgt.to(device)
        test_src = test_src.to(device)

        # testing
        mask_shape = 1
        for i in range(max_len):
            Mask = ED.subsequent_mask(mask_shape)
            Mask = Mask.to(device)
            output = model(test_src, tgt, None, Mask)
            result = torch.argmax(output[:, -1])
            tgt = torch.concat([tgt, result.unsqueeze(0).unsqueeze(0)], dim=1)
            mask_shape += 1
            if (result == torch.tensor(2)):
                break

        tgt_result = ""
        for i in range(len(tgt[0])):
            key = int(tgt[0][i])
            word = getkey(tgt_vocab, key)[0]
            tgt_result += word
        tgt_result = tgt_result.lstrip('<sos>')
        tgt_result = tgt_result.rstrip('<eos>')
        l_pred_seq[idx] = tgt_result

        print('*' * 30)
        print("The current num is: ", num)
        print("Test : => ", test_sequences)
        print("Result: => ", tgt_result)
        if '<pad>' in tgt_result:
            tgt_result = tgt_result.replace('<pad>', '')
            n_pad += 1
    print("Number of <pad> elements is: ", n_pad)
    write_fa(path_pred, l_pred_seq, 'w')
    write_csv(path_csv, l_idx, l_src_seq, l_pred_seq, 'w')
    end_time = time.time()
    print("The total time cost is: ", end_time - start_time)
    print("Done!!!")
