import torch
import json
from torch.distributions import Categorical, MultivariateNormal, MixtureSameFamily

class MixtureModel:
    def __init__(self, config_dir="../../configs", device="cpu"):
        # Load configuration files
        self.device = device
        self.means = self._load_config(f"{config_dir}/mixturemodel.json")['means']
        self.covs = self._load_config(f"{config_dir}/mixturemodel.json")['covs']
        self.mixture_weights = self._load_config(f"{config_dir}/mixturemodel.json")['weights']

        # Convert to tensors
        self.means = torch.tensor(self.means, device=self.device)
        self.covs = torch.tensor(self.covs, device=self.device)
        self.mixture_weights = torch.tensor(self.mixture_weights, device=self.device)

        # Create the mixture model
        self.dist = MixtureSameFamily(
            Categorical(self.mixture_weights),
            MultivariateNormal(self.means, self.covs)
        )

    def _load_config(self, filepath):
        with open(filepath, 'r') as file:
            return json.load(file)

    def sample(self, num_samples):
        return self.dist.sample((num_samples,))

    def log_prob(self, samples):
        return self.dist.log_prob(samples)