#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Translate pre-processed data with a trained model.
"""

import logging
import math
import os
import sys

import torch

from fairseq import bleu, checkpoint_utils, options, progress_bar, tasks, utils
from fairseq.meters import StopwatchMeter, TimeMeter
import pdb

def param_calculator(arch):
#    pdb.set_trace()
    arch = [i + 4 for i in arch]
    arc = arch[0]
    layout = arch[1:]
    dict_num = 32768
    dict_dim = 1024
    head_dim = 64
    p_sum = 0
    p_sum += dict_num * dict_dim
    p_sum += dict_dim * arc * head_dim + arc * head_dim
    p_sum += dict_dim * arc * head_dim + dict_dim
    for i in layout[0:6]:
        p_sum += encoder_param_layer(arc, i)
    for j in layout[6:12]:
        p_sum += decoder_param_layer(arc, j)
    return p_sum

def encoder_param_layer(input_head, inner_head):
    return (input_head * 64 * inner_head * 64) * 4 + inner_head * 64 * 3 + input_head * 64 + inner_head * 64 * 2 + input_head * 64 * inner_head * 64 * 4 * 2  + inner_head * 64 + input_head * 64 + input_head * 64 * 2

def decoder_param_layer(input_head, inner_head):
    return (input_head * 64 * inner_head * 64) * 4 + inner_head * 64 * 3 + input_head * 64 + inner_head * 64 * 2 + input_head * 64 * inner_head * 64 * 4 * 2 + inner_head * 64 * 2 + (input_head * 64 * inner_head * 64) * 4 + input_head * 64 * 2

def flops_calculator(arch, length):
    arch = [i + 4 for i in arch]
    arc = arch[0]
    layout = arch[1:]
    dict_num = 32768
    dict_dim = 1024
    head_dim = 64
    p_sum = 0
    p_sum += length * (2 * dict_dim - 1) * arc * head_dim
    p_sum += length * (2 * head_dim - 1) * arc * dict_dim
    for i in layout[0:6]:
        p_sum += encoder_flops_layer(arc, i, length)
#    for leng in range(length-1):
    for j in layout[6:12]:
            p_sum += decoder_flops_layer(arc, j, length, length)
#    print(sum)
    return p_sum
#    for j in layout[6:12]:
#    print(sum)

def encoder_flops_layer(input_head, inner_head, length):
    p_sum = 0
    ## key
    p_sum += length * (2 * inner_head * 64 - 1) * input_head * 64
    ## query
    p_sum += length * (2 * inner_head * 64 - 1) * input_head * 64
    ## value
    p_sum += length * (2 * inner_head * 64 - 1) * input_head * 64
    ## attention * mul
    p_sum += inner_head * length * length * (64 * 2)
    p_sum += inner_head * length * length * length * (64 * 64 + 64 * 64)
    ## reproj
    p_sum += length * (2 * input_head * 64 - 1) * inner_head * 64
    ## Residual
    p_sum += length * input_head * 64
    p_sum += length * input_head * 64
    ## fc1
    p_sum += length * (2 * inner_head * 4 * 64 - 1) * input_head * 64
    ## relu
    ## fc2i
    p_sum += length * (2 * input_head * 4 * 64 - 1) * inner_head * 64
    ## residual
    p_sum += length * input_head * 64
    p_sum += length * input_head * 64

    return p_sum

def decoder_flops_layer(input_head, inner_head, length_src, length_tar):



    p_sum = 0
    ##Self Attention
    ## key
    p_sum += length_tar * (2 * inner_head * 64 - 1) * input_head * 64
    ## query
    p_sum += length_tar * (2 * inner_head * 64 - 1) * input_head * 64
    ## value
    p_sum += length_tar * (2 * inner_head * 64 - 1) * input_head * 64
    ## attention * mul
    p_sum += inner_head * length_tar * length_tar * (64 * 2)
    p_sum += inner_head * length_tar * length_tar * length_tar * (64 *64 + 64 *64)
    ## reproj
    p_sum += length_tar * (2 * input_head * 64 - 1) * inner_head * 64
    ## Residual
    p_sum += length_tar * input_head * 64

    ##Co Attention
    ## key
    p_sum += length_tar * (2 * inner_head * 64 - 1) * input_head * 64
    ## query
    p_sum += length_tar * (2 * inner_head * 64 - 1) * input_head * 64
    ## value
    p_sum += length_tar * (2 * inner_head * 64 - 1) * input_head * 64
    ## attention * mul
    p_sum += inner_head * length_src * length_tar * (64 * 2)
    p_sum += inner_head * length_src * length_src * length_tar * (64 * 64 + 64 * 64)
    ## reproj
    p_sum += length_tar * (2 * input_head * 64 - 1) * inner_head * 64
    ## Residual
    p_sum += length_tar * input_head * 64

    ## fc1
    p_sum += length_tar * (2 * inner_head * 4 * 64 - 1) * input_head * 64
    ## relu
    ## fc2i
    p_sum += length_tar * (2 * input_head * 4 * 64 - 1) * inner_head * 64
    ## residual
    p_sum += length_tar * input_head * 64

    return p_sum



def main(args):
    assert args.path is not None, '--path required for generation!'
    assert not args.sampling or args.nbest == args.beam, \
        '--sampling requires --nbest to be equal to --beam'
    assert args.replace_unk is None or args.dataset_impl == 'raw', \
        '--replace-unk requires a raw text dataset (--dataset-impl=raw)'

    if args.results_path is not None:
        os.makedirs(args.results_path, exist_ok=True)
        output_path = os.path.join(args.results_path, 'generate-{}.txt'.format(args.gen_subset))
        with open(output_path, 'w', buffering=1) as h:
            return _main(args, h)
    else:
        return _main(args, sys.stdout)


def _main(args, output_file):
    logging.basicConfig(
        format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S',
        level=logging.INFO,
        stream=output_file,
    )
    logger = logging.getLogger('fairseq_cli.generate')

    utils.import_user_module(args)

    if args.max_tokens is None and args.max_sentences is None:
        args.max_tokens = 12000
    logger.info(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

    # Load dataset splits
#    task = tasks.setup_task(args)
#    task.load_dataset(args.gen_subset)

    # Set dictionaries
#    try:
#        src_dict = getattr(task, 'source_dictionary', None)
#    except NotImplementedError:
#        src_dict = None
#    tgt_dict = task.target_dictionary

    # Load ensemble
    logger.info('loading model(s) from {}'.format(args.path))
#    models, _model_args = checkpoint_utils.load_model_ensemble(
#        args.path.split(os.pathsep),
#        arg_overrides=eval(args.model_overrides),
#        task=task,
#    )

    # Optimize ensemble for generation
#    for model in models:
#        model.make_generation_fast_(
#            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
#            need_attn=args.print_alignment,
#        )
#        if args.fp16:
#            model.half()
#        if use_cuda:
#            model.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
#    align_dict = utils.load_align_dict(args.replace_unk)

    large = True
    if large:
       indexs = [[12, 12, 12, 12, 12, 11, 12, 12, 12, 12, 12, 12, 10]]
#       indexs = [[12] * 13]
#        indexs = [[12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 11],
#                  [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 11, 10],
#                  [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 11, 12, 10],
#                  [12, 12, 12, 12, 12, 12, 12, 12, 11, 12, 12, 11, 11],
#                  [12, 12, 12, 12, 12, 12, 12, 12, 10, 12, 12, 12, 11],
#                  [12, 12, 12, 12, 12, 11, 12, 12, 10, 12, 12, 11, 10],
#                  [12, 12, 12, 12, 12, 11, 10, 12, 10, 11, 12, 11, 10],
#                  [12, 12, 12, 12, 12, 11, 10, 12, 10, 11, 12, 11, 11],
#                  [12, 12, 12, 12, 12, 11, 10, 12, 10, 11, 12, 12, 10],
#                  [12, 12, 12, 12, 12, 12, 12, 12, 10, 12, 12, 12, 11],
#                  [12, 12, 12, 12, 12, 12, 12, 11, 10, 12, 12, 11, 11],
#                  [12, 12, 12, 12, 12, 10, 12, 12, 11, 12, 12, 12, 10],
#                  [12, 12, 12, 12, 12, 11, 12, 12, 12, 12, 12, 12, 10]] 
#        indexs = indexs[12]
#        indexs = [indexs]
#         indexs = [[12] * 13]
#        indexs = [[0]*13, [1]*13, [2]*13, [3]*13, [4]*13, [5]* 13, [6]*13, [7]*13, [8]*13, [9] * 13, [10]*13, [11]*13, [12]*13]
#       indexs = [[12, 12, 12, 12, 12, 11, 12, 12, 12, 12, 12, 12, 10]]
#       indexs = [[12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12]]
#       indexs = [[12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 11]]
#        indexs = [[12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
    else:
       cell_list = [12, 11, 10]
#       cell_list = [3, 2, 1, 0]
       indexs = [[12, q, w, e, r, t, y, u, i, g ,h, j, k] for q in cell_list
                                                  for w in cell_list
                                                  for e in cell_list
                                                  for r in cell_list
                                                  for t in cell_list
                                                  for y in cell_list
                                                  for u in cell_list
                                                  for i in cell_list
                                                  for g in cell_list
                                                  for h in cell_list
                                                  for j in cell_list
                                                  for k in cell_list]
    
#    pdb.set_trace()
#    constraint_flops = flops_calculator([2] * 13, 20)
    for index in indexs:
#       pdb.set_trace()
#       dropout_probs = torch.empty(1).uniform_()
#       if dropout_probs < 0.9:
#          continue       
#       param = param_calculator(index)
#       pdb.set_trace()
       param = flops_calculator(index, 20)
#       if param > constraint_flops:
#          continue
#       pdb.set_trace()
       task = tasks.setup_task(args)
       try:
          src_dict = getattr(task, 'source_dictionary', None)
       except NotImplementedError:
          src_dict = None
       task.load_dataset(args.gen_subset)
       tgt_dict = task.target_dictionary
       align_dict = utils.load_align_dict(args.replace_unk)
       models, _model_args = checkpoint_utils.load_model_ensemble(
               args.path.split(os.pathsep),
               arg_overrides=eval(args.model_overrides),
               task=task,
       )
       for model in models:
           model.make_generation_fast_(
                beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
                need_attn=args.print_alignment,
           )
           if args.fp16:
              model.half()
           if use_cuda:
               model.cuda()
       # Load dataset (possibly sharded)
       itr = task.get_batch_iterator(
            dataset=task.dataset(args.gen_subset),
            max_tokens=args.max_tokens,
            max_sentences=args.max_sentences,
            max_positions=utils.resolve_max_positions(
            task.max_positions(),
            *[model.max_positions() for model in models]
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=args.required_batch_size_multiple,
            num_shards=args.num_shards,
            shard_id=args.shard_id,
            num_workers=args.num_workers,
       ).next_epoch_itr(shuffle=False)

       # Initialize generator
       gen_timer = StopwatchMeter()
       generator = task.build_generator(args)

       # Generate and compute BLEU score
       if args.sacrebleu:
           scorer = bleu.SacrebleuScorer()
       else:
           scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
       num_sentences = 0
       has_target = True
 
       with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            prefix_tokens = None
            if args.prefix_size > 0:
                prefix_tokens = sample['target'][:, :args.prefix_size]

            gen_timer.start()
            sample['index'] = index
            hypos = task.inference_step(generator, models, sample, prefix_tokens)
            num_generated_tokens = sum(len(h[0]['tokens']) for h in hypos)
            gen_timer.stop(num_generated_tokens)

            for i, sample_id in enumerate(sample['id'].tolist()):
                has_target = sample['target'] is not None

                # Remove padding
                src_tokens = utils.strip_pad(sample['net_input']['src_tokens'][i, :], tgt_dict.pad())
                target_tokens = None
                if has_target:
                    target_tokens = utils.strip_pad(sample['target'][i, :], tgt_dict.pad()).int().cpu()

                # Either retrieve the original sentences or regenerate them from tokens.
                if align_dict is not None:
                    src_str = task.dataset(args.gen_subset).src.get_original_text(sample_id)
                    target_str = task.dataset(args.gen_subset).tgt.get_original_text(sample_id)
                else:
                    if src_dict is not None:
                        src_str = src_dict.string(src_tokens, args.remove_bpe)
                    else:
                        src_str = ""
                    if has_target:
                        target_str = tgt_dict.string(target_tokens, args.remove_bpe, escape_unk=True)

                if not args.quiet:
                    if src_dict is not None:
                        print('S-{}\t{}'.format(sample_id, src_str), file=output_file)
                    if has_target:
                        print('T-{}\t{}'.format(sample_id, target_str), file=output_file)

                # Process top predictions
                for j, hypo in enumerate(hypos[i][:args.nbest]):
                    hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                        hypo_tokens=hypo['tokens'].int().cpu(),
                        src_str=src_str,
                        alignment=hypo['alignment'],
                        align_dict=align_dict,
                        tgt_dict=tgt_dict,
                        remove_bpe=args.remove_bpe,
                    )

                    if not args.quiet:
                        score = hypo['score'] / math.log(2)  # convert to base 2
                        print('H-{}\t{}\t{}'.format(sample_id, score, hypo_str), file=output_file)
                        print('P-{}\t{}'.format(
                            sample_id,
                            ' '.join(map(
                                lambda x: '{:.4f}'.format(x),
                                # convert from base e to base 2
                                hypo['positional_scores'].div_(math.log(2)).tolist(),
                            ))
                        ), file=output_file)

                        if args.print_alignment:
                            print('A-{}\t{}'.format(
                                sample_id,
                                ' '.join(['{}-{}'.format(src_idx, tgt_idx) for src_idx, tgt_idx in alignment])
                            ), file=output_file)

                        if args.print_step:
                            print('I-{}\t{}'.format(sample_id, hypo['steps']), file=output_file)

                        if getattr(args, 'retain_iter_history', False):
                            for step, h in enumerate(hypo['history']):
                                _, h_str, _ = utils.post_process_prediction(
                                    hypo_tokens=h['tokens'].int().cpu(),
                                    src_str=src_str,
                                    alignment=None,
                                    align_dict=None,
                                    tgt_dict=tgt_dict,
                                    remove_bpe=None,
                                )
                                print('E-{}_{}\t{}'.format(sample_id, step, h_str), file=output_file)

                    # Score only the top hypothesis
                    if has_target and j == 0:
                        if align_dict is not None or args.remove_bpe is not None:
                            # Convert back to tokens for evaluation with unk replacement and/or without BPE
                            target_tokens = tgt_dict.encode_line(target_str, add_if_not_exist=True)
                        if hasattr(scorer, 'add_string'):
                            scorer.add_string(target_str, hypo_str)
                        else:
                            scorer.add(target_tokens, hypo_tokens)

            wps_meter.update(num_generated_tokens)
            t.log({'wps': round(wps_meter.avg)})
            num_sentences += sample['nsentences']
       print(str(index) + ' ' + scorer.result_string() + ' ' +  str(round(param / 10.**9, 2)))
       del(models) 
       del(_model_args)          
       del(task)
    logger.info('NOTE: hypothesis and token scores are output in base 2')
    logger.info('Translated {} sentences ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)'.format(
        num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg))
    if has_target:
        logger.info('Generate {} with beam={}: {}'.format(args.gen_subset, args.beam, scorer.result_string()))

    return scorer


def cli_main():
    parser = options.get_generation_parser()
    args = options.parse_args_and_arch(parser)
    main(args)


if __name__ == '__main__':
    cli_main()
