"""
Generate eval test set in the same style as done in
https://github.com/tung-nd/TNP-pytorch/blob/master/regression/gp.py
"""
import argparse
import math
import random

from attrdict import AttrDict
import numpy as np
from tqdm import tqdm
import torch

from krt.data_module.infinite_gp_data import RBFGPIterator


###########################################################################
# %% Parse arguments.
###########################################################################
parser = argparse.ArgumentParser()
parser.add_argument('--save_path', type=str, required=True)
parser.add_argument('--dim_x', type=int, required=True)
parser.add_argument('--num_batches', type=int, default=3000)
parser.add_argument('--functions_per_batch', type=int, default=16)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--min_points_per_function', type=int, default=6)
parser.add_argument('--max_points_per_function', type=int, default=50)
parser.add_argument('--min_ctx_size', type=int, default=3)
parser.add_argument('--min_trg_size', type=int, default=3)
parser.add_argument('--lengthscale_range', type=str, default='0.1,0.6')
parser.add_argument('--scale_range', type=str, default='0.1,1.0')
parser.add_argument('--scale_range_with_dim', action='store_true')
parser.add_argument('--noise', type=float, default=2e-2)
args = parser.parse_args()
lengthscales = np.array([float(lgth) for lgth in args.lengthscale_range.split(',')])
if args.scale_range_with_dim:
    lengthscales *= math.sqrt(args.dim_x)
scales = np.array([float(lgth) for lgth in args.scale_range.split(',')])

###########################################################################
# %% Collect the data.
###########################################################################
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
itr = RBFGPIterator(
    dim_x=args.dim_x,
    batches_per_epoch=args.num_batches,
    batch_size=args.functions_per_batch,
    min_points_per_function=args.min_points_per_function,
    max_points_per_function=args.max_points_per_function,
    lengthscale_range=lengthscales,
    scale_range=scales,
    x_bounds=(-2.0, 2.0),
    noise=args.noise,
)
pbar = tqdm(total=args.num_batches)
data = []
for xi, yi in itr:
    num_ctx = random.randint(args.min_ctx_size,
                             xi.shape[1] - args.min_trg_size)
    batch = AttrDict({
        'x': xi,
        'y': yi,
        'xt': xi[:, num_ctx:],
        'yt': yi[:, num_ctx:],
        'xc': xi[:, :num_ctx],
        'yc': yi[:, :num_ctx],
    })
    data.append(batch)
    pbar.update(1)
pbar.close()
torch.save(data, args.save_path)
