from .utils import *


class BaseNet(object):
    def __init__(self):
        cprint('c', '\nNet:')

    def get_nb_parameters(self):
        return np.sum(p.numel() for p in self.model.parameters())

    def set_mode_train(self, train=True):
        if train:
            self.model.train()
        else:
            self.model.eval()
            
    def double(self):
        self.model.double()

    def update_lr(self, epoch, gamma=0.99):
        self.epoch += 1
        if self.schedule is not None:
            if len(self.schedule) == 0 or epoch in self.schedule:
                self.lr *= gamma
                print('learning rate: %f  (%d)\n' % self.lr, epoch)
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = self.lr

    def save(self, filename):
        cprint('c', 'Writting %s\n' % filename)
        model_dict = {
            'epoch': self.epoch,
            'lr': self.lr_hyperparam,
            'model': self.model,
            'optimizer': self.optimizer,
            'input_dim': self.input_dim, 
            'output_dim': self.output_dim}
        torch.save(model_dict, filename)
        return model_dict

    def load(self, filename):
        cprint('c', 'Reading %s\n' % filename)
        state_dict = torch.load(filename)
        self.epoch = state_dict['epoch']
        self.lr = state_dict['lr']
        self.model = state_dict['model']
        self.optimizer = state_dict['optimizer']
        self.input_dim = state_dict['input_dim']
        self.output_dim = state_dict['output_dim']
        print('  restoring epoch: %d, lr: %f' % (self.epoch, self.lr))
        return self.epoch