import argparse
import json
import os

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
from torchvision.utils import save_image

from code.composed_model import create_vi_model
from code.exp_utils import compute_variance, load_base_model, get_eval_data, ForwardOp
from code.utils import print_model_info
from code.vi import train_cond_model


def run_for_index(args, index, data_index, data, base_model, forward_op):
    cur_dir = os.path.join(args.root_dir, f'index={index:05d}')
    if os.path.exists(cur_dir) and len(os.listdir(cur_dir)) > 0:
        print(f'Skipping index {index} -- nonempty directory: {cur_dir}')
        return
    else:
        os.makedirs(cur_dir, exist_ok=True)

    # Get target example
    x_star = data[data_index:data_index + 1].to(args.device)
    x_star = x_star.float() / 256.

    # Create model and train
    cond_model = create_vi_model(args.data, args.cond_model, seed=args.seed).to(args.device)
    print_model_info(cond_model)

    cond_model, steps_to_samples, stats = train_cond_model(
        base_model=base_model, cond_model=cond_model, x_star=x_star,
        forward_op=forward_op, device=args.device, sigma=args.sigma,
        batch_size=args.batch_size, learning_rate=args.learning_rate,
        num_steps=args.num_steps, temps=args.temps, use_ema=args.ema,
        sample_steps=args.sample_steps
    )

    # Dump a bunch of stuff
    steps_to_variances = {}
    for step, temps_to_samples in steps_to_samples.items():
        steps_to_variances[step] = {}
        for temp, samples in temps_to_samples.items():
            save_image(torch.cat([x_star.cpu(), forward_op.visualize(x_star.cpu()), samples[:30]], dim=0),
                       os.path.join(cur_dir, f'samples_step={step}_temp={temp}.png'),
                       nrow=8, pad_value=1, range=(0, 1))
            steps_to_variances[step][temp] = compute_variance(samples)
    torch.save(steps_to_samples, os.path.join(cur_dir, 'samples.pt'))
    torch.save(steps_to_variances, os.path.join('variances.pt'))

    # Plot variances
    plt.clf()
    n_steps = len(steps_to_variances)
    n_temps = len(args.temps)
    fig, axes = plt.subplots(n_steps, n_temps, figsize=(4 * n_temps, 4 * n_steps))
    for idx_step, step in enumerate(sorted(steps_to_variances.keys())):
        for idx_temp, temp in enumerate(sorted(args.temps)):
            ax = axes[idx_step, idx_temp]
            v = steps_to_variances[step][temp]
            sns.heatmap(v.squeeze().numpy(), vmin=0.0, vmax=1.0, cbar=False, ax=ax)
            ax.set(xlabel=None, ylabel=None, xticks=[], yticks=[], title=f'Step: {step}, Temp: {temp}', aspect=1.0)
    fig.tight_layout(pad=2.5)
    fig.savefig(os.path.join(cur_dir, 'variances.png'))
    plt.close(fig)

    stats.data_index = data_index
    torch.save(stats, os.path.join(cur_dir, f'stats.pt'))
    torch.save(cond_model.state_dict(), os.path.join(cur_dir, f'cond_model_state_dict.pt'))


def main(args):
    # Create & load base model
    print('Loading base model...')
    base_model = load_base_model(args.data, args.base_ckpt)
    base_model.to(args.device)
    print_model_info(base_model)

    # Load data
    data = get_eval_data(args.data, data_root=args.data_root)
    np.random.seed(0)
    if args.n is None:
        print(f'Argument --n not given; Using full data length {len(data)}')
        size = len(data)
    else:
        size = args.n
    data_indices = np.random.choice(len(data), size, replace=False)
    data_indices = data_indices[args.index_range[0]:args.index_range[1]]
    print(f'Using indices:\n{data_indices}\n')

    if args.forward_op == 'mar':
        torch.manual_seed(args.seed)
        mask = torch.randint(2, base_model.image_shape, device=args.device).bool()
        assert mask.shape == base_model.image_shape
        print(f'MAR mask has {mask.sum()} nonzero items: '
              f'{mask.sum().float().item() / np.prod(base_model.image_shape):.4f} %')
        forward_op = ForwardOp(args.forward_op, mask=mask)
    else:
        forward_op = ForwardOp(args.forward_op)

    for i, data_index in enumerate(data_indices):
        print(f'Running for index: {data_index} [{i+1} / {len(data_indices)}]')
        run_for_index(args, i + args.index_range[0], data_index, data, base_model, forward_op)
        print()


if __name__ == '__main__':
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # Basic stuff
    parser.add_argument('--root_dir', type=str, required=True)

    # Dataset related
    parser.add_argument('--data', type=str, required=True)
    parser.add_argument('--n', type=int, default=None)
    parser.add_argument('--index_range', type=int, nargs=2, required=True)
    parser.add_argument('--data_root', type=str, default=None)

    # Model config
    parser.add_argument('--base_ckpt', type=str, required=True)
    parser.add_argument('--cond_model', type=str, default='realnvp')
    parser.add_argument('--n_mix', type=int, default=10)

    # Task parameters
    parser.add_argument('--sigma', type=float, default=0.1)
    parser.add_argument('--forward_op', type=str, default='bottom_half')
    parser.add_argument('--temps', type=float, nargs='+', default=[0.7, 0.8, 0.9, 1.0])
    parser.add_argument('--dist', type=str, default='l2', choices=['l2', 'lpips_vgg'])

    # Training related
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--learning_rate', type=float, default=1e-3)
    parser.add_argument('--num_steps', type=int, default=2000)
    parser.add_argument('--sample_steps', type=int, nargs='+', required=True)
    parser.add_argument('--ema', type=bool, default=False)

    # Misc
    parser.add_argument('--seed', type=int, default=1234)

    args = parser.parse_args()

    if not os.path.exists(args.root_dir):
        os.makedirs(args.root_dir)
    with open(os.path.join(args.root_dir, f'args_{args.index_range[0]}-{args.index_range[1]}.json'), 'w') as f:
        json.dump(vars(args), f, indent=2)

    args.device = torch.device('cuda:0')
    main(args)
