import os
import sys
import time
import logging
from tqdm import tqdm

import torch
from fairseq import utils, tasks, options
from fairseq.checkpoint_utils import load_model_ensemble_and_task
from fairseq.dataclass.utils import convert_namespace_to_omegaconf

from torch import Tensor
from typing import Dict, List, Optional

logging.basicConfig(
    format="%(asctime)s | %(levelname)s | %(name)s |  [%(filename)s:%(lineno)d] %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=os.environ.get("LOGLEVEL", "INFO").upper(),
    stream=sys.stdout,
)
logger = logging.getLogger("inference")


def write_result(results, output_file):
    with open(output_file, 'w') as f:
        for line in results:
            f.write(line + '\n')


@torch.no_grad()
def fairseq_generate(data_lines, cfg, models, task, batch_size, device):
    # fairseq original decoding implementation
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary
    generator = task.build_generator(models, cfg.generation)
    data_size = len(data_lines)
    all_results = []
    logger.info(f'Fairseq generate batch {batch_size}')
    start = time.perf_counter()
    for start_idx in tqdm(range(0, data_size, batch_size)):
        batch_lines = [line for line in data_lines[start_idx: min(start_idx + batch_size, data_size)]]
        batch_ids = [src_dict.encode_line(sentence, add_if_not_exist=False).long() for sentence in batch_lines]
        lengths = torch.LongTensor([t.numel() for t in batch_ids])
        batch_dataset = task.build_dataset_for_inference(batch_ids, lengths)
        batch = batch_dataset.collater(batch_dataset)
        batch = utils.apply_to_sample(lambda t: t.to(device), batch)
        translations = generator.generate(models, batch, prefix_tokens=None)
        results = []
        for id, hypos in zip(batch["id"].tolist(), translations):
            results.append((id, hypos))
        batched_hypos = [hypos for _, hypos in sorted(results, key=lambda x: x[0])]
        all_results.extend([tgt_dict.string(hypos[0]['tokens']) for hypos in batched_hypos])
    delta = time.perf_counter() - start
    remove_bpe_results = [line.replace('@@ ', '') for line in all_results]
    return remove_bpe_results, delta


@torch.no_grad()
def baseline_forward_decoder(model,
                             input_tokens,
                             encoder_out: Dict[str, List[Tensor]],
                             incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
                             parallel_forward_start_pos=None,
                             temperature: float = 1.0):
    decoder_out = model.decoder.forward(input_tokens,
                                        encoder_out=encoder_out,
                                        incremental_state=incremental_state,
                                        parallel_forward_start_pos=parallel_forward_start_pos)
    decoder_out_tuple = (decoder_out[0].div_(temperature), decoder_out[1])
    pred_tokens = torch.argmax(decoder_out_tuple[0], dim=-1).squeeze(0)
    return pred_tokens


@torch.no_grad()
def baseline_generate(data_lines, model, task, batch_size, device, max_len=200):
    # simplified AR greedy decoding
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary
    data_size = len(data_lines)
    all_results = []
    start = time.perf_counter()
    logger.info(f'Baseline generate')
    for start_idx in tqdm(range(0, data_size, batch_size)):
        batch_size = min(data_size - start_idx, batch_size)
        batch_lines = [line for line in data_lines[start_idx: start_idx + batch_size]]
        batch_ids = [src_dict.encode_line(sentence, add_if_not_exist=False).long() for sentence in batch_lines]
        lengths = torch.LongTensor([t.numel() for t in batch_ids])
        batch_dataset = task.build_dataset_for_inference(batch_ids, lengths)
        batch_dataset.left_pad_source = False
        batch = batch_dataset.collater(batch_dataset)
        batch = utils.apply_to_sample(lambda t: t.to(device), batch)
        net_input = batch['net_input']
        encoder_out = model.encoder.forward(net_input['src_tokens'], net_input['src_lengths'])
        incremental_state = torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]],
                                               torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}))
        batch_tokens = [[tgt_dict.eos()] for _ in range(batch_size)]
        finish_list = []
        for step in range(0, max_len):
            cur_input_tokens = torch.tensor(batch_tokens).to(device).long()
            pred_tokens = baseline_forward_decoder(model,
                                                   cur_input_tokens,
                                                   encoder_out,
                                                   incremental_state=incremental_state)
            for i, pred_tok in enumerate(pred_tokens):
                if len(batch_tokens[i]) == 1:
                    batch_tokens[i].append(pred_tok.item())
                else:
                    if batch_tokens[i][-1] != tgt_dict.eos():
                        batch_tokens[i].append(pred_tok.item())
                    else:
                        if i not in finish_list:
                            finish_list.append(i)
                        batch_tokens[i].append(tgt_dict.eos())
            if len(finish_list) == batch_size:
                break
        batch_tokens = [y for x, y in sorted(zip(batch['id'].cpu().tolist(), batch_tokens))]
        for tokens in batch_tokens:
            all_results.append(tgt_dict.string(tokens[1:]))
    remove_bpe_results = [line.replace('@@ ', '') for line in all_results]
    delta = time.perf_counter() - start
    return remove_bpe_results, delta


