import copy

import torch

from .abstract import TargetNetwork
from all2.nn import RLNetwork


class PolyakTarget(TargetNetwork):
    """TargetNetwork that updates using polyak averaging"""

    def __init__(self, rate):
        self._source_encoder = None
        self._source_decoder = None
        self._target_encoder = None
        self._target_decoder = None
        self._rate = rate

    def __call__(self, *inputs):
        with torch.no_grad():
            return self._target_decoder(self._target_encoder(*inputs))

    def encode(self, *inputs):
        with torch.no_grad():
            return self._target_encoder(*inputs)

    def decode(self, *inputs):
        with torch.no_grad():
            return self._target_decoder(*inputs)

    def init(self, encoder, decoder):
        self._source_encoder = RLNetwork(encoder)
        self._source_decoder = decoder
        self._target_encoder = RLNetwork(copy.deepcopy(encoder))
        self._target_decoder = copy.deepcopy(decoder)

    def update(self):
        for target_param, source_param in zip(
            self._target_encoder.parameters(), self._source_encoder.parameters()
        ):
            target_param.data.copy_(
                target_param.data * (1.0 - self._rate) + source_param.data * self._rate
            )
        for target_param, source_param in zip(
            self._target_decoder.parameters(), self._source_decoder.parameters()
        ):
            target_param.data.copy_(
                target_param.data * (1.0 - self._rate) + source_param.data * self._rate
            )
