import numpy as np
import torch
from torch import nn

# --------------------------- Init Methods ---------------------------

def l2_init(model, scale=1e-2):
    with torch.no_grad():
        for p in model.parameters():
            noise = torch.randn_like(p) * scale
            p.add_(noise)

def l2_init_norm_control(model, eps=1e-2):
    with torch.no_grad():
        all_params = torch.cat([p.view(-1) for p in model.parameters()])
        noise = torch.randn_like(all_params)
        noise = noise / noise.norm() * eps
        offset = 0

        for p in model.parameters():
            n = p.numel()
            p.add_(noise[offset:offset+n].view_as(p))
            offset += n

def reset_params_recursive(model: nn.Module):
    for p in model.modules():
        if isinstance(p, (nn.Linear, nn.LayerNorm, nn.GRUCell, nn.GRU)):
            p.reset_parameters()

def copy_params_recursive(model_targ: nn.Module, model_src: nn.Module):
    # Require the same structure
    for p_src, p_targ in zip(model_src.parameters(), model_targ.parameters()):
        p_targ.data.copy_(p_src.data)

# --------------------------- NCE related ----------------------------------
def direct_gaussian(x, scale=0.1):
    return x + torch.randn_like(x) * scale

# --------------------------- Sampling related ----------------------------------
def unif_wo_replace_th(shape, device=None):
    random_scores = torch.rand(shape, device=device)
    return random_scores.argsort(dim=-1)

def unif_wo_replace_np(shape):
    random_scores = np.random.rand(*shape)
    return random_scores.argsort(axis=-1)

# --------------------------- Reg Losses ----------------------------------
def get_ewc_fisher(loss, params):
    loss.backward()
    fisher = []
    for p in params:
        if p.grad is None:
            fisher.append(torch.zeros_like(p))
        else:
            fisher.append(p.grad.detach().pow(2))
    
    return fisher
    
def get_l2_loss(params, rec, fisher=None):
    l2_loss = torch.tensor(0., device=params[0].device)
    if fisher is None: 
        fisher = [1] * len(params)
    for p, p_rec, f in zip(params, rec, fisher):
        l2_loss += (f * (p - p_rec).pow(2)).sum()
    return l2_loss