from distributions.GMM import GMM
from optimization.VonVIPS import *
import numpy as np
import os
import tensorflow_probability as tfp
from recording.Recorder import RecorderKeys as rec_keys
from recording.Recorder import Recorder
from recording.modules.UpdateModules import WeightUpdateRecMod, ComponentUpdateRecMod
tfd = tfp.distributions

def construct_initial_mixture(num_dimensions, num_initial_components, prior_scale):
    if np.isscalar(prior_scale):
        prior = tfd.MultivariateNormalDiag(loc=np.zeros(num_dimensions), scale_identity_multiplier=prior_scale)
    else:
        prior = tfd.MultivariateNormalDiag(loc=np.zeros(num_dimensions), scale_diag=prior_scale)

    initial_covs = prior.covariance().numpy().astype(np.float32) # use the same initial covariance that was used for sampling the mean

    weights = np.ones(num_initial_components) / num_initial_components
    means = np.zeros((num_initial_components, num_dimensions), dtype=np.float64)
    covs = np.ones((num_initial_components, num_dimensions, num_dimensions), dtype=np.float64)

    for i in range(0, num_initial_components):
        if num_initial_components == 1:
            means[i] = np.zeros(num_dimensions)
        else:
            means[i] = prior.sample(1).numpy()
        covs[i] = 1*initial_covs

    gmm = GMM(weights, means, covs)
    return gmm


def learn(target_dist_maker, model, config, groundtruth,
          path, do_plots=False):
    if do_plots:
        import matplotlib.pyplot as plt
        plt.ion()

    dirname = os.path.dirname(path)
    if not os.path.exists(dirname):
        os.makedirs(dirname)
    target_distribution = target_dist_maker()
    try:
        np.savez(path + '/target_mixture.npz', target_mean=target_distribution.target_means,
                 target_cov=target_distribution.target_covars)
    except:
        pass

    if len(groundtruth) != 0:
        groundtruth= groundtruth
    else:
        groundtruth = target_distribution.sample(10000).numpy()

    """ Recording """
    recorder_dict = {
        rec_keys.WEIGHTS_UPDATE: WeightUpdateRecMod(plot=False),
        rec_keys.COMPONENT_UPDATE: ComponentUpdateRecMod(plot=False, summarize=True),
    }

    # Set 'plot_realtime' to True for some nice visualization
    recorder = Recorder(recorder_dict, plot_realtime=True, save=False)

    if config.component_optimizer == "MORE":
        vips = MoreVIPS(config=config, savepath=path, target_distribution=target_distribution, initial_w=model.weights.numpy(), initial_m=model.means.numpy(), initial_c=model.covars.numpy(),
                 target_sample = groundtruth,recorder = recorder,withgrad=False)
    elif config.component_optimizer == "gMORE":
        vips = MoreVIPS(config=config, savepath=path, target_distribution=target_distribution, initial_w=model.weights.numpy(), initial_m=model.means.numpy(), initial_c=model.covars.numpy(),
                 target_sample = groundtruth,recorder = recorder,withgrad=True)
    elif config.component_optimizer == "gMORE_NFO":
        vips = MoreVIPS(config=config, savepath=path, target_distribution=target_distribution, initial_w=model.weights.numpy(), initial_m=model.means.numpy(), initial_c=model.covars.numpy(),
                 target_sample = groundtruth,recorder = recorder,withgrad=True)
        vips._gmm_learner._no_first_order = True
    elif config.component_optimizer == "VON":
        vips = VonVIPS(config=config, savepath=path, target_distribution=target_distribution, initial_w=model.weights.numpy(), initial_m=model.means.numpy(), initial_c=model.covars.numpy(),
                 target_sample = groundtruth,recorder = recorder,withgrad=True)
    elif config.component_optimizer == "VOGN":
        vips = VognVIPS(exploit_Bayesian=True, savepath=path, config=config, target_distribution=target_distribution,
                        initial_w=model.weights.numpy(), initial_m=model.means.numpy(), initial_c=model.covars.numpy(),
                        target_sample=groundtruth, recorder=recorder, withgrad=True)
    elif config.component_optimizer == "GM":
        vips = VognVIPS(exploit_Bayesian=False, savepath=path, config=config, target_distribution=target_distribution, initial_w=model.weights.numpy(), initial_m=model.means.numpy(), initial_c=model.covars.numpy(),
                 target_sample = groundtruth,recorder = recorder,withgrad=True)

    vips.train()
    print("done")

