import argparse
import json
import math
import os
import time
from types import SimpleNamespace

import numpy as np
import torch
from torchvision.utils import save_image

from code.utils import print_model_info

from code.exp_utils import load_base_model, get_eval_data, ForwardOp
from code.mcmc import TransitionKernel, AuxDensity, run_plmcmc_singlesample


@torch.no_grad()
def main(args):
    bs = args.batch_size

    base_model = load_base_model(args.base_model, args.base_ckpt).to(args.device)
    print_model_info(base_model)

    if args.forward_op == 'mar':
        torch.manual_seed(args.seed)
        mask = torch.randint(2, base_model.image_shape, device=args.device).bool()
        assert mask.shape == base_model.image_shape
        print(f'MAR mask has {mask.sum()} nonzero items: '
              f'{mask.sum().float().item() / np.prod(base_model.image_shape):.4f} %')
        forward_op = ForwardOp(args.forward_op, mask=mask, keep_shape=True)
    else:
        forward_op = ForwardOp(args.forward_op, keep_shape=True)

    # Load data
    data = get_eval_data(args.data, data_root=args.data_root)
    np.random.seed(0)
    if args.n is None:
        print(f'Argument --n not given; Using full data length {len(data)}')
        size = len(data)
    else:
        size = args.n
    indices = np.random.choice(len(data), size, replace=False)
    indices = indices[args.index_range[0]:args.index_range[1]]
    n_batches = math.ceil(len(indices) / bs)
    print(f'Using indices:\n{indices}\n')

    for sample_idx in range(args.starting_sample_index, args.num_samples):
        stats = SimpleNamespace(
            indices=indices.tolist(),
            num_samples=args.num_samples,
            sample_steps=args.sample_steps,
            n_batches=n_batches,

            total_time={},
            batch_size={},
        )

        for i in range(n_batches):
            # Create directory
            batch_indices = indices[i * bs : (i + 1) * bs]
            index_low = i * bs + args.index_range[0]
            index_high = index_low + len(batch_indices)
            save_dir = os.path.join(args.root_dir,
                                    f'samples_{sample_idx:03d}_indices_{index_low:05d}-{index_high:05d}')
            if os.path.exists(save_dir) and len(os.listdir(save_dir)) > 0:
                print(f'Skipping sample {sample_idx+1} / {args.num_samples} '
                      f'batch {i+1} / {n_batches}; directory nonempty {save_dir}\n')
                continue

            os.makedirs(save_dir, exist_ok=True)

            # Prep stuff
            x_star = data[batch_indices].float().to(args.device) / 256.
            x_obs, _ = forward_op.observe(x_star)
            g = TransitionKernel('perturb', args.sigma_p)
            q = AuxDensity(x_obs, args.sigma_a)
            z_init = base_model.sample_prior(len(x_star), temp=1.0, device=args.device)
            print(f'Index_range {args.index_range[0]}-{args.index_range[1]} '
                  f'sample {sample_idx+1:03d}/{args.num_samples:03d} batch {i+1:03d}/{n_batches:03d} : '
                  f'{save_dir}')

            # Run PL-MCMC
            start_time = time.time()
            samples = run_plmcmc_singlesample(
                base_model=base_model, x_obs=x_obs, z_init=z_init, forward_op=forward_op,
                g=g, q=q, num_steps=args.num_steps, sample_steps=args.sample_steps)
            time_spent = time.time() - start_time
            stats.total_time[i] = time_spent
            stats.batch_size[i] = len(x_star)
            print(f'sample {sample_idx+1:02d}/{args.num_samples:02d} '
                  f'batch {i+1:03d}/{n_batches:03d} '
                  f'took {time_spent:.2f} sec')

            # Save results
            torch.save(stats, os.path.join(save_dir, 'stats.pt'))
            torch.save(samples, os.path.join(save_dir, 'samples.pt'))

            k = min(16, len(x_star)) # max examples in image dumps
            for sample_step in samples.keys():
                x = samples[sample_step].clamp(0, 1)[:k]
                save_image(torch.cat([x_star.cpu()[:k], x], dim=0),
                           os.path.join(save_dir, f'samples_step={sample_step}.png'),
                           nrow=k, pad_value=1, range=(0, 1))

            print()


if __name__ == '__main__':
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    # Basic stuff
    parser.add_argument('--root_dir', type=str, required=True)

    # Dataset related
    parser.add_argument('--data', type=str, required=True)
    parser.add_argument('--data_root', type=str, default=None)
    parser.add_argument('--index_range', type=int, nargs=2, required=True)
    parser.add_argument('--n', type=int)

    # Model config
    parser.add_argument('--base_ckpt', type=str, required=True)
    parser.add_argument('--base_model', type=str, required=True)

    # Task parameters
    parser.add_argument('--forward_op', type=str, default='bottom_half')

    # MCMC parameters
    parser.add_argument('--sigma_p', type=float, default=0.05)
    parser.add_argument('--sigma_a', type=float, default=0.001)
    parser.add_argument('--num_steps', type=int, required=True)
    parser.add_argument('--num_samples', type=int, required=True)

    # Misc
    parser.add_argument('--seed', type=int, default=1234)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--sample_steps', type=int, nargs='+', required=True)
    parser.add_argument('--starting_sample_index', type=int, default=0)

    args = parser.parse_args()

    if not os.path.exists(args.root_dir):
        os.makedirs(args.root_dir)
    with open(os.path.join(args.root_dir, f'args_{args.index_range[0]}-{args.index_range[1]}.json'), 'w') as f:
        json.dump(vars(args), f, indent=2)

    args.device = torch.device('cuda:0')
    main(args)
