import numpy as np
import torch.nn.functional as F
from torch import nn

import params
import torch
from cuda import use_cuda
from model.encoder import * 


class BaseModel(nn.Module):
    def load(self, path):
        if use_cuda:
            params = torch.load(path)
        else:
            params = torch.load(path, map_location=lambda storage, loc: storage)

        params = params['model']

        state = self.state_dict()
        for name, val in params.items():
            if name in state:
                assert state[name].shape == val.shape, "%s size has changed from %s to %s" % \
                                                       (name, state[name].shape, val.shape)
                state[name].copy_(val)
            else:
                print("WARNING: %s not in model during model loading!" % name)

    def save(self, path):
        #TODO: save my PE model
        torch.save(self.state_dict(), path)