from copy import deepcopy

import torch
import torch.nn as nn
import numpy as np


class BaseActor(nn.Module):
    def __init__(self):
        super(BaseActor, self).__init__()

    def set_params(self, params):
        """Set the params of the network to the given parameters"""
        cpt = 0
        for param in self.parameters():
            tmp = np.product(param.size())
            param.data.copy_(torch.from_numpy(
                params[cpt:cpt + tmp]).view(param.size()))
            cpt += tmp

    def get_params(self):
        """Return parameters of the actor"""
        return np.hstack([v.data.cpu().numpy().flatten() for v in
                                   self.parameters()]).copy()

    def get_size(self):
        """Return the number of parameters of the network"""
        return self.get_params().shape[0]

    def load_model(self, filename, net_name):
        """Load the model"""
        if filename is None:
            return

        self.load_state_dict(
            torch.load('{}/{}.pkl'.format(filename, net_name),
                       map_location=lambda storage, loc: storage)
        )

    def save_model(self, output, net_name):
        """Saves the model"""
        torch.save(
            self.state_dict(),
            '{}/{}.pkl'.format(output, net_name)
        )
