import os
import sys
import time
import logging
from tqdm import tqdm

import torch
from fairseq import utils, tasks, options
from fairseq.checkpoint_utils import load_model_ensemble_and_task
from fairseq.dataclass.utils import convert_namespace_to_omegaconf

from torch import Tensor
from typing import Dict, List, Optional

logging.basicConfig(
    format="%(asctime)s | %(levelname)s | %(name)s |  [%(filename)s:%(lineno)d] %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=os.environ.get("LOGLEVEL", "INFO").upper(),
    stream=sys.stdout,
)
logger = logging.getLogger("inference")


def write_result(results, output_file):
    with open(output_file, 'w') as f:
        for line in results:
            f.write(line + '\n')


@torch.no_grad()
def fairseq_generate(data_lines, cfg, models, task, batch_size, device):
    # fairseq original decoding implementation
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary
    generator = task.build_generator(models, cfg.generation)
    data_size = len(data_lines)
    all_results = []
    logger.info(f'Fairseq generate batch {batch_size}')
    start = time.perf_counter()
    for start_idx in tqdm(range(0, data_size, batch_size)):
        batch_lines = [line for line in data_lines[start_idx: min(start_idx + batch_size, data_size)]]
        batch_ids = [src_dict.encode_line(sentence, add_if_not_exist=False).long() for sentence in batch_lines]
        lengths = torch.LongTensor([t.numel() for t in batch_ids])
        batch_dataset = task.build_dataset_for_inference(batch_ids, lengths)
        batch = batch_dataset.collater(batch_dataset)
        batch = utils.apply_to_sample(lambda t: t.to(device), batch)
        translations = generator.generate(models, batch, prefix_tokens=None)
        results = []
        for id, hypos in zip(batch["id"].tolist(), translations):
            results.append((id, hypos))
        batched_hypos = [hypos for _, hypos in sorted(results, key=lambda x: x[0])]
        all_results.extend([tgt_dict.string(hypos[0]['tokens']) for hypos in batched_hypos])
    delta = time.perf_counter() - start
    remove_bpe_results = [line.replace('@@ ', '') for line in all_results]
    return remove_bpe_results, delta


@torch.no_grad()
def baseline_forward_decoder(model,
                             input_tokens,
                             encoder_out: Dict[str, List[Tensor]],
                             incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
                             parallel_forward_start_pos=None,
                             temperature: float = 1.0):
    decoder_out = model.decoder.forward(input_tokens,
                                        encoder_out=encoder_out,
                                        incremental_state=incremental_state,
                                        parallel_forward_start_pos=parallel_forward_start_pos)
    decoder_out_tuple = (decoder_out[0].div_(temperature), decoder_out[1])
    pred_tokens = torch.argmax(decoder_out_tuple[0], dim=-1).squeeze(0)
    return pred_tokens


@torch.no_grad()
def baseline_generate(data_lines, model, task, device, max_len=200):
    # simplified AR greedy decoding
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary
    data_size = len(data_lines)
    all_results = []
    logger.info(f'Baseline generate')
    start = time.perf_counter()
    for start_idx in tqdm(range(0, data_size)):
        bpe_line = data_lines[start_idx]
        src_tokens = src_dict.encode_line(bpe_line, add_if_not_exist=False).long()
        net_input = {'src_tokens': src_tokens.unsqueeze(0).to(device),
                     'src_lengths': torch.LongTensor([src_tokens.numel()]).to(device)}
        encoder_out = model.encoder.forward_torchscript(net_input)
        incremental_state = torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]],
                                               torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}))
        tokens = [tgt_dict.eos()]

        for step in range(0, max_len):
            cur_input_tokens = torch.tensor([tokens]).to(device).long()
            pred_token = baseline_forward_decoder(model,
                                                  cur_input_tokens,
                                                  encoder_out,
                                                  incremental_state).item()
            if pred_token == tgt_dict.eos():
                break
            else:
                tokens.append(pred_token)
        all_results.append(tgt_dict.string(tokens[1:]))
    delta = time.perf_counter() - start
    remove_bpe_results = [line.replace('@@ ', '') for line in all_results]
    return remove_bpe_results, delta


def cut_incremental_state(incremental_state, keep_len, encoder_state_ids):
    for n in incremental_state:
        if n[: n.index('.')] in encoder_state_ids:
            continue
        for k in incremental_state[n]:
            if incremental_state[n][k] is not None:
                if incremental_state[n][k].dim() == 4:
                    incremental_state[n][k] = incremental_state[n][k][:, :, :keep_len]
                elif incremental_state[n][k].dim() == 2:
                    incremental_state[n][k] = incremental_state[n][k][:, :keep_len]