@torch.no_grad()
def forward_decoder(model, input_tokens, encoder_out, incremental_state=None,
                    parallel_forward_start_pos=None, temperature=1.0, beta=1, tau=0.0):
    decoder_out = model.decoder.forward(input_tokens,
                                        encoder_out=encoder_out,
                                        incremental_state=incremental_state,
                                        parallel_forward_start_pos=parallel_forward_start_pos)
    decoder_out_tuple = (decoder_out[0].div_(temperature), decoder_out[1])
    topk_scores, indexes = torch.topk(decoder_out_tuple[0], beta, dim=-1)
    topk_scores_list = topk_scores.tolist()
    indexes_list = indexes.tolist()
    for i in range(indexes.size(0)):
        for j in range(indexes.size(1)):
            for k, s in enumerate(topk_scores_list[i][j]):
                if topk_scores_list[i][j][0] - s > tau:
                    indexes_list[i][j][k] = -1
    return indexes_list


def gad_generate(data_lines, model, AR_model, task, block_size, batch_size, device, beta=1, tau=0, max_len=200):
    # Speculative Decoding
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary
    data_size = len(data_lines)
    all_results = []
    logger.info(f'GAD generate')
    start = time.perf_counter()
    for start_idx in tqdm(range(0, data_size, batch_size)):
        batch_size = min(data_size - start_idx, batch_size)
        batch_lines = [line for line in data_lines[start_idx: start_idx + batch_size]]
        batch_ids = [src_dict.encode_line(sentence, add_if_not_exist=False).long() for sentence in batch_lines]
        lengths = torch.LongTensor([t.numel() for t in batch_ids])
        batch_dataset = task.build_dataset_for_inference(batch_ids, lengths)
        batch_dataset.left_pad_source = False
        batch = batch_dataset.collater(batch_dataset)
        batch = utils.apply_to_sample(lambda t: t.to(device), batch)
        net_input = batch['net_input']
        AR_encoder_out = AR_model.encoder.forward(net_input['src_tokens'], net_input['src_lengths'])
        encoder_out = model.encoder.forward(net_input['src_tokens'], net_input['src_lengths'])
        sentences = [[tgt_dict.eos()] for _ in range(batch_size)]
        prev_output_tokens = [[tgt_dict.unk()] * block_size for _ in range(batch_size)]
        start_pos_list = [0] * batch_size
        finish_list = []
        for step in range(0, max_len):
            prev_output_tokens, start_pos_list = gad_forward(start_pos_list, block_size, batch_size,
                                                             tgt_dict, prev_output_tokens,
                                                             encoder_out, AR_encoder_out, model, AR_model, beta, tau)
            for i, start_pos in enumerate(start_pos_list):
                if i not in finish_list:
                    if start_pos == -1:
                        finish_list.append(i)
                        sentences[i] = prev_output_tokens[i]

            if len(finish_list) == batch_size:
                break
        batch_sents = [y for x, y in sorted(zip(batch['id'].cpu().tolist(), sentences))]
        for s in batch_sents:
            all_results.append(tgt_dict.string(s))
    remove_bpe_results = [line.replace('@@ ', '') for line in all_results]
    delta = time.perf_counter() - start
    return remove_bpe_results, delta


