import os
import json
import torch
import yaml
from transformers import AutoTokenizer, PretrainedConfig, BertTokenizer
from collections import OrderedDict
from utils.get_embed.lucaone_utils.luca_utils import gene_seqs_replace, gene_seq_replace_re, download_trained_checkpoint_lucaone
from utils.get_embed.lucaone_utils.lucaone_gplm import LucaGPLM, LucaGPLMConfig
from utils.get_embed.lucaone_utils.alphabet import Alphabet


class Args(object):
    pass


def load_model(log_filepath, model_dirpath):
    '''
    create tokenizer, model config, model
    :param log_filepath:
    :param model_dirpath:
    :return:
    '''
    with open(log_filepath, "r") as yaml_file:
        args_info = yaml.load(yaml_file, Loader=yaml.FullLoader)
    
    assert model_dirpath is not None and os.path.exists(model_dirpath)
    # create tokenizer
    tokenizer_dir = os.path.join(model_dirpath, "tokenizer")
    assert os.path.exists(tokenizer_dir)
    
    if args_info["tokenization"]:
        print("AutoTokenizer, tokenizer dir: %s" % tokenizer_dir)
        tokenizer = AutoTokenizer.from_pretrained(
            tokenizer_dir,
            do_lower_case=args_info["do_lower_case"],
            truncation_side=args_info["truncation"]
        )
    elif args_info["model_type"] in ["lucaone_gplm"]: # use this
        if "/v2.0/" in model_dirpath:
            tokenizer = Alphabet.from_predefined("gene_prot")
        else:
            raise Exception("Not support version=%s" % model_dirpath)
    else:
        tokenizer = BertTokenizer.from_pretrained(
            tokenizer_dir,
            do_lower_case=args_info["do_lower_case"],
            truncation_side=args_info["truncation"])
        
    # four type of models
    if args_info["model_type"] in ["lucaone_gplm"]:
        if "/v2.0/" in model_dirpath:
            config_class, model_class = LucaGPLMConfig, LucaGPLM
        else:
            raise Exception("Not support version=%s" % model_dirpath)
    else:
        raise Exception("Not support model_type=%s" % args_info["model_type"])

    # model config
    model_config: PretrainedConfig = config_class.from_json_file(os.path.join(model_dirpath, "config.json"))

    # load the pretrained model or create the model
    args = Args()
    args.pretrain_tasks = args_info["pretrain_tasks"]
    args.ignore_index = args_info["ignore_index"]
    args.label_size = args_info["label_size"]
    args.loss_type = args_info["loss_type"]
    args.output_mode = args_info["output_mode"]
    args.max_length = args_info["max_length"]
    args.classifier_size = args_info["classifier_size"]
    args.pretrained_model_name = None
    try:
        model = model_class.from_pretrained(model_dirpath, args=args)
    except Exception as e:
        model = None
    if model is None:
        try:
            model = torch.load(os.path.join(model_dirpath, "pytorch.pt"), map_location=torch.device("cpu"))
        except Exception as e:
            model = model_class(model_config, args=args)
            pretrained_net_dict = torch.load(os.path.join(model_dirpath, "pytorch.pth"),
                                             map_location=torch.device("cpu"))
            model_state_dict_keys = set()
            for key in model.state_dict():
                model_state_dict_keys.add(key)

            new_state_dict = OrderedDict()
            for k, v in pretrained_net_dict.items():
                if k.startswith("module."):
                    # remove `module.`
                    name = k[7:]
                else:
                    name = k
                if name in model_state_dict_keys:
                    new_state_dict[name] = v
            # print("diff:")
            # print(model_state_dict_keys.difference(new_state_dict.keys()))
            model.load_state_dict(new_state_dict)
    # print(model)
    model.eval()
    return args_info, model_config, model, tokenizer


