import copy
import random
from functools import wraps

import torch
from torch import nn
import torch.nn.functional as F

class EMA():
    def __init__(self, beta):
        super().__init__()
        self.beta = beta

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

def update_moving_average(ema_updater, ma_model, current_model):
    for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
        old_weight, up_weight = ma_params.data, current_params.data
        ma_params.data = ema_updater.update_average(old_weight, up_weight)

class  TargetNetwork(nn.Module):
    def __init__(self, moving_average_decay = 0.99):
        super().__init__()

        self.target_F = None
        self.target_Z = None
        self.target_ema_updater = EMA(moving_average_decay)

    def forward(self, x, online_F, online_Z, return_feature=True):

        if self.target_F is None and self.target_Z is None:
            self.target_F = copy.deepcopy(online_F)
            self.target_Z = copy.deepcopy(online_Z)
            self.target_F.eval()
            self.target_Z.eval()

        x = self.target_F(x)
        if return_feature:
            return x
        x = self.target_Z(x, None,  return_meta_dist=True)

        return x

    def reset_moving_average(self):
        del self.target_F
        del self.target_Z
        self.target_F = None
        self.target_Z = None

    def update_moving_average(self, online_F, online_Z):
        assert self.target_F is not None, 'Target F has not been created'
        assert self.target_Z is not None, 'Target Z has not been created'
        update_moving_average(self.target_ema_updater, self.target_F, online_F)
        update_moving_average(self.target_ema_updater, self.target_Z, online_Z)