@torch.no_grad()
def forward_decoder(model,
                    input_tokens,
                    encoder_out: Dict[str, List[Tensor]],
                    incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
                    parallel_forward_start_pos=None,
                    temperature: float = 1.0,
                    beta: int = 1,
                    tau: float = 0.0):
    decoder_out = model.decoder.forward(input_tokens,
                                        encoder_out=encoder_out,
                                        incremental_state=incremental_state,
                                        parallel_forward_start_pos=parallel_forward_start_pos)
    decoder_out_tuple = (decoder_out[0].div_(temperature), decoder_out[1])
    topk_scores, indexes = torch.topk(decoder_out_tuple[0], beta, dim=-1)
    topk_scores = topk_scores[0].tolist()
    indexes = indexes[0].tolist()
    for i in range(len(topk_scores)):
        for j, s in enumerate(topk_scores[i]):
            if topk_scores[i][0] - s > tau:
                indexes[i][j] = -1
    return indexes


def gad_generate(data_lines, model, AR_model, task, block_size, device, beta=1, tau=0, max_len=200):
    # Speculative Decoding
    src_dict = task.source_dictionary
    tgt_dict = task.target_dictionary
    encoder_state_ids = []
    for i in range(len(AR_model.decoder.layers)):
        encoder_state_ids.append(AR_model.decoder.layers[i].encoder_attn._incremental_state_id)
    data_size = len(data_lines)
    all_results = []
    logger.info(f'GAD generate')
    pass_tokens = [0] * max_len
    sent_nums = [0] * max_len
    start = time.perf_counter()
    for start_idx in tqdm(range(0, data_size)):
        bpe_line = data_lines[start_idx]
        src_tokens = src_dict.encode_line(bpe_line, add_if_not_exist=False).long()
        net_input = {'src_tokens': src_tokens.unsqueeze(0).to(device),
                     'src_lengths': torch.LongTensor([src_tokens.numel()]).to(device)}
        AR_encoder_out = AR_model.encoder.forward_torchscript(net_input)
        encoder_out = model.encoder.forward_torchscript(net_input)
        incremental_state = torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]],
                                               torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}))
        prev_output_tokens = [tgt_dict.unk()] * block_size
        start_pos = 0
        for step in range(0, max_len):
            start_pos, prev_output_tokens, pass_token = gad_forward(incremental_state, encoder_state_ids,
                                                                    start_pos, block_size, tgt_dict,
                                                                    prev_output_tokens,
                                                                    encoder_out, AR_encoder_out, model,
                                                                    AR_model, beta, tau)
            pass_tokens[step] += pass_token
            sent_nums[step] += 1
            if start_pos == -1:
                break
        all_results.append(tgt_dict.string(prev_output_tokens))
    total_pass_tokens = 0
    total_sent_nums = 0
    for step in range(max_len):
        if sent_nums[step] > 0:
            total_pass_tokens += pass_tokens[step]
            total_sent_nums += sent_nums[step]
    print("Avg accepted tokens:", total_pass_tokens / total_sent_nums)
    total_iter = 0
    for step in range(max_len):
        if sent_nums[step - 1] > 0:
            if step == 0:
                last_num = data_size
            else:
                last_num = sent_nums[step - 1]
            if (last_num - sent_nums[step]) > 0:
                total_iter += (last_num - sent_nums[step]) * (step)
    print("Avg decoding iteration:", total_iter / data_size)
    delta = time.perf_counter() - start
    remove_bpe_results = [line.replace('@@ ', '') for line in all_results]
    return remove_bpe_results, delta


