import argparse
import logging
import sys
import os
from configparser import ConfigParser
import matplotlib.pyplot as plt
import torch
from torch import optim
import numpy as np

from utils.helpers import FormatterNoDuplicate, check_bounds, set_seed
from utils.visualize import Visualizer
from utils.viz_helpers import get_samples
from main import RES_DIR
from disvae.utils.modelIO import load_model, load_metadata



set_seed(125)
experiment_name = 'debug'
model_dir = os.path.join(RES_DIR, experiment_name)
meta_data = load_metadata(model_dir)
model = load_model(model_dir)
model.eval()  # don't sample from latent: use mean
dataset = meta_data['dataset']
viz = Visualizer(model=model,
                 model_dir=model_dir,
                 dataset=dataset,
                 max_traversal=2,
                 loss_of_interest='kl_loss_',
                 upsample_factor=1)
size = (6, 7)
# same samples for all plots: sample max then take first `x`data  for all plots
num_samples = size[0] * size[1]
samples = get_samples(dataset, 200, idcs=[])


def show_images_grid(imgs_, num_images=25):
    ncols = int(np.ceil(num_images ** 0.5))
    nrows = int(np.ceil(num_images / ncols))
    _, axes = plt.subplots(ncols, nrows, figsize=(nrows * 3, ncols * 3))
    axes = axes.flatten()

    for ax_i, ax in enumerate(axes):
        if ax_i < num_images:
            ax.imshow(imgs_[ax_i], cmap='Greys_r', interpolation='nearest')
            ax.set_xticks([])
            ax.set_yticks([])
        else:
            ax.axis('off')


def show_density(imgs):
    _, ax = plt.subplots()
    ax.imshow(imgs.mean(axis=0), interpolation='nearest', cmap='Greys_r')
    ax.grid('off')
    ax.set_xticks([])
    ax.set_yticks([])

# n_per_latent= 7 # number of columns. traversal 的数量
#
# reconstructions = viz.reconstruct(samples,size=(2, n_per_latent),
#                                            is_force_return=True)
#
#
# n_latents = 10
# latent_samples = [viz._traverse_line(dim, n_per_latent, data=samples)
#                   for dim in range(viz.latent_dim)]
# decoded_traversal = viz._decode_latents(torch.cat(latent_samples, dim=0))
#
# viz.gif_traversals(samples[:6, ...], n_latents=7)
# viz.reconstruct_traverse(samples,
#                      is_posterior=True,
#                      n_latents=size[0],
#                      n_per_latent=7,
#                      is_show_text=True)
