import os
import torch
import sys
from torch import nn
from os.path import join

from fastNLP.core.callback import Callback


class lrCallback(Callback):
    def __init__(self, scheduler):
        super(lrCallback, self).__init__()
        self.scheduler = scheduler

    def on_step_end(self):
        self.scheduler.step()


class SaveCkptCallback(Callback):
    def __init__(self, args):
        super(SaveCkptCallback, self).__init__()
        self.args = args
        self.cl_loss_list = []
        self.nll_loss_list = []

    def _save_this_model(self):
        name = "epoch-{}_step-{}.pt".format(self.epoch, self.step // self.update_every)
        save_dir = os.path.join(self.args.save_path, self.trainer.start_time)
        try:
            self._save_model(self.model, model_name=name, save_dir=save_dir)
        except Exception as e:
            print(f"The following exception:{e} happens when save model to {save_dir}.")

    def _save_model(self, model, model_name, save_dir):
        model_path = os.path.join(save_dir, model_name)
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir, exist_ok=True)
        # state_dict = model.state_dict()
        # torch.save(state_dict, model_path)
        torch.save(model, model_path)
        optm_path = model_path.replace(".pt", ".optm")
        torch.save(self.trainer.optimizer.state_dict(), optm_path)
        print(f" ============= save model at {model_path} ============= ")
        print(f" ============= save optimizer at {optm_path} ============= ")

    def on_step_end(self):
        # warm up
        if self.epoch >= 0:
            if self.step % self.update_every == 0 and self.step > 0:
                if self.step // self.update_every % self.args.save_steps == 0:
                    self._save_this_model()


    def on_epoch_end(self):
        self.pbar.write('Epoch {} is done !!!'.format(self.epoch))

    def on_loss_begin(self, batch_y, predict_y):
        cl_loss = predict_y["cl_loss"].detach().cpu().item()
        nll_loss = predict_y["loss"].detach().cpu().item() - cl_loss
        self.cl_loss_list.append(cl_loss)
        self.nll_loss_list.append(nll_loss)
        if self.step // self.update_every % (self.args.save_steps // 4) == 0:
            self.pbar.write(
                f'Contrastive loss is {sum(self.cl_loss_list) / len(self.cl_loss_list)}, nll_loss is {sum(self.nll_loss_list) / len(self.nll_loss_list)}')
            self.cl_loss_list = []
            self.nll_loss_list = []

    def on_exception(self, exception):
        if isinstance(exception, KeyboardInterrupt):
            os.system("pkill -f train_distributed.py")
        else:
            raise exception
