from spaghettini import quick_register

import torch
import copy

from torch.nn import Module


@quick_register
class ProximalRegWrapper(Module):
    def __init__(self, module):
        super().__init__()
        self.curr_model = module
        self.trailing_model = copy.deepcopy(module)
        self.sync()

    def sync(self):
        params_curr = self.curr_model.named_parameters()
        params_trailing = self.trailing_model.named_parameters()

        dict_params_trailing = dict(params_trailing)

        for name1, param1 in params_curr:
            if name1 in dict_params_trailing:
                dict_params_trailing[name1].data.copy_(param1.data)

    def is_synced(self):
        params_curr = self.curr_model.named_parameters()
        params_trailing = self.trailing_model.named_parameters()

        dict_params_trailing = dict(params_trailing)

        for name1, param1 in params_curr:
            if name1 in dict_params_trailing:
                if torch.allclose(input=param1, other=dict_params_trailing[name1]) is False:
                    print(f"Parameters named {name1} don't match. ")
                    return False
        return True

    def forward(self, *args, **kwargs):
        return self.curr_model(*args, **kwargs)


if __name__ == "__main__":
    """
    Run from root. 
    python -m src.dl.models.proximal_reg_wrapper
    """
    from src.dl.models.fully_connected import FCNetFixedWidth

    test_num = 3

    if test_num == 0:
        # Check if syncing is implemented properly.
        module_ = FCNetFixedWidth(num_inputs=256)
        dup_mod_ = ProximalRegWrapper(module=module_)
        dup_mod_.sync()
        xs = torch.rand(size=(5, 256))
        out1 = dup_mod_.curr_model(xs)
        out2 = dup_mod_.trailing_model(xs)
        correct = torch.allclose(input=out1, other=out2)
        print(f"If following True, then sync worked: {correct}")

    if test_num == 1:
        # Test forward pass.
        module_ = FCNetFixedWidth(num_inputs=256)
        dup_mod_ = ProximalRegWrapper(module=module_)
        xs_ = torch.rand(size=(5, 256))
        outs = dup_mod_(xs_)
        print(outs)

    if test_num == 3:
        # Test the type of ProximalRegWrapper.
        module_ = FCNetFixedWidth(num_inputs=256)
        dup_mod_ = ProximalRegWrapper(module=module_)
        isinstance(dup_mod_, ProximalRegWrapper)
