"""
Evaluate the model on the TNP repo data.
"""
import argparse
import os

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import iqr
import torch
from tqdm import tqdm

from krt.models.edges import RBF, Matern, Periodic
from krt.utils import load_model

###########################################################################
# %% Parse arguments.
###########################################################################
parser = argparse.ArgumentParser()
parser.add_argument('--path', type=str, required=True)
parser.add_argument('--cuda_device', type=int)
parser.add_argument('--evalset_path', type=str, default='evalsets/gp')
parser.add_argument('--offset', type=float, default=0)
parser.add_argument('--plot', action='store_true')
parser.add_argument('--data_lim', type=int, default=None)
args = parser.parse_args()
device = f'cuda:{args.cuda_device}' if args.cuda_device is not None else 'cpu'


###########################################################################
# %% Load in the data.
###########################################################################
print('Loading data...', end='')
if '2d' in args.path:
    dpath = f'{args.evalset_path}/dim_2d.tar'
elif '4d' in args.path:
    dpath = f'{args.evalset_path}/dim_4d.tar'
else:
    dpath = f'{args.evalset_path}/dim_1d.tar'
data = torch.load(dpath, map_location=device)
if args.data_lim:
    data = data[:args.data_lim]
print('Done!')

###########################################################################
# %% Traverse file system looking for models.
###########################################################################
for path, _, fnames in os.walk(args.path):
    if '.hydra' in path or 'config.yaml' not in fnames:
        continue
    print(f'Evaluating model at {path}')
    ###########################################################################
    # %% Load in the model.
    ###########################################################################
    print('\tLoading model...', end='')
    model = load_model(path, map_location=device)
    model = model.to(device)
    print('Done!')

    ###########################################################################
    # %% Calculating log likelihood.
    ###########################################################################
    print('\tEvaluating model...')
    lls = []
    model.eval()
    for batch in tqdm(data):
        with torch.no_grad():
            batch_lls = model.seq_ll(batch.xc + args.offset, batch.yc,
                                     batch.xt + args.offset, batch.yt,
                                     use_training_range_for_min_max=False)
        lls += batch_lls.tolist()
    print(f'\t{np.mean(lls):0.3f} +- {np.std(lls):0.3f}\n'
          f'\tMedian: {np.median(lls)}\n'
          f'\tIQR: {iqr(lls)}\n')
    fname = f'offset{args.offset:0.1f}_eval.txt'
    with open(os.path.join(path, fname),
              'w') as f:
        f.write(f'{np.mean(lls)},{np.std(lls)},{np.median(lls)},{iqr(lls)}\n')

    ###########################################################################
    # %% Show histogram.
    ###########################################################################
    if args.plot:
        print(f'\tFive worst test idxs: {np.argsort(lls)[:5]}')
        plt.hist(lls, alpha=0.6, color='cornflowerblue', bins=50)
        plt.axvline(np.mean(lls), color='black', ls='--')
        plt.title(f'(Average, Median) LL: ({np.mean(lls):0.3f}, {np.median(lls):0.3f})')
        plt.show()
