#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Train a new model on one or across multiple GPUs.
"""
import checkpoint_utils_append
import logging
import math
import os
import random
import sys

import numpy as np
import torch
import Append_tranformer
from fairseq.data import iterators
from fairseq import (
    checkpoint_utils, distributed_utils, metrics, options, progress_bar, tasks, utils
)

from fairseq.trainer import Trainer
from fairseq.meters import StopwatchMeter


logging.basicConfig(
    format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    level=logging.INFO,
    stream=sys.stdout,
)
logger = logging.getLogger('fairseq_cli.train')


def save_Append_(args, trainer, epoch_itr, valid_losses):
    Li2 = Append_tranformer.Del_append(trainer.get_model())
    checkpoint_utils_append.save_checkpoint(args, trainer, epoch_itr, valid_losses[0], Li2)
    trainer.get_model()._is_generation_fast = False

    _ = Append_tranformer.Append_TransformerModel(trainer.get_model(),
                                                  Tran_for=Append_tranformer.Append_transformer_forward,
                                                  en_for=Append_tranformer.Append_encoder_forward,
                                                  de_for=Append_tranformer.Append_decoder_forward,
                                                  de_ext=Append_tranformer.Append_extract_features)
    _ = Append_tranformer.Append_retrain(trainer.get_model(),
                                         Re_train=Append_tranformer.Re_train)
    _ = Append_tranformer.add_parameters(trainer.get_model(), Li2[0], Li2[1],
                                         Li2[2], Li2[3])






def main(args, init_distributed=False):
    utils.import_user_module(args)

    assert args.max_tokens is not None or args.max_sentences is not None, \
        'Must specify batch size either with --max-tokens or --max-sentences'

    # Initialize CUDA and distributed training
    if torch.cuda.is_available() and not args.cpu:
        torch.cuda.set_device(args.device_id)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if init_distributed:
        args.distributed_rank = distributed_utils.distributed_init(args)

    if distributed_utils.is_master(args):
        checkpoint_utils_append.verify_checkpoint_directory(args.save_dir)

    # Print args
    logger.info(args)

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(args)

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    for valid_sub_split in args.valid_subset.split(','):
        task.load_dataset(valid_sub_split, combine=False, epoch=0)

    # Build model and criterion


    if (args.modelO==None):
        model = task.build_model(args)
    else:
        model_ = task.build_model(args)
        del model_
        model = args.modelO




    criterion = task.build_criterion(args)
    #logger.info(model)
    logger.info('model {}, criterion {}'.format(args.arch, criterion.__class__.__name__))
    logger.info('num. model params: {} (num. trained: {})'.format(
        sum(p.numel() for p in model.parameters()),
        sum(p.numel() for p in model.parameters() if p.requires_grad),
    ))

    # Build trainer

    trainer = Trainer(args, task, model, criterion)
    logger.info('training on {} GPUs'.format(args.distributed_world_size))
    logger.info('max tokens per GPU = {} and max sentences per GPU = {}'.format(
        args.max_tokens,
        args.max_sentences,
    ))

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator



    lr = trainer.get_lr()

    if(lr > args.min_lr):
        print(10)
    logger.info('done training in  seconds')

    if (args.cur_cpt != None):
        args.restore_file=os.path.join(args.cur_cpt, "checkpoint_best.pt")


    # if (args.extra_stateO==None):
    #     extra_state, epoch_itr = checkpoint_utils_append.load_checkpoint(args, trainer)
    # else:
    #     extra_state = args.extra_stateO
    #     epoch_itr=args.epoch_itrO

    args.reset_optimizer = True
    if (args.cur_cpt != None):
        args.restore_file = os.path.join(args.cur_cpt, "checkpoint_best.pt")
        pp = 0
        for p_ in trainer.get_model().state_dict():
            if (pp == 1):
                c__ = trainer.get_model().state_dict()[p_]
                break
            pp = pp + 1
        device = c__.device
        restore_file_append = os.path.join(args.cur_cpt,
                                           "Append_best.pt")
        Li = torch.load(restore_file_append)
        extra_state, epoch_itr = checkpoint_utils_append.load_checkpoint(args, trainer)
        trainer.get_model()._is_generation_fast = False

        _ = Append_tranformer.Append_TransformerModel(trainer.get_model(),
                                                      Tran_for=Append_tranformer.Append_transformer_forward,
                                                      en_for=Append_tranformer.Append_encoder_forward,
                                                      de_for=Append_tranformer.Append_decoder_forward,
                                                      de_ext=Append_tranformer.Append_extract_features)
        _ = Append_tranformer.Append_retrain(trainer.get_model(),
                                             Re_train=Append_tranformer.Re_train)
        _ = Append_tranformer.add_parameters(trainer.get_model(), Li[0].to(device), Li[1],
                                             Li[2].to(device), Li[3])
    else:
        extra_state, epoch_itr = checkpoint_utils_append.load_checkpoint(args, trainer)
        trainer.get_model()._is_generation_fast = False
        _ = Append_tranformer.Append_TransformerModel(trainer.get_model(),
                                                                Tran_for=Append_tranformer.Append_transformer_forward,
                                                                en_for=Append_tranformer.Append_encoder_forward,
                                                                de_for=Append_tranformer.Append_decoder_forward,
                                                                de_ext=Append_tranformer.Append_extract_features)
        _ = Append_tranformer.Append_retrain(trainer.get_model(),
                                                       Re_train=Append_tranformer.Re_train)
        pp = 0
        for p_ in trainer.get_model().state_dict():
            if (pp == 1):
                c__ = trainer.get_model().state_dict()[p_]
                break
            pp = pp + 1
        device = c__.device

        E_num = len(trainer.get_model().encoder.layers)
        D_num = len(trainer.get_model().decoder.layers)
        Enlayer_ModuleList, Enaddlink_addList = Append_tranformer.ini_model_list(E_num, 1)
        Delayer_ModuleList, Deaddlink_addList = Append_tranformer.ini_model_list(D_num, 2)
        _ = Append_tranformer.add_parameters(trainer.get_model(), Enlayer_ModuleList.to(device), Enaddlink_addList,
                                                       Delayer_ModuleList.to(device), Deaddlink_addList)

    E_num = len(trainer.get_model().encoder.layers)
    D_num = len(trainer.get_model().decoder.layers)

    Li_=(trainer.get_model().encoder_modellist,trainer.get_model().encoder_addList,trainer.get_model().decoder_modellist,trainer.get_model().decoder_addList)
    Enlayer_ModuleList, Enaddlink_addList, Delayer_ModuleList, Deaddlink_addList, _, _ = \
        Append_tranformer.new_connection_random(Li_[0].to(device), Li_[1],
                                             Li_[2].to(device), Li_[3], E_num, D_num,
                                                total_num=5)

    _ = Append_tranformer.add_parameters(trainer.get_model(), Enlayer_ModuleList.to(device), Enaddlink_addList,
                                                   Delayer_ModuleList.to(device), Deaddlink_addList)






    # Train until the learning rate gets too small
    max_epoch = args.max_epoch or math.inf
    max_update = args.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = StopwatchMeter()
    train_meter.start()
    valid_subsets = args.valid_subset.split(',')
    normal=0
    while (
        lr > args.min_lr
        and (
            epoch_itr.epoch < max_epoch
            # allow resuming training from the final checkpoint
            or epoch_itr._next_epoch_itr is not None
        )
        and trainer.get_num_updates() < max_update
    ):
        if(epoch_itr.epoch>args.normal and normal==0):######################################!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
            Append_tranformer.return_normal(model.encoder_modellist)
            Append_tranformer.return_normal(model.decoder_modellist)
            normal = 1

        if (trainer.get_num_updates() > args.change_normal_update and normal==0):  ######################################!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
            Append_tranformer.return_normal(model.encoder_modellist)
            Append_tranformer.return_normal(model.decoder_modellist)
            normal =1


        # train for one epoch
        train(args, trainer, task, epoch_itr)

        if not args.disable_validation and epoch_itr.epoch % args.validate_interval == 0:
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
        else:
            valid_losses = [None]

        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])



        # save checkpoint
        if epoch_itr.epoch % args.save_interval == 0:
            save_Append_(args, trainer, epoch_itr, valid_losses)
            # Li2 = Append_tranformer.Del_append(trainer.get_model())
            # checkpoint_utils_append.save_checkpoint(args, trainer, epoch_itr, valid_losses[0], Li2)
            # trainer.get_model()._is_generation_fast = False
            #
            # _ = Append_tranformer.Append_TransformerModel(trainer.get_model(),
            #                                               Tran_for=Append_tranformer.Append_transformer_forward,
            #                                               en_for=Append_tranformer.Append_encoder_forward,
            #                                               de_for=Append_tranformer.Append_decoder_forward,
            #                                               de_ext=Append_tranformer.Append_extract_features)
            # _ = Append_tranformer.Append_retrain(trainer.get_model(),
            #                                      Re_train=Append_tranformer.Re_train)
            # _ = Append_tranformer.add_parameters(trainer.get_model(), Li2[0].to(device), Li2[1],
            #                                      Li2[2].to(device), Li2[3])




        # early stop
        if should_stop_early(args, valid_losses[0]):
            logger.info('early stop since valid performance hasn\'t improved for last {} runs'.format(args.patience))
            break

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.epoch,
            # sharded data: get train iterator for next epoch
            load_dataset=(os.pathsep in getattr(args, 'data', '')),
        )
    train_meter.stop()
    logger.info('done training in {:.1f} seconds'.format(train_meter.sum))
    if hasattr(checkpoint_utils_append.save_checkpoint, 'best'):
        key = 'best_{0}'.format(args.best_checkpoint_metric)
        best_function = max if args.maximize_best_checkpoint_metric else min
        best_bleu = checkpoint_utils_append.save_checkpoint.best

    return best_bleu


def should_stop_early(args, valid_loss):
    if args.patience <= 0:
        return False

    def is_better(a, b):
        return a > b if args.maximize_best_checkpoint_metric else a < b

    prev_best = getattr(should_stop_early, 'best', None)
    if prev_best is None or is_better(valid_loss, prev_best):
        should_stop_early.best = valid_loss
        should_stop_early.num_runs = 0
        return False
    else:
        should_stop_early.num_runs += 1
        return should_stop_early.num_runs > args.patience


@metrics.aggregate('train')
def train(args, trainer, task, epoch_itr):
    """Train the model for one epoch."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus=args.fix_batches_to_gpus,
        shuffle=(epoch_itr.epoch >= args.curriculum),
    )
    update_freq = (
        args.update_freq[epoch_itr.epoch - 1]
        if epoch_itr.epoch <= len(args.update_freq)
        else args.update_freq[-1]
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.build_progress_bar(
        args, itr, epoch_itr.epoch, no_progress_bar='simple',
    )

    # task specific setup per epoch
    task.begin_epoch(epoch_itr.epoch, trainer.get_model())

    valid_subsets = args.valid_subset.split(',')
    max_update = args.max_update or math.inf
    ery=0
    for samples in progress:######################################!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        if (args.early_epoch!=None):
            ery=ery+1
            if (ery>args.early_epoch):
                break

        log_output = trainer.train_step(samples)
        num_updates = trainer.get_num_updates()
        if log_output is None:
            continue

        # log mid-epoch stats
        stats = get_training_stats(metrics.get_smoothed_values('train'))
        progress.log(stats, tag='train', step=num_updates)

        if (
            not args.disable_validation
            and args.save_interval_updates > 0
            and num_updates % args.save_interval_updates == 0
            and num_updates > 0
        ):
            valid_losses = validate(args, trainer, task, epoch_itr, valid_subsets)
            save_Append_(args, trainer, epoch_itr, valid_losses)

        if num_updates >= max_update:
            break

    # log end-of-epoch stats
    stats = get_training_stats(metrics.get_smoothed_values('train'))
    progress.print(stats, tag='train', step=num_updates)

    # reset epoch-level meters
    metrics.reset_meters('train')


def get_training_stats(stats):
    if 'nll_loss' in stats and 'ppl' not in stats:
        stats['ppl'] = utils.get_perplexity(stats['nll_loss'])
    stats['wall'] = round(metrics.get_meter('default', 'wall').elapsed_time, 0)
    return stats


def return_best(args, trainer, epoch_itr, val_loss):
    from fairseq import distributed_utils, meters

    prev_best = getattr(return_best, "best", val_loss)
    # if val_loss is not None:
    #     best_function = max if args.maximize_best_checkpoint_metric else min
    #     return_best.best = best_function(val_loss, prev_best)
    return prev_best


def validate(args, trainer, task, epoch_itr, subsets):
    """Evaluate the model on the validation set(s) and return the losses."""

    if args.fixed_validation_seed is not None:
        # set fixed seed for every validation
        utils.set_torch_seed(args.fixed_validation_seed)

    valid_losses = []
    for subset in subsets:
        # Initialize data iterator
        itr = task.get_batch_iterator(
            dataset=task.dataset(subset),
            max_tokens=args.max_tokens_valid,
            max_sentences=args.max_sentences_valid,
            max_positions=utils.resolve_max_positions(
                task.max_positions(),
                trainer.get_model().max_positions(),
            ),
            ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
            required_batch_size_multiple=args.required_batch_size_multiple,
            seed=args.seed,
            num_shards=args.distributed_world_size,
            shard_id=args.distributed_rank,
            num_workers=args.num_workers,
        ).next_epoch_itr(shuffle=False)
        progress = progress_bar.build_progress_bar(
            args, itr, epoch_itr.epoch,
            prefix='valid on \'{}\' subset'.format(subset),
            no_progress_bar='simple'
        )

        # create a new root metrics aggregator so validation metrics
        # don't pollute other aggregators (e.g., train meters)
        with metrics.aggregate(new_root=True) as agg:
            for sample in progress:
                trainer.valid_step(sample)

        # log validation stats
        stats = get_valid_stats(args, trainer, agg.get_smoothed_values())
        progress.print(stats, tag=subset, step=trainer.get_num_updates())

        valid_losses.append(stats[args.best_checkpoint_metric])
    return valid_losses


def get_valid_stats(args, trainer, stats):
    if 'nll_loss' in stats and 'ppl' not in stats:
        stats['ppl'] = utils.get_perplexity(stats['nll_loss'])
    stats['num_updates'] = trainer.get_num_updates()
    if hasattr(checkpoint_utils_append.save_checkpoint, 'best'):
        key = 'best_{0}'.format(args.best_checkpoint_metric)
        best_function = max if args.maximize_best_checkpoint_metric else min
        stats[key] = best_function(
            checkpoint_utils_append.save_checkpoint.best,
            stats[args.best_checkpoint_metric],
        )
    return stats


def distributed_main(i, args, start_rank=0):
    args.device_id = i
    if args.distributed_rank is None:  # torch.multiprocessing.spawn
        args.distributed_rank = start_rank + i
    main(args, init_distributed=True)


def cli_main(modify_parser=None,modelO = None,extra_stateO = None,epoch_itrO = None,early_epoch=None,save_new=None,data_dir=None,normal_up=30000,cur_cpt=None):
    parser = options.get_training_parser()
    #modify_parser2=modify_parser
    #parser.data = data_dir
    args = options.parse_args_and_arch(parser, modify_parser=modify_parser)


    modify2=args

    modify2.eval_bleu_args = '{"beam": 5, "max_len_a": 1.2, "max_len_b": 10}'





    args = modify2
    args.modelO = modelO
    args.extra_stateO = extra_stateO
    args.epoch_itrO = epoch_itrO
    args.early_epoch=early_epoch
    args.validate_interval=1

    args.normal = args.max_epoch
    args.change_normal_update=normal_up

    args.cur_cpt=cur_cpt

    args.replace_unk = True














    if args.distributed_init_method is None:
        distributed_utils.infer_init_method(args)

    if args.distributed_init_method is not None:
        # distributed training
        if torch.cuda.device_count() > 1 and not args.distributed_no_spawn:
            start_rank = args.distributed_rank
            args.distributed_rank = None  # assign automatically
            torch.multiprocessing.spawn(
                fn=distributed_main,
                args=(args, start_rank),
                nprocs=torch.cuda.device_count(),
            )
        else:
            distributed_main(args.device_id, args)
    elif args.distributed_world_size > 1:
        # fallback for single node with multiple GPUs
        assert args.distributed_world_size <= torch.cuda.device_count()
        port = random.randint(10000, 20000)
        args.distributed_init_method = 'tcp://localhost:{port}'.format(port=port)
        args.distributed_rank = None  # set based on device id
        if max(args.update_freq) > 1 and args.ddp_backend != 'no_c10d':
            logger.info('NOTE: you may get faster training with: --ddp-backend=no_c10d')
        torch.multiprocessing.spawn(
            fn=distributed_main,
            args=(args, ),
            nprocs=args.distributed_world_size,
        )
    else:
        # single GPU training
        cli__=main(args)
        return cli__


if __name__ == '__main__':
    cli_main()

