import numpy as np
from distributions.GMM import GMM
import os


def save_model(model, path: str, filename: str):
    if isinstance(model, GMM):
        save_gmm(model, path, filename)
    else:
        raise NotImplementedError("Saving not implemented for " + str(model.__class__))


def save_gmm(model: GMM, path: str, filename: str):
    means = np.stack([c.mean for c in model.components], axis=0)
    covars = np.stack([c.covar for c in model.components], axis=0)
    model_dict = {"weights": model.weight_distribution.p, "means": means, "covars": covars}
    np.savez_compressed(os.path.join(path, filename + ".npz"), **model_dict)

def load_cpp_gmm(path: str, filename: str):
    model_path = os.path.join(path, filename + ".npz")
    model_dict = dict(np.load(model_path))
    return GMM(model_dict["weights"], model_dict["means"], model_dict["covars"])