"""
Create a multitask dataset from a GP.

Author: Ian Char
Date: October 29, 2023
"""
import argparse
import math
import os
import pickle as pkl
import random

import git
import torch
from torch.distributions.normal import Normal
from torch.distributions.multivariate_normal import MultivariateNormal
from tqdm import tqdm

###########################################################################
# %% Parse the arguments.
###########################################################################
parser = argparse.ArgumentParser()
parser.add_argument('--save_dir', type=str, required=True)
parser.add_argument('--x_dim', type=int, default=1)
parser.add_argument('--kernel_type', type=str, default='rbf')
parser.add_argument('--lengthscale_range', type=str, default='1.0,1.0')
parser.add_argument('--num_functions', type=int, default=1000)
parser.add_argument('--num_test_functions', type=int, default=0)
parser.add_argument('--points_per_function', type=int, default=50)
parser.add_argument('--x_bounds', type=str, default='-2.0,2.0')
parser.add_argument('--scale_range', type=str, default='1.0,1.0')
parser.add_argument('--noise', type=float, default=2e-2)
parser.add_argument('--seed', type=int, default=0)
args = parser.parse_args()
lscale_low, lscale_high = [float(ll) for ll in args.lengthscale_range.split(',')]
scale_low, scale_high = [float(n) for n in args.scale_range.split(',')]
xlow, xhigh = [float(x) for x in args.x_bounds.split(',')]
xdiam = xhigh - xlow
os.makedirs(args.save_dir)
with open(os.path.join(args.save_dir, 'args.pkl'), 'wb') as f:
    pkl.dump(args, f)
with open(os.path.join(args.save_dir, 'version.txt'), 'w') as f:
    f.write(str(git.Repo(search_parent_directories=True).head.object.hexsha))
if args.kernel_type != 'rbf':
    raise NotImplementedError(f'Have not implemented {args.kernel_type} kernel.')
random.seed(args.seed)
torch.manual_seed(args.seed)

###########################################################################
# %% Generate the data.
###########################################################################
if args.num_functions:
    data_x, data_y = [], []
    for _ in tqdm(range(args.num_functions)):
        x_sample = torch.rand(args.points_per_function, args.x_dim) * xdiam + xlow
        lscales = torch.rand(args.x_dim) * (lscale_high - lscale_low) + lscale_low
        scale = random.uniform(scale_low, scale_high)
        cov = torch.Tensor([[
            (-1 * torch.norm((x_sample[i] - x_sample[j]) / lscales).pow(2) / 2).exp()
            for j in range(args.points_per_function)]
            for i in range(args.points_per_function)])
        y_sample = MultivariateNormal(
            torch.zeros(args.points_per_function),
            scale ** 2 * cov + torch.eye(args.points_per_function) * args.noise ** 2,
        ).sample().unsqueeze(-1)
        data_x.append(x_sample)
        data_y.append(y_sample)
    torch.save(torch.stack(data_x), os.path.join(args.save_dir, 'x_data.pt'))
    torch.save(torch.stack(data_y), os.path.join(args.save_dir, 'y_data.pt'))
if args.num_test_functions:
    data_x, data_y, te_lscales, te_scales, cum_joint_logprob, marginal_logprob =\
        [[] for _ in range(6)]
    for _ in tqdm(range(args.num_test_functions)):
        x_sample = torch.rand(args.points_per_function, args.x_dim) * xdiam + xlow
        lscales = torch.rand(args.x_dim) * (lscale_high - lscale_low) + lscale_low
        scale = random.uniform(scale_low, scale_high)
        cov = torch.Tensor([[
            (-1 * torch.norm((x_sample[i] - x_sample[j]) / lscales).pow(2) / 2).exp()
            for j in range(args.points_per_function)]
            for i in range(args.points_per_function)])
        ynorm = MultivariateNormal(
            torch.zeros(args.points_per_function),
            scale ** 2 * cov + torch.eye(args.points_per_function) * args.noise ** 2,
        )
        y_sample = ynorm.sample()
        data_x.append(x_sample)
        data_y.append(y_sample.unsqueeze(-1))
        cum_joint_logprob.append(torch.Tensor([
            MultivariateNormal(
                torch.zeros(i),
                scale ** 2 * cov[:i, :i] + torch.eye(i) * args.noise ** 2,
            ).log_prob(y_sample[:i]).item()
            for i in range(1, args.points_per_function + 1)]))
        marginal_logprob.append(Normal(0, math.sqrt(scale ** 2 + args.noise ** 2))
                                .log_prob(y_sample))
        te_lscales.append(lscales)
        te_scales.append(scale)
    torch.save(torch.stack(data_x), os.path.join(args.save_dir, 'te_x_data.pt'))
    torch.save(torch.stack(data_y), os.path.join(args.save_dir, 'te_y_data.pt'))
    torch.save(torch.stack(te_lscales), os.path.join(args.save_dir, 'te_lscales.pt'))
    torch.save(torch.stack(cum_joint_logprob),
               os.path.join(args.save_dir, 'cum_joint_logprob.pt'))
    torch.save(torch.stack(marginal_logprob),
               os.path.join(args.save_dir, 'marginal_logprob.pt'))
    torch.save(torch.Tensor(te_scales), os.path.join(args.save_dir, 'te_scales.pt'))