def gad_forward(incremental_state, encoder_state_ids, start_pos, block_size, tgt_dict, prev_output_tokens,
                encoder_out, AR_encoder_out, model, AR_model, beta, tau, max_len=200):
    output_tokens = torch.tensor([prev_output_tokens]).to(device)
    _scores, _tokens = model.decoder(
        normalize=False,
        prev_output_tokens=output_tokens,
        encoder_out=encoder_out,
    ).max(-1)

    prev_output_tokens[start_pos:start_pos + block_size] = _tokens[0].tolist()[start_pos:start_pos + block_size]

    cut_incremental_state(incremental_state, keep_len=start_pos, encoder_state_ids=encoder_state_ids)
    cur_span_input_tokens = torch.tensor([[tgt_dict.eos()] + prev_output_tokens]).to(device)
    AR_topk_tokens = forward_decoder(AR_model,
                                     cur_span_input_tokens,
                                     AR_encoder_out,
                                     incremental_state,
                                     parallel_forward_start_pos=start_pos,
                                     beta=beta,
                                     tau=tau)

    bifurcation = block_size
    for i, (token, AR_topk_token) in enumerate(zip(prev_output_tokens[start_pos:], AR_topk_tokens[:-1][:])):
        if token not in AR_topk_token:
            bifurcation = i
            break

    next_output_tokens = prev_output_tokens[:start_pos + bifurcation] + [AR_topk_tokens[bifurcation][0]] + [
        tgt_dict.unk()] * block_size

    pass_token = 0
    find_eos = False
    for i, o in enumerate(next_output_tokens[start_pos:start_pos + bifurcation + 1]):
        if o == tgt_dict.eos() or i + start_pos == max_len:
            next_output_tokens = next_output_tokens[0:start_pos + i]
            start_pos = -1
            pass_token = i
            find_eos = True
            break

    if not find_eos:
        start_pos = start_pos + bifurcation + 1
        pass_token = bifurcation + 1

    return start_pos, next_output_tokens, pass_token


if __name__ == '__main__':
    parser = options.get_generation_parser()
    parser.add_argument('--input-path', type=str, required=True,
                        help='path to eval file')
    parser.add_argument('--output-path', type=str, default=None,
                        help='path to output file')
    parser.add_argument('--AR-path', type=str, default=None,
                        help='path to AR model')
    parser.add_argument('--strategy', type=str, default='fairseq',
                        help='decoding strategy, choose from: fairseq, AR, gad')
    parser.add_argument('--batch', type=int, default=None,
                        help='batch size')
    parser.add_argument('--block-size', type=int, default=5,
                        help='block size')
    parser.add_argument('--beta', type=int, default=1,
                        help='top-beta hyperparameter')
    parser.add_argument('--tau', type=float, default=0,
                        help='tolerance hyperparameter')
    cmd_args = options.parse_args_and_arch(parser)
    cmd_args.input_path = os.path.expanduser(cmd_args.input_path)
    cmd_args.output_path = os.path.expanduser(cmd_args.output_path)

    cfg = convert_namespace_to_omegaconf(cmd_args)

    task = tasks.setup_task(cfg.task)

    # NAR drafter
    logger.info("loading model(s) from {}".format(cfg.common_eval.path))
    models, _model_args, _model_task = load_model_ensemble_and_task(filenames=[cfg.common_eval.path], task=task)

    if cmd_args.cpu:
        device = torch.device('cpu')
    else:
        device = torch.device('cuda')
    model = models[0].to(device).eval()

    # AR verifier
    AR_model = None
    AR_models = None
    _AR_model_task = None
    if cmd_args.AR_path is not None:
        AR_models, _AR_model_args, _AR_model_task = load_model_ensemble_and_task(filenames=[cmd_args.AR_path],
                                                                                 arg_overrides={'data': cfg.task.data})
        AR_model = AR_models[0].to(device).eval()
        logging.info("AR model loaded!")

    with open(cmd_args.input_path, 'r') as f:
        bpe_sents = [l.strip() for l in f.readlines()]

    if cmd_args.strategy == 'AR':
        logger.info("Decoding Strategy: Simplified AR")
        remove_bpe_results, delta = baseline_generate(bpe_sents, AR_model, _AR_model_task, device)
        logger.info(f'Simplified AR generate: {delta}')
    elif cmd_args.strategy == 'specdec':
        logger.info("Decoding Strategy: Speculative")
        remove_bpe_results, delta = gad_generate(bpe_sents, model, AR_model, task, cmd_args.block_size, device,
                                                 beta=cmd_args.beta, tau=cmd_args.tau)
        logger.info(f'Speculative generate: {delta}')
    else:
        logger.info("Decoding Strategy: fairseq")
        remove_bpe_results, delta = fairseq_generate(bpe_sents, cfg, AR_models, _AR_model_task, cmd_args.batch, device)
        logger.info(f'Fairseq generate batch {cmd_args.batch}, beam {cfg.generation.beam}: {delta}')

    if cmd_args.output_path is not None:
        write_result(remove_bpe_results, cmd_args.output_path)