def encoder(args_info, model_config, seq, seq_type, tokenizer):
    seqs = [seq]
    seq_types = [seq_type]
    seq_encoded_list = [tokenizer.encode(seq)]
    if "max_length" in args_info and args_info["max_length"] and args_info["max_length"] > 0:
        seq_encoded_list = [encoded[:args_info["max_length"]] for encoded in seq_encoded_list]
    max_len = max(len(seq_encoded) for seq_encoded in seq_encoded_list)
    processed_seq_len = max_len + int(tokenizer.prepend_bos) + int(tokenizer.append_eos)
    # for input
    input_ids = torch.empty(
        (
            1,
            processed_seq_len,
        ),
        dtype=torch.int64,
    )
    input_ids.fill_(tokenizer.padding_idx)

    position_ids = None
    if not model_config.no_position_embeddings:
        position_ids = torch.empty(
            (
                1,
                processed_seq_len,
            ),
            dtype=torch.int64,
        )
        position_ids.fill_(tokenizer.padding_idx)

    token_type_ids = None
    if not model_config.no_token_type_embeddings:
        token_type_ids = torch.empty(
            (
                1,
                processed_seq_len,
            ),
            dtype=torch.int64,
        )
        token_type_ids.fill_(tokenizer.padding_idx)

    for i, (seq_type, seq_str, seq_encoded) in enumerate(
            zip(seq_types, seqs, seq_encoded_list)
    ):
        if tokenizer.prepend_bos:
            input_ids[i, 0] = tokenizer.cls_idx
        seq = torch.tensor(seq_encoded, dtype=torch.int64)
        input_ids[i, int(tokenizer.prepend_bos): len(seq_encoded) + int(tokenizer.prepend_bos)] = seq
        if tokenizer.append_eos:
            input_ids[i, len(seq_encoded) + int(tokenizer.prepend_bos)] = tokenizer.eos_idx

        if not model_config.no_position_embeddings:
            cur_len = int(tokenizer.prepend_bos) + len(seq_encoded) + int(tokenizer.append_eos)
            for idx in range(0, cur_len):
                position_ids[i, idx] = idx
        if not model_config.no_token_type_embeddings:
            if seq_type in ['DNA', 'RNA']:
                type_value = 0
            else:
                type_value = 1
            cur_len = int(tokenizer.prepend_bos) + len(seq_encoded) + int(tokenizer.append_eos)
            for idx in range(0, cur_len):
                token_type_ids[i, idx] = type_value

    encoding = {"input_ids": input_ids, "token_type_ids": token_type_ids, "position_ids": position_ids}
    
    return encoding, processed_seq_len


def get_LUCAONE_embeds(args_info, model_config, tokenizer, model, seqs,
                   seq_type: str = 'DNA', layer_idx: int = -1):
    if isinstance(seqs, str):
        seqs = [seqs]
    
    # replace ATUCGN to 12345 for genes, and then we can use all 20 amino acids for proteins
    if seq_type in ['DNA', 'RNA']:
        seqs = gene_seqs_replace(seqs)
    
    batches = []
    for seq in seqs:
        batch, processed_seq_len = encoder(args_info, model_config, seq, seq_type, tokenizer)
        batches.append(batch)
    
    new_batch = {}
    for k, v in batches[0].items():
        if torch.is_tensor(v):
            new_batch[k] = torch.cat([b[k] for b in batches]).cuda()
    new_batch["return_contacts"] = True
    new_batch["return_dict"] = True
    new_batch["repr_layers"] = list(range(args_info["num_hidden_layers"] + 1))
    new_batch['output_keys'] = dict(token_level=['gene'])
    
    with torch.no_grad():
        output, repr_layers = model(**new_batch)

    logits = output[:, 1:-1]  # B x L+2 x 39 -> B x L x 39
    ATCG_logits = logits[:, :, 5:9].softmax(-1)  # B x L x 39 -> B x L x 4
    embeds = repr_layers[new_batch["repr_layers"][layer_idx]][:, 1:-1]  # K: B x L+2 x 2560 -> B x L x 2560

    return ATCG_logits, embeds


def load_lucaone_model(path):
    download_trained_checkpoint_lucaone(path)
    cur_log_filepath = os.path.join("config/embed/lucaone_args_info.yaml")
    cur_model_dirpath = os.path.join(path, 'models/lucagplm/v2.0/token_level,span_level,seq_level,structure_level/lucaone_gplm/20231125113045/checkpoint-step5600000')
    args_info, model_config, model, tokenizer = load_model(cur_log_filepath, cur_model_dirpath)
    return args_info, model_config, model, tokenizer


def main(seq):
    args_info, model_config, model, tokenizer = load_lucaone_model("data/ckpt/lucaone/llm")
    model = model.cuda()
    model.eval()
    
    emb, logits = get_embedding(args_info,
        model_config,
        tokenizer,
        model,
        seq)
    print(emb, emb.shape)
    idxs = logits.argmax(-1)
    print(logits.shape, idxs)
    seqs = gene_seq_replace_re(tokenizer.decode(idxs))
    print(seqs)


if __name__ =='__main__':
    seq = 'ATGAGCGAGTTGAACGACGC[MASK]TACTGGATGAAGCAGGCGTTGGCGTTGGCTCAAAAAGCGCGCGAGCAGGGTGAGGTCCCAGTCGGGGCCATTCTGGTCTTGGATGACGAGGTCATTGGCCAGGGGTGGAATCGCTGCGTTCACAACCACGACCCGACAGCTCACGCGGAGATCATGGCGCTGCAGCAGGGCGGGAAGCGCGTACATAACTACCGCCTGCATGACGCAACGCTATACTCTACATTCGAACCATGCGTGATGTGCGCCGGCGCCATGGTCCATTCGAGGATCAAGCGCCTGGTGTACGGCATGAGCAACAGCAAGCGCGGCGCCGCTGGGAGTCTGTTGAATGTATTGAACTACCCTGGCATGAACCATCAGATCGAAATCACAGCGGGTGTCATGGCGAATGAGGCCTATGAACTTCTTCGGCAGTTTTACCGGCAGCTGCGCTTGGCTCCAGGGGCCGAGCGCGAAGCGCGGCGCGAGGTCAACCTGGATCGCGCAGAT'
    main(seq)