def gad_forward(start_pos_list, block_size, batch_size, tgt_dict, prev_output_tokens,
                encoder_out, AR_encoder_out, model, AR_model, beta, tau, max_len=200):
    pad_tokens = [[tgt_dict.pad()] * (max_len + block_size) for _ in range(batch_size)]
    for i in range(batch_size):
        pad_tokens[i][:len(prev_output_tokens[i])] = prev_output_tokens[i]
    output_tokens = torch.tensor(pad_tokens).to(device)
    output_tokens = output_tokens[:, : output_tokens.ne(tgt_dict.pad()).sum(1).max()]

    _, tensor_tokens = model.decoder(
        normalize=False,
        prev_output_tokens=output_tokens,
        encoder_out=encoder_out,
    ).max(-1)

    _tokens = tensor_tokens.tolist()
    for i, start_pos in enumerate(start_pos_list):
        if start_pos_list[i] != -1:
            output_tokens[i, start_pos:start_pos + block_size] = tensor_tokens[i, start_pos:start_pos + block_size]
            prev_output_tokens[i][start_pos:start_pos + block_size] = _tokens[i][start_pos:start_pos + block_size]

    append_eos = torch.tensor([[tgt_dict.eos()] for _ in range(batch_size)]).to(device)
    cur_span_input_tokens = torch.cat((append_eos, output_tokens), dim=-1)

    AR_verify_tokens = forward_decoder(AR_model, cur_span_input_tokens, AR_encoder_out, beta=beta, tau=tau)

    next_output_tokens = prev_output_tokens.copy()
    for i in range(batch_size):
        if start_pos_list[i] != -1:
            bifurcation = block_size
            for j, (token, AR_verify_token) in enumerate(
                    zip(prev_output_tokens[i][start_pos_list[i]:], AR_verify_tokens[i][start_pos_list[i]:-1])):
                if token not in AR_verify_token:
                    bifurcation = j
                    break
            next_output_tokens[i] = prev_output_tokens[i][:start_pos_list[i] + bifurcation] + \
                                    [AR_verify_tokens[i][start_pos_list[i] + bifurcation][0]] + \
                                    [tgt_dict.unk()] * block_size

            find_eos = False
            for j, o in enumerate(next_output_tokens[i][start_pos_list[i]:start_pos_list[i] + bifurcation + 1]):
                if o == tgt_dict.eos() or start_pos_list[i] + j == max_len:
                    next_output_tokens[i] = next_output_tokens[i][:start_pos_list[i] + j]
                    start_pos_list[i] = -1
                    find_eos = True
                    break
            if not find_eos:
                start_pos_list[i] = start_pos_list[i] + bifurcation + 1

    return next_output_tokens, start_pos_list


if __name__ == '__main__':
    parser = options.get_generation_parser()
    parser.add_argument('--input-path', type=str, required=True,
                        help='path to eval file')
    parser.add_argument('--output-path', type=str, default=None,
                        help='path to output file')
    parser.add_argument('--AR-path', type=str, default=None,
                        help='path to AR model')
    parser.add_argument('--strategy', type=str, default='fairseq',
                        help='decoding strategy, choose from: fairseq, AR, gad')
    parser.add_argument('--batch', type=int, default=None,
                        help='batch size')
    parser.add_argument('--block-size', type=int, default=5,
                        help='block size')
    parser.add_argument('--beta', type=int, default=1,
                        help='top-beta hyperparameter')
    parser.add_argument('--tau', type=float, default=0,
                        help='tolerance hyperparameter')
    cmd_args = options.parse_args_and_arch(parser)
    cmd_args.input_path = os.path.expanduser(cmd_args.input_path)
    cmd_args.output_path = os.path.expanduser(cmd_args.output_path)

    cfg = convert_namespace_to_omegaconf(cmd_args)

    task = tasks.setup_task(cfg.task)

    # NAR drafter
    logger.info("loading model(s) from {}".format(cfg.common_eval.path))
    models, _model_args, _model_task = load_model_ensemble_and_task(filenames=[cfg.common_eval.path], task=task)

    if cmd_args.cpu:
        device = torch.device('cpu')
    else:
        device = torch.device('cuda')
    model = models[0].to(device).eval()
    if cfg.common.fp16:
        logging.info("NAR fp16 enabled!")
        model.half()

    # AR verifier
    AR_model = None
    AR_models = None
    _AR_model_task = None
    if cmd_args.AR_path is not None:
        AR_models, _AR_model_args, _AR_model_task = load_model_ensemble_and_task(filenames=[cmd_args.AR_path],
                                                                                 arg_overrides={'data': cfg.task.data})
        if cfg.common.fp16:
            logging.info("AR fp16 enabled!")
            for AR_model in AR_models:
                AR_model.half()
        AR_model = AR_models[0].to(device).eval()
        logging.info("AR model loaded!")

    with open(cmd_args.input_path, 'r') as f:
        bpe_sents = [l.strip() for l in f.readlines()]

    if cmd_args.strategy == 'AR':
        logger.info("Decoding Strategy: Simplified AR")
        remove_bpe_results, delta = baseline_generate(bpe_sents, AR_model, _AR_model_task, cmd_args.batch, device)
        logger.info(f'Simplified AR generate: {delta}')
    elif cmd_args.strategy == 'specdec':
        logger.info("Decoding Strategy: Speculative")
        remove_bpe_results, delta = gad_generate(bpe_sents, model, AR_model, task, cmd_args.block_size, cmd_args.batch,
                                                 device, beta=cmd_args.beta, tau=cmd_args.tau)
        logger.info(f'Speculative generate: {delta}')
    else:
        logger.info("Decoding Strategy: fairseq")
        remove_bpe_results, delta = fairseq_generate(bpe_sents, cfg, AR_models, _AR_model_task, cmd_args.batch, device)
        logger.info(f'Fairseq generate batch {cmd_args.batch}, beam {cfg.generation.beam}: {delta}')

    if cmd_args.output_path is not None:
        write_result(remove_bpe_results, cmd_args.output_path)
