"""
Evaluate model on test set.

Author: Ian Char
Date: December 31, 2023
"""
import argparse
from collections import defaultdict
import os
import pickle as pkl

import matplotlib.pyplot as plt
import numpy as np
from omegaconf import OmegaConf
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader, TensorDataset

from krt import KRT_PATH
from krt.utils import load_model

###########################################################################
# %% Parse arguments.
###########################################################################
parser = argparse.ArgumentParser()
parser.add_argument('--path', type=str, required=True)
parser.add_argument('--te_data', type=str)
parser.add_argument('--offset', action='store_true')
parser.add_argument('--dont_normalize', action='store_true')
parser.add_argument('--cuda_device', type=int)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--min_te_ctx_size', type=int, default=1)
parser.add_argument('--max_te_ctx_size', type=int, default=49)
parser.add_argument('--hist', action='store_true')
args = parser.parse_args()
device = f'cuda:{args.cuda_device}' if args.cuda_device is not None else 'cpu'


###########################################################################
# %% Load in the test data.
###########################################################################
print('Loading data...', end='')
if args.te_data:
    data_path = args.te_data
else:
    cfg = OmegaConf.load(os.path.join(args.path, 'config.yaml'))
    data_path = (cfg['data']['path'] if 'path' in cfg['data']
                 else cfg['data']['te_data_path'])
test_x = torch.load(os.path.join(
    KRT_PATH,
    data_path,
    'te_x_data.pt'))
if args.offset:
    test_x += 10
test_y = torch.load(os.path.join(
    KRT_PATH,
    data_path,
    'te_y_data.pt'))
dont_normalize = (args.dont_normalize or 'cum_join_logprob.pt' not in os.listdir(
    os.path.join(KRT_PATH, data_path)))
if not dont_normalize:
    cjoint_ll = torch.load(os.path.join(
        KRT_PATH,
        data_path,
        'cum_joint_logprob.pt',
    ))
    marginal_ll = torch.load(os.path.join(
        KRT_PATH,
        data_path,
        'marginal_logprob.pt',
    ))
    te_data = DataLoader(
        TensorDataset(test_x, test_y, cjoint_ll, marginal_ll),
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=False,
        pin_memory=device != 'cpu',
    )
else:
    te_data = DataLoader(
        TensorDataset(test_x, test_y),
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=False,
        pin_memory=device != 'cpu',
    )
N = test_x.shape[0]
L = test_x.shape[1]
ctx_sizes = [s for s in range(args.min_te_ctx_size,
                              args.max_te_ctx_size + 1)]
print('Done!')

###########################################################################
# %% Traverse file system looking for models.
###########################################################################
for path, _, fnames in os.walk(args.path):
    if '.hydra' in path or 'config.yaml' not in fnames:
        continue
    print(f'Evaluating model at {path}')
    ###########################################################################
    # %% Load in the model.
    ###########################################################################
    print('Loading model...', end='')
    model = load_model(path, map_location=device)
    model = model.to(device)
    print('Done!')

    ###########################################################################
    # %% Do evaluations over different condition set sizes.
    ###########################################################################
    print('Evaluating...')
    evals = {}
    log_stats = defaultdict(list)
    for cs in tqdm(ctx_sizes):
        total_ll = 0.0
        total_min = 0.0
        total_max = 0.0
        for batch in te_data:
            if not dont_normalize:
                xi, yi, ci, mi = batch
                ci, mi = ci.to(device), mi.to(device)
            else:
                xi, yi = batch
            xi, yi = xi.to(device), yi.to(device)
            xc, xt = xi[:, :cs], xi[:, cs:]
            yc, yt = yi[:, :cs], yi[:, cs:]
            ll = model.seq_ll(xc, yc, xt, yt)
            log_stats[cs].append(ll.cpu().numpy())
            if args.dont_normalize:
                total_min += mi[:, cs:].sum().item()
                total_max += (ci[:, -1] - ci[:, cs - 1]).sum().item()
            total_ll += ll.sum().item()
        if not dont_normalize:
            total_ll = (total_ll - total_min) / (total_max - total_min) * 100
        else:
            total_ll /= (N * (L - cs))
        evals[f'Normalized_LL_Ctx{cs}'] = total_ll
    evals['Normalized_LL'] = np.mean([v for v in evals.values()])
    print('=' * 20)
    for cs in ctx_sizes:
        score = evals[f'Normalized_LL_Ctx{cs}']
        print(f'Ctx Size {cs}: {score:0.2f}')
    print(f'Average: {np.mean([v for v in evals.values()]):0.2f}')
    print('=' * 20)
    with open(os.path.join(path, 'test_stats.pkl'), 'wb') as f:
        pkl.dump(evals, f)

    ###########################################################################
    # %% Possibly show histogram.
    ###########################################################################
    if args.hist:
        avg_model, avg_gp = [], []
        for k, v in log_stats.items():
            vv = np.concatenate(v)
            avg_model.append(vv)
            tv = cjoint_ll[:, -1].numpy() - cjoint_ll[:, k - 1].numpy()
            avg_gp.append(tv)
            plt.hist(vv, bins=50, alpha=0.4, color='cornflowerblue')
            plt.axvline(np.mean(vv), color='cornflowerblue', ls='--',
                        label=f'Model = {np.mean(vv):0.2f}')
            plt.hist(tv, bins=50, alpha=0.4, color='orange')
            plt.axvline(np.mean(tv), color='orange', ls='--',
                        label=f'GP = {np.mean(tv):0.2f}')
            plt.legend()
            plt.title(f'Ctx: {k}')
            plt.show()
        avg_model = np.mean(np.array(avg_model), axis=0)
        avg_gp = np.mean(np.array(avg_gp), axis=0)
        plt.hist(avg_model, bins=50, alpha=0.4, color='cornflowerblue')
        plt.axvline(np.mean(avg_model), color='cornflowerblue', ls='--',
                    label=f'Model = {np.mean(avg_model):0.2f}')
        plt.hist(avg_gp, bins=50, alpha=0.4, color='orange')
        plt.axvline(np.mean(avg_gp), color='orange', ls='--',
                    label=f'GP = {np.mean(avg_gp):0.2f}')
        plt.legend()
        plt.title('Average')
        plt.show()
