#!/usr/bin/env python
# encoding: utf-8
'''
@license: (C) Copyright 2021, Hey.
@author: Hey
@email: sanyuan.**@**.com
@tel: 137****6540
@datetime: 2022/11/28 19:31
@project: LucaOneApp
@file: utils.py
@desc: utils
'''
import os
from typing import List, Union
import requests


common_nucleotide_set = {'A', 'T', 'C', 'G', 'U', 'N'}

# not {'O', 'U', 'Z', 'J', 'B'}
# Common amino acids
common_amino_acid_set = {'R', 'X', 'S', 'G', 'W', 'I', 'Q', 'A', 'T', 'V', 'K', 'Y', 'C', 'N', 'L', 'F', 'D', 'M', 'P', 'H', 'E'}


def gene_seq_replace_re(seqs):
    '''
    Nucleic acid 还原
    :param seq:
    :return:
    '''
    if isinstance(seqs, str):
        seqs = [seqs]
    new_seqs = []
    for seq in seqs:
        new_seq = ""
        for ch in seq:
            if ch == '1':
                new_seq += "A"
            elif ch == '2':
                new_seq += "T"
            elif ch == '3':
                new_seq += "C"
            elif ch == '4':
                new_seq += "G"
            else: # unknown
                new_seq += "N"
        new_seqs.append(new_seq)
    return new_seqs


def gene_seq_replace(seq: str):
    '''
    Nucleic acid （gene replace: A->1, U/T->2, C->3, G->4, N->5
    :param seq:
    :return:
    '''
    new_seq = ""
    idx = 0
    while idx < len(seq):
        ch = seq[idx]
        if ch in ["A", "a"]:
            new_seq += "1"
        elif ch in ["T", "U", "t", "u"]:
            new_seq += "2"
        elif ch in ["C", "c"]:
            new_seq += "3"
        elif ch in ["G", "g"]:
            new_seq += "4"
        elif ch in ["["]:
            new_seq += "[MASK]"
            idx += 5
        else: # unknown
            raise ValueError(f'Invalid sequence: {ch} in {seq}.')
        idx += 1
    return new_seq

def gene_seqs_replace(seqs: Union[str, List[str]]):
    if isinstance(seqs, str):
        return gene_seq_replace(seqs)
    else:
        return [gene_seq_replace(seq) for seq in seqs]


def clean_seq(protein_id, seq, return_rm_index=False):
    seq = seq.upper()
    new_seq = ""
    has_invalid_char = False
    invalid_char_set = set()
    return_rm_index_set = set()
    for idx, ch in enumerate(seq):
        if 'A' <= ch <= 'Z' and ch not in ['J']:
            new_seq += ch
        else:
            invalid_char_set.add(ch)
            return_rm_index_set.add(idx)
            has_invalid_char = True
    if has_invalid_char:
        print("id: %s. Seq: %s" % (protein_id, seq))
        print("invalid char set:", invalid_char_set)
        print("return_rm_index:", return_rm_index_set)
    if return_rm_index:
        return new_seq, return_rm_index_set
    return new_seq


def download_file(url, local_filename):
    with requests.get(url, stream=True) as r:
        r.raise_for_status()
        dir_name = os.path.dirname(local_filename)
        if not os.path.exists(dir_name):
            os.makedirs(dir_name)
        with open(local_filename, 'wb') as f:
            for chunk in r.iter_content(chunk_size=8192):
                if chunk: # filter out keep-alive new chunks
                    f.write(chunk)
    return local_filename


def download_folder(base_url, file_names, local_dir):
    if not os.path.exists(local_dir):
        os.makedirs(local_dir)

    for file_name in file_names:
        file_url = f"{base_url}/{file_name}"
        local_filename = os.path.join(local_dir, file_name)
        download_file(file_url, local_filename)
        print(f"Downloaded {file_name}")


def download_trained_checkpoint_lucaone(
        llm_dir,
        llm_type="lucaone_gplm",
        llm_version="v2.0",
        llm_task_level="token_level,span_level,seq_level,structure_level",
        llm_time_str="20231125113045",
        llm_step="5600000",
        base_url="http://47.93.21.181/lucaone/TrainedCheckPoint"
):
    try:
        logs_file_names = ["logs.txt"]
        models_file_names = ["config.json", "pytorch.pth", "training_args.bin", "tokenizer/alphabet.pkl"]
        logs_path = "logs/lucagplm/%s/%s/%s/%s" % (llm_version, llm_task_level, llm_type, llm_time_str)
        models_path = "models/lucagplm/%s/%s/%s/%s/checkpoint-step%s" % (llm_version, llm_task_level, llm_type, llm_time_str, llm_step)
        logs_local_dir = os.path.join(llm_dir, logs_path)
        exists = True
        for logs_file_name in logs_file_names:
            if not os.path.exists(os.path.join(logs_local_dir, logs_file_name)):
                exists = False
                break
        models_local_dir = os.path.join(llm_dir, models_path)
        if exists:
            for models_file_name in models_file_names:
                if not os.path.exists(os.path.join(models_local_dir, models_file_name)):
                    exists = False
                    break
        if not exists:
            print("*" * 20 + "Downloading" + "*" * 20)
            print("Downloading LucaOne TrainedCheckPoint: LucaOne-%s-%s-%s ..." % (llm_version, llm_time_str, llm_step))
            print("Wait a moment, please.")
            # download logs
            if not os.path.exists(logs_local_dir):
                os.makedirs(logs_local_dir)
            logs_base_url = os.path.join(base_url, logs_path)
            download_folder(logs_base_url, logs_file_names, logs_local_dir)
            # download models
            if not os.path.exists(models_local_dir):
                os.makedirs(models_local_dir)
            models_base_url = os.path.join(base_url, models_path)
            download_folder(models_base_url, models_file_names, models_local_dir)
            print("LucaOne Downloaded.")
            print("*" * 50)
    except Exception as e:
        print(e)
        print("Download automatically LucaOne Trained CheckPoint failed!")
        print("You can manually download 'logs/' and 'models/' into local directory: %s/ from %s" % (os.path.abspath(llm_dir), os.path.join(base_url, "TrainedCheckPoint/")))
        raise Exception(e)
