import numpy as np 
import matplotlib.pyplot as plt
from scipy.stats import norm
import torch

def create_mixture(M, D):
    means = torch.normal(0, 1, size=(M, D))
    covs = torch.normal(0, 1, size=(M, D))**2
    mixture_weights = 1/M
    return means, covs, mixture_weights

def sample_mixture(mus, sigs, weights, numb):
    cat = torch.distributions.categorical.Categorical(torch.tensor(weights))
    comps = cat.sample([numb])
    samp = torch.zeros((numb, len(mus[0])))
    for i in range(len(weights)):
        samp[comps==i, :] = torch.distributions.normal.Normal(mus[i], sigs[i])\
                .sample([(comps==i).sum()])
    return samp, comps

def mixture_pdf(samp, mus, sigs, weights):
    prob = torch.zeros(samp.shape[0])
    for j in range(len(weights)):
        prob += weights[j]*torch.exp(torch.distributions.normal.Normal(mus[j,:], sigs[j,:])\
                .log_prob(samp).sum(1))
    return prob
