import copy
import torch
import torch.nn as nn

__all__ = ["SWAModel"]


class SWAModel(object):

    def __init__(self, model: nn.Module, device: str = "cpu"):
        # almost always SWA is controlled in CPU.
        super(SWAModel, self).__init__()
        self.model = copy.deepcopy(model).to(device)  # default to keep at CPU
        self.count = 0

    @torch.no_grad()
    def update_state(self, new_model: nn.Module) -> None:
        # parameters to average
        for p_swa, p_new in zip(self.model.parameters(), new_model.parameters()):
            device = p_swa.device
            p_new_ = p_new.detach().to(device)
            if self.count == 0:
                p_swa.copy_(p_new_)
            else:
                # p' = (p_swa * n + p_new * 1) / (n + 1)
                # p' = p_swa + (p_new - p_swa) / (n + 1)
                p_avg = p_swa + (p_new_ - p_swa) / (self.count + 1)
                p_swa.copy_(p_avg)

        # buffer to copy
        for b_swa, b_new in zip(self.model.buffers(), new_model.buffers()):
            device = b_swa.device
            b_new_ = b_new.detach().to(device)
            b_swa.copy_(b_new_)

        self.count += 1

    def state_dict(self) -> dict:
        state = dict()
        state["network"] = self.model.state_dict()
        state["count"] = self.count
        return state

    def load_state_dict(self, state_dict: dict) -> None:
        self.model.load_state_dict(state_dict["network"], strict=True)
        self.count = state_dict.get("count", 0)
