import os
import torch
from termcolor import colored


class ModelSaverLoaderCallback:
    def __init__(self, result_path, model_filename, opt):
        self.model_filename = model_filename
        self.opt = opt
        self.result_path = self._make_result_path(result_path)
        
        init_metric = -1e3
        self.inscrease = True
        if self.opt.ckpt_metric == 'loss':
            init_metric = 1e3
            self.inscrease = False # decrease metric is better
            
        self.metric = init_metric

    def _make_result_path(self, result_path):
        if self.opt.transfer_experiment:
            full_save_path = os.path.join(result_path, 'mosi', 'checkpoints')
        else:
            full_save_path = os.path.join(result_path, self.opt.dataset, 'checkpoints')
        os.makedirs(full_save_path, exist_ok=True)

        return full_save_path

    def save_cpkt(self, model, metric):
        if metric > self.metric and self.inscrease:
            print(colored('Saving best checkpoint...','green', 'on_grey'))
            self.metric = metric
            torch.save(model.state_dict(), os.path.join(self.result_path, self.model_filename + '_best.pt'))
        elif metric < self.metric and not self.inscrease:
            print(colored('Saving best checkpoint...','green', 'on_grey'))
            self.metric = metric
            torch.save(model.state_dict(), os.path.join(self.result_path, self.model_filename + '_best.pt'))
        torch.save(model.state_dict(), os.path.join(self.result_path, self.model_filename + '_last.pt'))

    def load_cpkt(self, model, last=False):
        if last:
            model.load_state_dict(torch.load(os.path.join(self.result_path, self.model_filename + '_last.pt')))
        else:
            model.load_state_dict(torch.load(os.path.join(self.result_path, self.model_filename + '_best.pt')))
        
        return model