#%%
import os
import torch
import numpy as np
from omegaconf import OmegaConf
import hydra
from hydra.utils import instantiate

from models.gmm import construct_n_star, GScore, GMScore, GaussianMixture, true_posterior, EDMPrecond, \
    construct_grid
from utils.plots import plot_samples_joint
from inverse_problems.linear import LinearOperator


@hydra.main(version_base="1.3", config_path="configs", config_name="config")
def main(config):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # set random seed
    torch.manual_seed(config.seed)
    np.random.seed(config.seed)
    exp_dir = os.path.join('exps', 'gmm', config.algorithm.name, config.exp_name)
    os.makedirs(exp_dir, exist_ok=True)
    fig_dir = os.path.join(exp_dir, 'figs')
    os.makedirs(fig_dir, exist_ok=True)
    OmegaConf.save(config, os.path.join(exp_dir, 'config.yaml'))
    # ------ construct prior   -----
    cov_scale = 2.0
    num_samples = 1024
    num_modes = 4
    # means = construct_n_star(num_modes, scale=11.3137)
    means = construct_grid(num_modes)
    covs = [cov_scale * np.eye(2) for _ in range(means.shape[0])]
    prior_dist = GaussianMixture(mean=means, cov=cov_scale) # default: uniform prior
    prior_samples = prior_dist.generate(num_samples).cpu().numpy()

    score_model = GMScore(torch.from_numpy(means).to(device), cov_scale) # uniform prior by default
    model = EDMPrecond(score_model)
    # ------ construct forward model    -----
    H = np.array([[-0.1117, -1.1456]])  # 1 x 2
    noise_std = 1.5

    y = np.array([-15.5626])  # (1,)
    forward_op = LinearOperator(H, device=device, sigma_noise=noise_std)

    # compute the ground truth posterior
    posterior_dist = true_posterior(y, H, prior_dist.prior, means, covs, noise_std)
    gt_samples = posterior_dist.generate(num_samples).cpu().numpy()

    # visualize the prior and posterior samples
    sample_dict = {
        'prior': prior_samples,
        'posterior': gt_samples
    }
    plot_samples_joint(sample_dict, save_path=os.path.join(fig_dir, 'prior_posterior_samples.png'))


    # ------ instantiate algorithm -----
    algo = instantiate(config.algorithm.method, forward_op=forward_op, net=model)
    results = []
    observation = torch.from_numpy(y).to(device)
    # for i in range(config.num_samples):
    #     algo_samples = algo.inference(observation, num_samples=1)
    #     results.append(algo_samples)
    # results = torch.cat(results, dim=0).detach().cpu().numpy()
    if 'dpg' in config.algorithm.method._target_ or 'scg' in config.algorithm.method._target_:
        algo_samples = []
        for i in range(config.num_samples):
            sample = algo.inference(observation.to(forward_op.dtype), num_samples=1).detach()
            algo_samples.append(sample)
        algo_samples = torch.cat(algo_samples, dim=0)
    else:
        algo_samples = algo.inference(observation.to(forward_op.dtype), num_samples=config.num_samples).detach()
    results.append(algo_samples)

    results = torch.cat(results, dim=0).cpu().numpy()

    sample_dict = {
        # 'prior': prior_samples,
        'posterior': gt_samples,
        f'{config.algorithm.name}': results,
    }
    plot_samples_joint(sample_dict, save_path=os.path.join(fig_dir, 'algorithm_samples.png'))
    torch.save(sample_dict, os.path.join(fig_dir, 'results.pt'))


if __name__ == '__main__':
    main()