import numpy as np
import torch
from typing import Dict, List
from collections import OrderedDict
from hyper_params import CHANNEL

NUM_CHANNELS = CHANNEL


def get_parameters(net:torch.nn.Module) -> List[np.ndarray]: # Access the parameters of a neural network 
  return [val.cpu().numpy() for _, val in net.state_dict().items()]

def set_parameters(net, parameters: List[np.ndarray]): # modify the parameters of a neural network
  params_dict = zip(net.state_dict().keys(), parameters)
  state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
  net.load_state_dict(state_dict, strict=False)

def generate_random_protos_gaussian(mean:Dict, covariance:Dict, num_samples_per_class) -> Dict:
    fake_protos = {}
    rng = np.random.default_rng(seed=1)
    # Initialization:
    for k in mean.keys():
        fake_protos[k] = []

    for k in fake_protos.keys():
        m = mean[k].detach().cpu().numpy()
        cov = covariance[k].detach().cpu().numpy()
        output = rng.multivariate_normal(mean=m, cov=cov, size=num_samples_per_class, check_valid='warn', method='cholesky')
        for i in range(output.shape[0]):
            fake_protos[k].append(torch.from_numpy(output[i]).float())
    return fake_protos

def get_aw_gradients(aw):
    wp = 0
    for l in aw.keys():
        wp += torch.norm(aw[l].grad)
    return wp

def copy_proto_dict(protodict:Dict):
    newdict = {}
    with torch.no_grad():
        for i in protodict.keys():
            newdict[i] = []
            for z in protodict[i]:
                newdict[i].append(z.detach().clone().requires_grad_(False))
    return newdict