from collections import OrderedDict

import torch


class CpuMemoryModel:

    def __init__(self, model):
        self.layer_params = OrderedDict()

        with torch.no_grad():
            for pname, param in model.named_parameters():
                self.layer_params[pname] = param.detach().to("cpu", non_blocking=True)
        torch.cuda.synchronize()

    def named_parameters(self):
        for name, param in self.layer_params.items():
            yield name, param

    def update_cpu_weights(self, model, replace_zeros=False):
        with torch.no_grad():
            for pname, param in model.named_parameters():
                assert pname in self.layer_params.keys()
                assert param.shape == self.layer_params[pname].shape
                if replace_zeros:
                    self.layer_params[pname].zero_()
                else:
                    self.layer_params[pname].copy_(param, non_blocking=True)
        torch.cuda.synchronize()

    def __getitem__(self, key):
        assert key in self.layer_params, f"{key=} not in self.layer_params"
        return self.layer_params[key]
