import os
import sys
import copy
import yaml
import math
import torch
import logging
import argparse
import sentencepiece as spm

sys.path.append("../asr_lm")

from nn_lm.models.transformer.transformer_lm import TransformerLM
from nn_lm.models.transformer.transformer_class import TransformerClassLM

from utils.text import convert_to_token
from utils.checkpoint import load_checkpoint
from nn_lm.models.init_model import init_model
from utils.common import load_dict


def get_args():
    parser = argparse.ArgumentParser(description="Language model inference")
    parser.add_argument('--text',
                        type=str,
                        default="",
                        help="location of text file")
    parser.add_argument('--ppl',
                        type=str,
                        default="",
                        help="location of ppl metric file")
    parser.add_argument('--model',
                        type=str,
                        default="checkpoints/test/class_transformer/TransformerClassLM_38/TransformerClassLM_avg_5.pt",
                        help="location of .pt model")
    parser.add_argument('--config',
                        type=str,
                        default="checkpoints/test/class_transformer/train.yaml",
                        help="location of model config")
    parser.add_argument('--dict',
                        type=str,
                        default="checkpoints/test/class_transformer/asr_lm_nwp_dict.txt",
                        help="location of word dict file")
    
    parser.add_argument('--bpe_model',
                        type=str,
                        default="data/bpe_model/spm_giga_xmly_1500.model",
                        help="location of bpe model for bpe tokenize")
    
    parser.add_argument('--gpu',
                        type=int,
                        default=0)
    parser.add_argument('--seed',
                        type=int,
                        default=1111)

    args = parser.parse_args()
    return args


def set_logger():
    log_dir = "log/test_log"
    os.makedirs(log_dir, exist_ok=True)

    # log_file = os.path.join(log_prefix, f"lm_train-{str(date.today())}.log")
    log_file = None
    logging.basicConfig(level=logging.DEBUG,
                        format='%(asctime)s %(levelname)s %(message)s',
                        filename=log_file,
                        filemode='a')


def main():
    args = get_args()
    set_logger()
    
    inference_type = "debug"
    
    assert os.path.exists(args.model)
    assert os.path.exists(args.config)
    assert os.path.exists(args.dict)

    os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu)
    torch.manual_seed(args.seed)
    
    with open(args.config, 'r') as f:
        configs = yaml.load(f, Loader=yaml.FullLoader)
    
    sp = spm.SentencePieceProcessor()
    sp.load(args.bpe_model)
    
    model_type = configs["model_type"]
    token_conf = configs['lm_dataset_conf']['token_conf']
    to_lower = token_conf['to_lower']
    cn_en_symbol = token_conf['cn_en_symbol']

    idx2word, word2idx = load_dict(dict_path=args.dict, to_lower=to_lower)

    if not to_lower:
        pad_idx = word2idx['<blank>'.upper()]
        sos = word2idx['<sos/eos>'.upper()]
        eos = word2idx['<sos/eos>'.upper()]
    else:
        pad_idx = word2idx['<blank>']
        sos = word2idx['<sos/eos>']
        eos = word2idx['<sos/eos>']
    
    lm_model, criterion = init_model(model_type=model_type,
                                     n_tokens=configs["n_tokens"],
                                     n_classes=configs.get("n_classes", 0),
                                     ignore_idx=pad_idx,
                                     configs=configs,
                                     param_init=False)
    load_checkpoint(lm_model, args.model)
    
    use_cuda = args.gpu >= 0 and torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')
    lm_model = lm_model.to(device)
    lm_model.eval()

    
    if inference_type == "text":
        fr = open(args.text, "r", encoding="utf-8")
        fw = open(args.ppl, "w", encoding="utf-8")
        
        num_seen_seq = 0
        batch_size = 1
        total_loss = 0.0

        with torch.no_grad():
            line = fr.readline().strip()
            while line:
                if to_lower:
                    text = line.lower()
                else:
                    text = line.upper()
                
                token_list = convert_to_token(txt=text,
                                              word2idx=word2idx,
                                              min_len=1,
                                              max_len=100,
                                              bpe_sp=sp,
                                              token_type="bpe",
                                              to_lower=to_lower,
                                              cn_en_symbol=cn_en_symbol)
                if token_list is None:
                    logging.warn("filter this line, text={}".format(text))
                    line = fr.readline().strip()
                    continue
                fw.write("text: " + "".join(token_list))
                fw.write("\n")
                fw.write("ids: " + " ".join([str(_t) for _t in token_list]))
                fw.write("\n")

                tokens_in = copy.deepcopy(token_list)
                tokens_out = copy.deepcopy(token_list)
                tokens_in.insert(0, sos)
                tokens_out.append(eos)
                lm_input = torch.tensor(tokens_in,
                                        dtype=torch.int64,
                                        device=device).unsqueeze(0)
                lm_target = torch.tensor(tokens_out,
                                        dtype=torch.int64,
                                        device=device).unsqueeze(0)
                # inference
                if isinstance(lm_model, TransformerLM):
                    lm_out, _ = lm_model(lm_input, None)
                elif isinstance(lm_model, TransformerClassLM):
                    lm_out, seq_tag_out = lm_model(lm_input)
                else:
                    raise NotImplementedError
                
                nwp_loss = criterion(x=lm_out, target_x=lm_target)
                fw.write("Loss: {:.4f} \t PPL: {:.4f}".format(nwp_loss.item(), math.exp(nwp_loss.item())))
                fw.write("\n")
                fw.write("\n")

                if torch.isfinite(nwp_loss):
                    num_seen_seq += batch_size
                    total_loss += nwp_loss.item() * batch_size
                line = fr.readline().strip()
        average_loss = total_loss / num_seen_seq
        
        fw.write("\n")
        fw.write("Total Loss: {:.4f} \t Total PPL: {:.4f}".format(average_loss, math.exp(average_loss)))
        fw.write("\n")

        fr.close()
        fw.close()
        
    elif inference_type == "debug":
        text = "给"
        with torch.no_grad():
            if to_lower:
                text = text.lower()
            else:
                text = text.upper()
            
            token_dict = convert_to_token(txt=text,
                                          word2idx=word2idx,
                                          word2idx_cls=dict(),
                                          min_len=0,
                                          max_len=100,
                                          bpe_sp=sp,
                                          token_type="bpe",
                                          to_lower=to_lower,
                                          cn_en_symbol=cn_en_symbol)
            token_list = token_dict["tokens"]
            
            tokens_in = copy.deepcopy(token_list)
            tokens_out = copy.deepcopy(token_list)
            tokens_in.insert(0, sos)
            tokens_out.append(eos)
            lm_input = torch.tensor(tokens_in,
                                    dtype=torch.int64,
                                    device=device).unsqueeze(0)
            lm_target = torch.tensor(tokens_out,
                                     dtype=torch.int64,
                                     device=device).unsqueeze(0)
            # inference
            if isinstance(lm_model, TransformerLM):
                lm_out, _ = lm_model(lm_input, None)
            elif isinstance(lm_model, TransformerClassLM):
                lm_out, seq_tag_out = lm_model(lm_input)
            else:
                raise NotImplementedError


if __name__ == '__main__':
    main()
