import argparse
import json
import os

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
from torchvision.utils import save_image

from code.composed_model import create_vi_model
from code.exp_utils import compute_variance, load_base_model, ForwardOp
from code.utils import print_model_info
from code.vi import train_cond_model


def run_for_digit(args, digit, base_model, forward_op):
    cur_dir = os.path.join(args.root_dir, f'digit_{digit}')
    if os.path.exists(cur_dir) and len(os.listdir(cur_dir)) > 0:
        print(f'Skipping digit {digit} -- nonempty directory: {cur_dir}')
        return
    else:
        os.makedirs(cur_dir, exist_ok=True)

    # Create model and train
    cond_model = create_vi_model('mnist', args.cond_model, seed=args.seed).to(args.device)
    print_model_info(cond_model)

    assert forward_op.type_ == 'module'
    x_star_obs = torch.ones(1, dtype=torch.float32, device=args.device)
    cond_model, steps_to_samples, stats = train_cond_model(
        base_model=base_model, cond_model=cond_model, x_star=None,
        forward_op=forward_op, device=args.device, sigma=args.sigma,
        batch_size=args.batch_size, learning_rate=args.learning_rate,
        num_steps=args.num_steps, temps=args.temps, use_ema=args.ema,
        sample_steps=args.sample_steps, x_star_obs=x_star_obs,
    )

    # Dump a bunch of stuff
    steps_to_variances = {}
    for step, temps_to_samples in steps_to_samples.items():
        steps_to_variances[step] = {}
        for temp, samples in temps_to_samples.items():
            save_image(samples,
                       os.path.join(cur_dir, f'samples_step={step}_temp={temp}.png'),
                       nrow=8, pad_value=1, range=(0, 1))
            steps_to_variances[step][temp] = compute_variance(samples)
    torch.save(steps_to_samples, os.path.join(cur_dir, 'samples.pt'))
    torch.save(steps_to_variances, os.path.join('variances.pt'))

    # Plot variances
    plt.clf()
    n_steps = len(steps_to_variances)
    n_temps = len(args.temps)
    fig, axes = plt.subplots(n_steps, n_temps, figsize=(4 * n_temps, 4 * n_steps))
    for idx_step, step in enumerate(sorted(steps_to_variances.keys())):
        for idx_temp, temp in enumerate(sorted(args.temps)):
            ax = axes[idx_step, idx_temp]
            v = steps_to_variances[step][temp]
            sns.heatmap(v.squeeze().numpy(), vmin=0.0, vmax=1.0, cbar=False, ax=ax)
            ax.set(xlabel=None, ylabel=None, xticks=[], yticks=[], title=f'Step: {step}, Temp: {temp}', aspect=1.0)
    fig.tight_layout(pad=2.5)
    fig.savefig(os.path.join(cur_dir, 'variances.png'))
    plt.close(fig)

    torch.save(stats, os.path.join(cur_dir, f'stats.pt'))
    torch.save(cond_model.state_dict(), os.path.join(cur_dir, f'cond_model_state_dict.pt'))


def main(args):
    # Create & load base model
    print('Loading base model...')
    base_model = load_base_model('mnist', args.base_ckpt)
    base_model.to(args.device)
    print_model_info(base_model)


    for digit in range(10):
        forward_op = ForwardOp(f'digit{digit}')
        print(f'Running for digit {digit}...')
        run_for_digit(args, digit, base_model, forward_op)
        print()


if __name__ == '__main__':
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # Basic stuff
    parser.add_argument('--root_dir', type=str, required=True)

    # Model config
    parser.add_argument('--base_ckpt', type=str, required=True)
    parser.add_argument('--cond_model', type=str, default='realnvp')

    # Task parameters
    parser.add_argument('--sigma', type=float, default=0.1)
    parser.add_argument('--temps', type=float, nargs='+', default=[0.7, 0.8, 0.9, 1.0])
    parser.add_argument('--dist', type=str, default='l2', choices=['l2', 'lpips_vgg'])

    # Training related
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--learning_rate', type=float, default=1e-3)
    parser.add_argument('--num_steps', type=int, default=2000)
    parser.add_argument('--sample_steps', type=int, nargs='+', required=True)
    parser.add_argument('--ema', type=bool, default=False)

    # Misc
    parser.add_argument('--seed', type=int, default=None)

    args = parser.parse_args()

    if not os.path.exists(args.root_dir):
        os.makedirs(args.root_dir)

    args.device = torch.device('cuda:0')
    main(args)
