import copy

import torch
import numpy as np

class TorchUtils(object):
    def __init__(self):
        pass

    @staticmethod
    def set_device(device_id):
        """
        set the GPU id to be run.
        :param device_id:
        :return:
        """
        torch.cuda.set_device(device=device_id)

    @staticmethod
    def set_random_seed(seed):
        seed_id = seed
        torch.manual_seed(seed_id)
        np.random.seed(seed_id)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    @staticmethod
    def reset_grad(model, grad_dict):
        for name, param in model.named_parameters():
            if name not in grad_dict:
                raise ValueError("cannot find {} in reset_grad".format(name))
            else:
                param.grad.copy_(grad_dict[name])

    @staticmethod
    def copy_model_parameters(module_src, module_dest):
        params_src = module_src.named_parameters()
        params_dest = module_dest.named_parameters()

        dict_dest = dict(params_dest)

        for name, param in params_src:
            if name in dict_dest:
                dict_dest[name].data = param.data.detach().clone()

    @staticmethod
    def update_model_parameters(module_src, module_dest, lr=0.001):
        params_src = module_src.named_parameters()
        params_dest = module_dest.named_parameters()

        dict_dest = dict(params_dest)

        for name, param in params_src:
            if name in dict_dest:
                dict_dest[name].data.copy_((1-lr) * dict_dest[name].data + lr * param.data)

    @staticmethod
    def get_model_diff(model_src, model_dest):
        params_src = model_src.named_parameters()
        params_dest = model_dest.named_parameters()

        dict_dest = dict(params_dest)
        model_diff_dict = {}
        for name, param in params_src:
            if name in dict_dest:
                model_diff_dict[name] = (param.data - dict_dest[name].data).detach().clone()
        return model_diff_dict

    @staticmethod
    def set_parameter_requires_grad(model, enable=True):
        for param in model.parameters():
            param.requires_grad = enable
