import os
import copy
import logging
import re
import torch
import numpy as np
import codecs
from tqdm import tqdm
from collections import OrderedDict
from torch.utils.data import DataLoader
from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer
from federatedscope.core.auxiliaries.scheduler_builder import get_scheduler
from federatedscope.core.trainers import GeneralTorchTrainer
from federatedscope.core.trainers.context import lifecycle, CtxVar
from federatedscope.core.trainers.enums import LIFECYCLE, MODE
from federatedscope.core.trainers.utils import filter_by_specified_keywords
from federatedscope.core.monitors import MetricCalculator
from federatedscope.core.monitors.metric_calculator import eval_acc
from federatedscope.nlp.hetero_tasks.trainer.utils import AverageMeter, \
    ContrastiveMonitor
from federatedscope.nlp.hetero_tasks.dataset.utils import setup_tokenizer
from federatedscope.nlp.hetero_tasks.dataset.squad import SquadResult
from federatedscope.nlp.hetero_tasks.dataset.newsqa import NewsQAResult

logger = logging.getLogger(__name__)


class ATCTrainer(GeneralTorchTrainer):
    def __init__(self,
                 model,
                 data,
                 device,
                 config,
                 only_for_eval=False,
                 monitor=None):
        super().__init__(model, data, device, config, only_for_eval, monitor)
        self.metric_calculator = MetricCalculator(config.eval.metrics)
        self.task = config.model.task
        self.ID = None
        self.load_ckpt = True
        self.pred_file, self.src_file, self.tgt_file = None, None, None
        self.finish_eval = False
        self.ctx.eval_metrics = None
        self.ctx.tokenizer = setup_tokenizer(config.model.model_type)
        self.ctx.grad_accum_count = config.grad.grad_accum_count
        self.ctx.padding_idx = self.ctx.tokenizer.pad_token_id
        self.ctx.init_params = copy.deepcopy(model.state_dict())
        self.pretrain_task = None
        self.use_contrastive_loss = config.model.use_contrastive_loss
        self.ctx.contrast_monitor = ContrastiveMonitor() if \
            self.use_contrastive_loss else None

    def update(self, model_parameters, strict=False):
        super().update(model_parameters, strict=strict)
        self.ctx.init_params = copy.deepcopy(self.ctx.model.state_dict())

    def update_pretrain_task(self, task):
        self.pretrain_task = task

    def update_stat(self, ID):
        self.ID = ID
        if self.task in {'cnndm', 'msqg'}:
            pred_dir = os.path.join(self.cfg.outdir, 'pred')
            src_dir = os.path.join(self.cfg.outdir, 'src')
            tgt_dir = os.path.join(self.cfg.outdir, 'tgt')
            self.ctx.pred_path = os.path.join(pred_dir, '%d.txt' % ID)
            self.ctx.src_path = os.path.join(src_dir, '%d.txt' % ID)
            self.ctx.tgt_path = os.path.join(tgt_dir, '%d.txt' % ID)

            os.makedirs(pred_dir, exist_ok=True)
            os.makedirs(src_dir, exist_ok=True)
            os.makedirs(tgt_dir, exist_ok=True)
            self.pred_file = codecs.open(self.ctx.pred_path, 'w', 'utf-8')
            self.src_file = codecs.open(self.ctx.src_path, 'w', 'utf-8')
            self.tgt_file = codecs.open(self.ctx.tgt_path, 'w', 'utf-8')

        self.ctx.model.update_client_id(ID)

    def update_contrast_monitor(self, contrast_monitor):
        self.ctx.contrast_monitor = contrast_monitor

    def get_model_grads(self, filter_keywords=None):
        if filter_keywords is None:
            filter_keywords = self.ctx.cfg.personalization.local_param
        grads = {}
        for n, p2 in self.ctx.model.state_dict().items():
            if filter_by_specified_keywords(n, filter_keywords):  # preserve
                grads[n] = p2 - self.ctx.init_params[n]
        return grads

    def parse_data(self, data):
        init_dict = dict()
        if isinstance(data, dict):
            all_split = ['train', 'val', 'test'] if not \
                self.cfg.model.use_contrastive_loss else \
                ['train_raw', 'train_contrast', 'val', 'test']
            for split in all_split:
                init_dict['{}_data'.format(split)] = None
                init_dict['{}_loader'.format(split)] = None
                init_dict['num_{}_data'.format(split)] = 0
                init_dict['{}_encoded'.format(split)] = None
                init_dict['{}_examples'.format(split)] = None
                if data.get(split, None) is not None:
                    if isinstance(data.get(split)['dataloader'], DataLoader):
                        init_dict['{}_loader'.format(split)] = \
                            data.get(split)['dataloader']
                        init_dict['num_{}_data'.format(split)] = \
                            len(data.get(split)['dataloader'].dataset)
                        init_dict['{}_encoded'.format(split)] = \
                            data.get(split)['encoded']
                        init_dict['{}_examples'.format(split)] = \
                            data.get(split)['examples']

                        if self.cfg.model.use_contrastive_loss and \
                                split == 'train_raw':
                            init_dict['train_data'] = None
                            init_dict['train_loader'] = \
                                data.get(split)['dataloader']
                            init_dict['num_train_data'] = \
                                len(data.get(split)['dataloader'].dataset)
                            init_dict['train_encoded'] = \
                                data.get(split)['encoded']
                            init_dict['train_examples'] = \
                                data.get(split)['examples']
                    else:
                        raise TypeError('Type {} is not supported.'.format(
                            type(data.get(split))))
        else:
            raise TypeError('Type of data should be dict.')

        return init_dict

    def setup_optimizer_and_scheduler(self, ctx):
        total_steps = getattr(ctx, f'num_total_{ctx.cur_mode}_batch',
                              None) // ctx.cfg.grad.grad_accum_count * \
                      ctx.cfg.federate.total_round_num
        warmup_steps = int(ctx.cfg[ctx.cur_mode].scheduler.warmup_ratio *
                           total_steps)
        optimizer = get_optimizer(ctx.model, **ctx.cfg[ctx.cur_mode].optimizer)
        scheduler = get_scheduler(optimizer,
                                  **ctx.cfg.train.scheduler,
                                  total_steps=total_steps,
                                  warmup_steps=warmup_steps)

        return optimizer, scheduler

    def _load_model(self, ctx):
        load_path = ctx.cfg.federate.atc_load_from
        global_ckpt_path = os.path.join(load_path, 'global_model.pt')
        client_ckpt_path = \
            os.path.join(load_path, 'client', 'client_model_{}.pt'.format(
                self.ID))
        if not os.path.exists(global_ckpt_path):
            global_dir = os.path.join(load_path, 'global')
            global_ckpt_path = \
                os.path.join(global_dir, 'global_model_{}.pt'.format(self.ID))
            if not os.path.exists(global_ckpt_path):
                raise RuntimeError(
                    'Checkpoint NOT found in \'{}\''.format(global_ckpt_path))

        model_ckpt = ctx.model.state_dict()
        logger.info('Loading model from \'{}\''.format(global_ckpt_path))
        global_ckpt = torch.load(global_ckpt_path, map_location='cpu')['model']
        model_ckpt.update({
            k: v
            for k, v in global_ckpt.items()
            if k in model_ckpt and v.size() == model_ckpt[k].size()
        })
        if os.path.exists(client_ckpt_path):
            logger.info('Updating model from \'{}\''.format(client_ckpt_path))
            client_ckpt = torch.load(client_ckpt_path,
                                     map_location='cpu')['model']
            model_ckpt.update({
                k: v
                for k, v in client_ckpt.items()
                if k in model_ckpt and v.size() == model_ckpt[k].size()
            })
        ctx.model.load_state_dict(model_ckpt)

    def _save_model(self, ctx):
        if len(ctx.cfg.personalization.local_param) > 0:
            model_ckpt = OrderedDict({
                k: v
                for k, v in ctx.model.state_dict().items()
                if re.search('|'.join(ctx.cfg.personalization.local_param), k)
                is not None
            })
            ckpt = {
                'model': model_ckpt,
                'epoch': ctx.cur_epoch_i + 1,
                'batch': ctx.cur_batch_i + 1,
            }
            save_dir = os.path.join(ctx.cfg.federate.save_to, 'client')
            os.makedirs(save_dir, exist_ok=True)
            ckpt_path = os.path.join(save_dir,
                                     'client_model_{}.pt'.format(self.ID))
            torch.save(ckpt, ckpt_path)

    def _remove_special_tokens(self, sent):
        return sent.replace('[CLS]', '').replace('[SEP]', '').\
            replace('[PAD]', '').replace('[unused0]', '').\
            replace('[unused3]', '').replace('[unused1]', ''). \
            replace(r' +', ' ').replace(' [unused2] ', '<q>').\
            replace('[unused2]', '').strip()

    @property
    def _in_contrast_prepare(self):
        return self.use_contrastive_loss and \
               self.task != 'pretrain' and \
               self.ctx.cur_split == 'train' and \
               self.ctx.contrast_monitor.stat == 1

    def train(self, target_data_split_name='train', hooks_set=None):
        hooks_set = hooks_set or self.hooks_in_train

        self.ctx.check_split(target_data_split_name)

        num_samples = self._run_routine(MODE.TRAIN, hooks_set,
                                        target_data_split_name)

        if not self.use_contrastive_loss:
            return \
                num_samples, self.get_model_para(), self.get_model_grads(), \
                self.ctx.eval_metrics
        return \
            num_samples, self.get_model_para(), self.get_model_grads(), \
            self.ctx.contrast_monitor, self.ctx.eval_metrics

    @lifecycle(LIFECYCLE.ROUTINE)
    def _run_routine(self, mode, hooks_set, dataset_name=None):
        if self.finish_eval:
            return self.ctx.num_samples

        raw_num_train_epoch, raw_num_train_batch = None, None
        if self._in_contrast_prepare:
            raw_num_train_epoch, raw_num_train_batch = \
                self.ctx.num_train_epoch, self.ctx.num_train_batch
            batch_size = self.ctx.cfg.data.batch_size
            num_contrast_data = len(self.ctx.contrast_monitor.synth_tokens)
            self.ctx.num_train_epoch = 1
            self.ctx.num_train_batch = \
                num_contrast_data // batch_size + bool(num_contrast_data %
                                                       batch_size)
            self.ctx.num_train_batch_last_epoch = self.ctx.num_train_batch
            self.ctx.num_total_train_batch = \
                self.ctx.num_train_epoch * self.ctx.num_train_batch

        for hook in hooks_set['on_fit_start']:
            hook(self.ctx)

        self._run_epoch(hooks_set)

        for hook in hooks_set["on_fit_end"]:
            hook(self.ctx)

        if raw_num_train_epoch is not None and raw_num_train_batch is not None:
            self.ctx.num_train_epoch = raw_num_train_epoch
            self.ctx.num_train_batch = raw_num_train_batch
            self.ctx.num_train_batch_last_epoch = self.ctx.num_train_batch
            self.ctx.num_total_train_batch = \
                self.ctx.num_train_epoch * self.ctx.num_train_batch

        return self.ctx.num_samples

    @lifecycle(LIFECYCLE.BATCH)
    def _run_batch(self, hooks_set):
        for batch_i in tqdm(range(
                getattr(self.ctx, f"num_{self.ctx.cur_split}_batch", None)),
                            disable=not (self._in_contrast_prepare
                                         or self.ctx.cur_split == "test")):
            self.ctx.cur_batch_i = CtxVar(batch_i, LIFECYCLE.BATCH)

            for hook in hooks_set["on_batch_start"]:
                hook(self.ctx)

            for hook in hooks_set["on_batch_forward"]:
                hook(self.ctx)

            for hook in hooks_set["on_batch_backward"]:
                hook(self.ctx)

            for hook in hooks_set["on_batch_end"]:
                hook(self.ctx)

            # Break in the final epoch
            if self.ctx.cur_mode in [MODE.TRAIN, MODE.FINETUNE] and \
                self.ctx.cur_epoch_i == getattr(
                    self.ctx, f'num_{self.ctx.cur_mode}_epoch', None) - 1:
                if batch_i >= \
                        getattr(self.ctx,
                                f'num_{self.ctx.cur_mode}_batch_last_epoch',
                                None) - 1:
                    break

    def _hook_on_fit_start_init(self, ctx):
        ctx.model.to(ctx.device)

        if ctx.cur_mode in [MODE.TRAIN, MODE.FINETUNE]:
            ctx.optimizer = ctx.get(f'{ctx.cur_mode}_optimizer', None)
            ctx.scheduler = ctx.get(f'{ctx.cur_mode}_scheduler', None)
            if ctx.optimizer is None or ctx.scheduler is None:
                ctx.optimizer, ctx.scheduler = \
                    self.setup_optimizer_and_scheduler(ctx)
                setattr(ctx, f'{ctx.cur_mode}_optimizer', ctx.optimizer)
                setattr(ctx, f'{ctx.cur_mode}_scheduler', ctx.scheduler)
            if ctx.cfg.federate.atc_load_from and self.load_ckpt:
                self._load_model(ctx)
                self.load_ckpt = False

        if ctx.cur_split == 'train' and ctx.cfg.federate.atc_load_from \
                and self.load_ckpt:
            self._load_model(ctx)
            self.load_ckpt = False

        # prepare statistics
        ctx.loss_agg = CtxVar(AverageMeter(), LIFECYCLE.ROUTINE)
        ctx.loss_batch_total = CtxVar(0, LIFECYCLE.ROUTINE)
        ctx.loss_regular_total = CtxVar(0, LIFECYCLE.ROUTINE)
        ctx.num_samples = CtxVar(0, LIFECYCLE.ROUTINE)
        ctx.accum_steps = CtxVar(0, LIFECYCLE.ROUTINE)
        ctx.ys_true = CtxVar([], LIFECYCLE.ROUTINE)
        ctx.ys_pred = CtxVar([], LIFECYCLE.ROUTINE)
        ctx.squad_results = CtxVar([], LIFECYCLE.ROUTINE)
        ctx.newsqa_results = CtxVar([], LIFECYCLE.ROUTINE)

        if self.use_contrastive_loss:
            if self._in_contrast_prepare:
                ctx.train_loader = ctx.train_contrast_loader
            else:
                ctx.regular_loss_agg = CtxVar(AverageMeter(),
                                              LIFECYCLE.ROUTINE)
                ctx.contrastive_loss_agg = CtxVar(AverageMeter(),
                                                  LIFECYCLE.ROUTINE)
                ctx.train_loader = ctx.train_raw_loader

    def _hook_on_batch_forward(self, ctx):
        if self.use_contrastive_loss:
            ctx.contrastive_loss_batch = CtxVar(None, LIFECYCLE.BATCH)

        if self.task == 'pretrain':
            token_ids = ctx.data_batch[self.pretrain_task]['token_ids']
            attention_mask = \
                ctx.data_batch[self.pretrain_task]['attention_mask']
            labels = ctx.data_batch[self.pretrain_task]['labels']
            example_indices = \
                ctx.data_batch[self.pretrain_task]['example_indices']

            outputs = ctx.model(
                input_ids=token_ids.to(ctx.device),
                attention_mask=attention_mask.to(ctx.device),
                labels=labels.to(ctx.device),
                pretrain_task=self.pretrain_task,
                example_indices=example_indices,
            )
            ctx.batch_size = CtxVar(len(token_ids), LIFECYCLE.BATCH)
            ctx.loss_batch = CtxVar(outputs.loss, LIFECYCLE.BATCH)
            if self.pretrain_task == 'mlm':
                y_true = labels
            elif self.pretrain_task == 'denoise':
                y_true = labels[:, 1:]
            else:
                raise KeyError('Unsupported pretrain task: \'{}\''.format(
                    self.pretrain_task))
            count_idx = y_true.ne(-100) & y_true.ne(ctx.padding_idx)
            ctx.y_true = CtxVar(y_true[count_idx], LIFECYCLE.BATCH)
            ctx.y_pred = CtxVar(
                outputs.logits.argmax(dim=-1)[count_idx], LIFECYCLE.BATCH)

        else:
            token_ids = ctx.data_batch.get('token_ids', None)
            token_type_ids = ctx.data_batch.get('token_type_ids', None)
            attention_mask = ctx.data_batch.get('attention_mask', None)
            labels = ctx.data_batch.get('labels', None)
            start_positions = ctx.data_batch.get('start_positions', None)
            end_positions = ctx.data_batch.get('end_positions', None)
            example_indices = ctx.data_batch.get('example_indices', None)

            if self.task in {'imdb', 'agnews'}:
                outputs = ctx.model(
                    input_ids=token_ids.to(ctx.device),
                    token_type_ids=token_type_ids.to(ctx.device),
                    attention_mask=attention_mask.to(ctx.device),
                    labels=labels.to(ctx.device),
                    contrast_monitor=ctx.contrast_monitor,
                    in_contrast_prepare=self._in_contrast_prepare,
                    example_indices=example_indices,
                )
                if not self._in_contrast_prepare:
                    ctx.batch_size = CtxVar(len(token_ids), LIFECYCLE.BATCH)
                    ctx.loss_batch = CtxVar(outputs.loss, LIFECYCLE.BATCH)
                    if self.use_contrastive_loss:
                        ctx.regular_loss_batch = CtxVar(
                            outputs.regular_loss, LIFECYCLE.BATCH)
                        ctx.contrastive_loss_batch = CtxVar(
                            outputs.contrastive_loss, LIFECYCLE.BATCH)
                    ctx.y_true = CtxVar(labels, LIFECYCLE.BATCH)
                    ctx.y_pred = CtxVar(outputs.logits.argmax(dim=-1),
                                        LIFECYCLE.BATCH)

            elif self.task in {'squad', 'newsqa'}:
                outputs = ctx.model(
                    input_ids=token_ids.to(ctx.device),
                    token_type_ids=token_type_ids.to(ctx.device),
                    attention_mask=attention_mask.to(ctx.device),
                    start_positions=start_positions.to(ctx.device),
                    end_positions=end_positions.to(ctx.device),
                    contrast_monitor=ctx.contrast_monitor,
                    in_contrast_prepare=self._in_contrast_prepare,
                    example_indices=example_indices,
                )
                if not self._in_contrast_prepare:
                    for i, example_idx in enumerate(example_indices):
                        encoded_input = ctx.get('{}_encoded'.format(
                            ctx.cur_split))[example_idx.item()]
                        unique_id = int(encoded_input.unique_id)
                        start_logits = \
                            outputs.logits[0][i].detach().cpu().tolist()
                        end_logits = \
                            outputs.logits[1][i].detach().cpu().tolist()
                        if ctx.cur_split != 'train':
                            if self.task == 'squad':
                                ctx.squad_results.append(
                                    SquadResult(unique_id, start_logits,
                                                end_logits))
                            elif self.task == 'newsqa':
                                ctx.newsqa_results.append(
                                    NewsQAResult(unique_id, start_logits,
                                                 end_logits))

                    ctx.batch_size = CtxVar(len(token_ids), LIFECYCLE.BATCH)
                    ctx.loss_batch = CtxVar(outputs.loss, LIFECYCLE.BATCH)
                    if self.use_contrastive_loss:
                        ctx.regular_loss_batch = CtxVar(
                            outputs.regular_loss, LIFECYCLE.BATCH)
                        ctx.contrastive_loss_batch = CtxVar(
                            outputs.contrastive_loss, LIFECYCLE.BATCH)
                    ctx.y_true = CtxVar(
                        torch.cat([start_positions, end_positions]),
                        LIFECYCLE.BATCH)
                    ctx.y_pred = CtxVar(
                        torch.cat(
                            [out.argmax(dim=-1) for out in outputs.logits]),
                        LIFECYCLE.BATCH)

            elif self.task in {'cnndm', 'msqg'}:
                if ctx.cur_split != 'test':
                    outputs = ctx.model(
                        input_ids=token_ids.to(ctx.device),
                        token_type_ids=token_type_ids.to(ctx.device),
                        attention_mask=attention_mask.to(ctx.device),
                        labels=labels.to(ctx.device),
                        contrast_monitor=ctx.contrast_monitor,
                        in_contrast_prepare=self._in_contrast_prepare,
                        example_indices=example_indices,
                    )
                    if not self._in_contrast_prepare:
                        ctx.batch_size = CtxVar(len(labels), LIFECYCLE.BATCH)
                        ctx.loss_batch = CtxVar(outputs.loss, LIFECYCLE.BATCH)
                        if self.use_contrastive_loss:
                            ctx.regular_loss_batch = CtxVar(
                                outputs.regular_loss, LIFECYCLE.BATCH)
                            ctx.contrastive_loss_batch = CtxVar(
                                outputs.contrastive_loss, LIFECYCLE.BATCH)

                        y_pred = outputs.logits.argmax(dim=-1)
                        y_true = labels[:, 1:]
                        non_padding_idx = y_true.ne(ctx.padding_idx)
                        ctx.y_true = CtxVar(y_true[non_padding_idx],
                                            LIFECYCLE.BATCH)
                        ctx.y_pred = CtxVar(y_pred[non_padding_idx],
                                            LIFECYCLE.BATCH)
                else:
                    outputs = ctx.model.generate(
                        input_ids=token_ids.to(ctx.device),
                        token_type_ids=token_type_ids.to(ctx.device),
                        attention_mask=attention_mask.to(ctx.device),
                    )
                    # save to file
                    out_str = ctx.tokenizer.batch_decode(outputs)
                    src_str = ctx.tokenizer.batch_decode(token_ids)
                    ref_str = ctx.tokenizer.batch_decode(labels)
                    for out, src, ref in zip(out_str, src_str, ref_str):
                        out = self._remove_special_tokens(out)
                        src = self._remove_special_tokens(src)
                        ref = self._remove_special_tokens(ref)
                        self.pred_file.write(out + '\n')
                        self.src_file.write(src + '\n')
                        self.tgt_file.write(ref + '\n')
                    self.pred_file.flush()
                    self.src_file.flush()
                    self.tgt_file.flush()

                    ctx.batch_size = CtxVar(len(labels), LIFECYCLE.BATCH)
                    ctx.y_pred = CtxVar(outputs, LIFECYCLE.BATCH)
                    ctx.y_true = CtxVar(labels[:, 1:], LIFECYCLE.BATCH)
                    return

        if self._in_contrast_prepare:
            ctx.batch_size = CtxVar(0, LIFECYCLE.BATCH)
            dec_out, dec_hidden, example_indices = \
                outputs.logits, outputs.hidden_states, outputs.example_indices
            if len(example_indices) > 0:
                for ex, out in zip(example_indices, dec_out.detach().cpu()):
                    ctx.contrast_monitor.update_dec_out(out, k=ex.item())
                for ex, hids in zip(example_indices,
                                    dec_hidden.detach().cpu()):
                    ctx.contrast_monitor.update_dec_hidden(hids, k=ex.item())
        else:
            ctx.loss_agg.update(ctx.loss_batch.detach().item(), ctx.batch_size)
            if self.use_contrastive_loss:
                if ctx.regular_loss_batch is not None and \
                        ctx.contrastive_loss_batch is not None:
                    ctx.regular_loss_agg.update(
                        ctx.regular_loss_batch.detach().item(), ctx.batch_size)
                    ctx.contrastive_loss_agg.update(
                        ctx.contrastive_loss_batch.detach().item(),
                        ctx.batch_size)

    def _hook_on_batch_forward_regularizer(self, ctx):
        if self._in_contrast_prepare:
            return
        super()._hook_on_batch_forward_regularizer(ctx)

    def _hook_on_batch_backward(self, ctx):
        if self._in_contrast_prepare:
            return

        cur_step = (ctx.cur_batch_i + 1) // ctx.grad_accum_count
        ctx.accum_steps += 1
        ctx.loss_task /= ctx.grad_accum_count
        ctx.loss_task.backward()

        if ctx.accum_steps == ctx.grad_accum_count:
            if ctx.grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(ctx.model.parameters(),
                                               ctx.grad_clip)
            ctx.optimizer.step()
            ctx.scheduler.step()
            ctx.optimizer.zero_grad()
            ctx.accum_steps = CtxVar(0, LIFECYCLE.ROUTINE)

        total_epoch = getattr(ctx, f'num_{ctx.cur_mode}_epoch', None)
        total_batch = getattr(ctx, f'num_{ctx.cur_mode}_batch', None) if \
            ctx.cur_epoch_i + 1 < total_epoch else \
            getattr(ctx, f'num_{ctx.cur_mode}_batch_last_epoch', None)
        if ctx.accum_steps == 0:
            if cur_step > 1 and (cur_step % ctx.cfg.trainer.disp_freq == 0
                                 or ctx.cur_batch_i + 1 == total_batch):
                y_true = ctx.y_true.detach().cpu().numpy()
                y_pred = ctx.y_pred.detach().cpu().numpy()
                if y_true.ndim == 1:
                    y_true = np.expand_dims(y_true, axis=-1)
                if y_pred.ndim == 1:
                    y_pred = np.expand_dims(y_pred, axis=-1)
                cur_acc = eval_acc(y_true, y_pred)

                log_str = 'Epoch: [{}/{}][{}/{}]\t' \
                          'LR: {:.2e}\t' \
                          'Acc: {:.4f}\t' \
                          'Loss: {loss.val:.4f} ({loss.avg:.4f})'\
                    .format(ctx.cur_epoch_i + 1,
                            total_epoch,
                            cur_step,
                            total_batch // ctx.grad_accum_count,
                            ctx.scheduler.get_last_lr()[0],
                            cur_acc,
                            loss=ctx.loss_agg)
                if self.use_contrastive_loss:
                    log_str += \
                        '\tRegular loss: {loss.val:.4f} ' \
                        '({loss.avg:.4f})'.format(loss=ctx.regular_loss_agg)
                    log_str += \
                        '\tContrastive loss: {loss.val:.4f} ' \
                        '({loss.avg:.4f})'.format(
                            loss=ctx.contrastive_loss_agg)
                if self.task == 'pretrain':
                    log_str += '\t({})'.format(self.pretrain_task)

                logger.info(log_str)

            if ctx.cur_batch_i + 1 == total_batch and ctx.cfg.federate.save_to:
                self._save_model(ctx)

    def _hook_on_batch_end(self, ctx):
        if self._in_contrast_prepare:
            return

        # update statistics
        ctx.num_samples += ctx.batch_size
        ctx.loss_batch_total += ctx.get(
            "loss_batch", torch.tensor(0.)).item() * ctx.batch_size
        ctx.loss_regular_total += float(ctx.get("loss_regular", 0.))

        # cache label for evaluate
        if self.task in {'pretrain', 'squad', 'newsqa', 'cnndm', 'msqg'}:
            ctx.ys_true = CtxVar([ctx.y_true.detach().cpu().numpy()],
                                 LIFECYCLE.ROUTINE)
            ctx.ys_pred = CtxVar([ctx.y_pred.detach().cpu().numpy()],
                                 LIFECYCLE.ROUTINE)
        else:
            ctx.ys_true.append(ctx.y_true.detach().cpu().numpy())
            ctx.ys_pred.append(ctx.y_pred.detach().cpu().numpy())

    def _hook_on_fit_end(self, ctx):
        if self.use_contrastive_loss and self.task != 'pretrain' and \
                ctx.cur_split == 'train':
            ctx.contrast_monitor.update_stat(ctx.contrast_monitor.stat + 1)
            return

        if ctx.cur_split != 'train':
            ctx.ys_true = CtxVar(np.concatenate(ctx.ys_true),
                                 LIFECYCLE.ROUTINE)
            ctx.ys_pred = CtxVar(np.concatenate(ctx.ys_pred),
                                 LIFECYCLE.ROUTINE)
            results = self.metric_calculator.eval(ctx)
            setattr(ctx, 'eval_metrics', results)

        if ctx.cur_split == 'test' and not self.finish_eval:
            if self.pred_file is not None:
                self.pred_file.close()
            if self.src_file is not None:
                self.src_file.close()
            if self.tgt_file is not None:
                self.tgt_file.close()
            self.finish_eval = True
