import argparse
import os
import json
from tqdm import tqdm
import numpy as np
import torch
from torch import nn
import pickle
import matplotlib.pyplot as plt
import warnings
from feedforward import FF, FF_MoCap
from utils.data import VF_Dataset, scale_data, MoCap_Dataset

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--tr_dataset', type=str, choices=['uwhvf', 'break', 'house', 'ballet_jazz', 'street_jazz', 'krump', 'la_hip_hop', 'lock', 'middle_hip_hop', 'pop', 'wack'], default='scheie', help='Dataset.')
parser.add_argument('--te_dataset', type=str, choices=['uwhvf', 'break', 'house', 'ballet_jazz', 'street_jazz', 'krump', 'la_hip_hop', 'lock', 'middle_hip_hop', 'pop', 'wack'], default='scheie', help='Dataset.')
parser.add_argument('--model_pth', type=str, help='Model path.')
parser.add_argument('--seed', type=int, default=0, help='Random seed.')
parser.add_argument('--gpu', type=int, default=0, help='GPU.')
parser.add_argument('--print_samples', action='store_true', help='Use this arg to print random samples of predictions and targets.')

# hyperparameters
parser.add_argument('--batch_size', type=int, default=128, help='Batch size.')
parser.add_argument('--time_dim', type=int, default=16, help='Dimension of time embedding.')
parser.add_argument('--hidden_dim', type=int, default=64, help='Epochs.')

# conditioning inputs
parser.add_argument('--n_frames', type=int, default=3, help='Number of frames required for prediction.')
parser.add_argument('--n_horizon', type=int, default=1, help='Number of frames to predict.')

args = parser.parse_args()

if __name__ == '__main__':
    # setup
    device = torch.device(f"cuda:{str(args.gpu)}" if torch.cuda.is_available() else "cpu")
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    # save directories
    fig_dir = f'analysis/pattern/{args.tr_dataset}_to_{args.te_dataset}'
    os.makedirs(fig_dir, exist_ok=True)

    # dataset
    if 'uwhvf' in args.te_dataset:
        data_range = (-37.69, 50.00)
    elif 'break' in args.te_dataset:
        data_range = (-217.49, 286.76)
    elif 'street_jazz' in args.te_dataset:
        data_range = (-102.57, 101.90)
    elif 'ballet_jazz' in args.te_dataset:
        data_range = (-308.86, 327.77)
    elif 'middle_hip_hop' in args.te_dataset:
        data_range = (-91.20, 92.65)
    elif 'house' in args.te_dataset:
        data_range = (-78.08, 81.47)
    elif 'krump' in args.te_dataset:
        data_range = (-85.32, 110.89)
    elif 'la_hip_hop' in args.te_dataset:
        data_range = (-77.41, 100.26)
    elif 'lock' in args.te_dataset:
        data_range = (-109.02, 95.57)
    elif 'pop' in args.te_dataset:
        data_range = (-132.20, 103.33)
    elif 'wack' in args.te_dataset:
        data_range = (-283.42, 413.92)
    else:
        raise NotImplementedError
    
    if args.te_dataset in ['uwhvf']:
        data_dir = f'datasets/{args.te_dataset}_seqp_{args.n_frames}'
    else:
        data_dir = f'datasets/aistpp_seqp_aa/{args.te_dataset}/h{args.n_frames}_H{args.n_horizon}'
    with open(os.path.join(data_dir, 'test.json')) as f:
        test_data = json.loads(f.read())

    if args.te_dataset in ['uwhvf']:
        test_dataset = VF_Dataset(test_data, representation='aa', normalize=False, data_range=data_range)
        field_mask = test_dataset.get_field_mask()
    else:
        test_dataset = MoCap_Dataset(test_data, representation='aa', normalize=False, data_range=data_range)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                    batch_size=args.batch_size,
                                    shuffle=False)
    
    # load archetype info
    with open(f"archetypes/{args.tr_dataset}_aa_object.pkl", 'rb') as f:
        aa = pickle.load(f)
    
    # test setup
    if args.te_dataset in ['uwhvf']:
        n_archetypes = aa.n_archetypes_
        model = FF(n_frames=args.n_frames,
                n_archetypes = n_archetypes,
                hidden_dim = args.hidden_dim,
                out_dim = n_archetypes,
                dim=args.time_dim,
                ).to(device)
    else:
        n_archetypes = aa.n_archetypes
        model = FF_MoCap(n_frames=args.n_frames,
                n_horizon=args.n_horizon,
                n_archetypes = n_archetypes,
                hidden_dim = args.hidden_dim,
                out_dim = n_archetypes*args.n_horizon,
                ).to(device)
    model.load_state_dict(torch.load(args.model_pth))
    mse = nn.MSELoss(reduction='none')
    mae = nn.L1Loss(reduction='none')

    aauq_agg = 'mean' # mean, last, max, min

    # test loop
    model.eval()
    mse_aa, mae_aa, mae_vf = [], [], []
    aa_uqs = []
    base_mae_vf = []
    with torch.no_grad():
        batch_mse_aa, batch_mae_aa, batch_mae_vf = 0, 0, 0
        for batch_idx, (inputs, x_f, sids) in tqdm(enumerate(test_loader), total=len(test_loader)):
            B = x_f.shape[0]

            # grab history and horizon samples, needed for OOD testing and uncertainty metric
            if args.te_dataset in ['uwhvf']:
                history = np.array([test_dataset.get_stats(sid)['baseline_td'] for sid in sids])                # (B, n_frames, 9, 9)
                history = history[:,:,field_mask[:9,:9]].reshape(-1,args.n_frames,52)                       # (B, n_frames, 52)
                horizon = np.array([test_dataset.get_stats(sid)['future_td'] for sid in sids])                    # (B, 9, 9)
                horizon = horizon[:,field_mask[:9,:9]].reshape(-1,52)                                           # (B, 52)
            else:
                history = np.array([test_dataset.get_stats(sid)['history_kp_norm'] for sid in sids])
                history = history.reshape(-1,args.n_frames,17*3)
                horizon = np.array([test_dataset.get_stats(sid)['horizon_kp_norm'] for sid in sids])
                horizon = horizon.reshape(-1,args.n_horizon,17*3)

            history = scale_data(history, data_range[0], data_range[1], 0, 1)
            horizon = scale_data(horizon, data_range[0], data_range[1], 0, 1)

            # construct input, either by pulling from dataloader or calculating AA coefficinets based on AA object
            if args.tr_dataset == args.te_dataset:
                if args.te_dataset in ['uwhvf']:
                    history_aa = torch.squeeze(inputs[0])                                                         # (B, n_frames, 1, n_archetypes) -> (B, n_frames, n_archetypes)
                else:
                    history_aa = inputs[0].reshape(-1, args.n_frames*n_archetypes)                                                        # (B, n_frames, 1, n_archetypes) -> (B, n_frames, n_archetypes)
                x_f = np.squeeze(x_f).to(device).float()                                                        # (B, 1, 1, n_archetypes) -> (B, n_archetypes)
            else:
                with warnings.catch_warnings():
                    warnings.filterwarnings("ignore", category=FutureWarning) # suppress warning about BaseEstimator._validate_data deprecation
                    if args.te_dataset in ['uwhvf']:
                        history_coeffs = torch.from_numpy(aa.transform(history.reshape(-1,52)))
                        history_aa = history_coeffs.reshape(B, args.n_frames, -1)
                        x_f = torch.from_numpy(aa.transform(horizon).reshape(-1,args.n_horizon,n_archetypes)).to(device).float()  
                    else:
                        history_coeffs = torch.from_numpy(aa.transform(history.reshape(-1,17*3)))
                        horizon_coeffs = torch.from_numpy(aa.transform(horizon.reshape(-1,17*3)))
                        history_aa = history_coeffs.reshape(-1, args.n_frames, n_archetypes)
                        x_f = horizon_coeffs.to(device).float()
                 
            if args.te_dataset in ['uwhvf']:
                history_aa = torch.squeeze(history_aa).to(device).float()
                start_age = inputs[1].reshape(-1,1).to(device).float()                                     # (B, 1)
                age_deltas = torch.unsqueeze(inputs[2],-1).to(device).float()                 # (B, n_frames*n_archetypes+n_frames+1)
                pred_alphas = model(history_aa, start_age, age_deltas)                                                                                 # (B, n_archetypes)
            else:
                history_aa = history_aa.reshape(-1, args.n_frames*n_archetypes).to(device).float()
                pred_alphas = model(history_aa)

            # MSE loss in AA space
            loss_mse = mse(pred_alphas.reshape(-1,args.n_horizon*n_archetypes), x_f.reshape(-1,args.n_horizon*n_archetypes))  # (B, -1)
            batch_mse_aa += torch.mean(loss_mse)
            mse_aa += list(np.array(torch.mean(loss_mse, dim=1).cpu().detach()))

            # MAE loss in AA space
            loss_mae = mae(pred_alphas.reshape(-1,args.n_horizon*n_archetypes), x_f.reshape(-1,args.n_horizon*n_archetypes))   # (B, n_archetypes)
            batch_mae_aa += torch.mean(loss_mae)
            mae_aa += list(np.array(torch.mean(loss_mae, dim=1).cpu().detach()))
            
            # MAE loss in VF space
            pred_vf = np.array(pred_alphas.detach().cpu()) @ aa.archetypes_
            future_vf_rescaled = scale_data(horizon, 0, 1, 0, 100)
            pred_vf_rescaled = scale_data(pred_vf, 0, 1, 0, 100)
            loss_mae_vf = np.abs(future_vf_rescaled - pred_vf_rescaled)  # (B, 52)                         
            batch_mae_vf += np.mean(loss_mae_vf)
            if args.te_dataset in ['uwhvf', 'scheie']:
                mae_vf += list(np.mean(loss_mae_vf,axis=1))
            else:
                mae_vf += list(np.mean(loss_mae_vf,axis=(1,2)))

            # calculate AA-based uncertainty metric
            reconstruction = np.array(history_aa.detach().cpu()).reshape(-1,args.n_frames,n_archetypes) @ aa.archetypes_                                    # (B, n_frames, 52)
            recon_err = np.mean(np.abs(history - reconstruction), axis=2)                           # (B, n_frames)
            if aauq_agg == 'mean':
                aa_uqs += list(np.mean(recon_err, axis=1))
            elif aauq_agg == 'last':
                aa_uqs += list(recon_err[:,-1])
            elif aauq_agg == 'max':
                aa_uqs += list(np.max(recon_err, axis=1))
            elif aauq_agg == 'min':
                aa_uqs += list(np.min(recon_err, axis=1))
            else:
                raise NotImplementedError

            # calculate baseline MAE loss on ground truth VF (not predicted visual field)
            with warnings.catch_warnings():
                warnings.filterwarnings("ignore", category=FutureWarning) # suppress warning about BaseEstimator._validate_data deprecation
                if args.te_dataset in ['uwhvf', 'scheie']:
                    future_coeffs = aa.transform(horizon)
                    future_recon = future_coeffs @ aa.archetypes_
                else:
                    future_coeffs = aa.transform(horizon.reshape(-1,17*3))
                    future_recon = future_coeffs.reshape(-1,args.n_horizon,n_archetypes) @ aa.archetypes_
                future_vf_recon = scale_data(future_recon, 0, 1, 0, 100)
                if args.te_dataset in ['uwhvf', 'scheie']:
                    base_mae_vf += list(np.mean(np.abs(future_vf_rescaled - future_vf_recon), axis=1))
                else:
                    base_mae_vf += list(np.mean(np.abs(future_vf_rescaled - future_vf_recon), axis=(1,2)))

    # print random samples if desired
    if args.print_samples:
        print('Random samples:')
        idxs = np.random.randint(0, pred_alphas.shape[0], 5)
        for i in idxs:
            print(pred_alphas[i], pred_alphas[0].sum())
            print(x_f[i], x_f[i].sum())
            print()

    # print stats
    print(f'\t\t\tMSE in AA space {batch_mse_aa/len(test_loader)}')
    print(f'\t\t\tMAE in AA space {batch_mae_aa/len(test_loader)}')
    print(f'\t\t\tMAE in VF space {batch_mae_vf/len(test_loader)}')
    print(f'\t\t\tAAUQ {np.mean(aa_uqs)}')

    # plot uncertainty vs. loss in different spaces for different loss functions
    for vals, loss_fn, space in [(mse_aa, 'MSE', 'AA'), (mae_aa, 'MAE', 'AA'), (mae_vf, 'MAE', 'VF')]:
        plt.hist2d(aa_uqs, vals, bins=100)
        plt.xlabel('AA-based uncertainty')
        plt.ylabel(f'{loss_fn} in {space} space')
        plt.grid()
        plt.savefig(f'{fig_dir}/{args.model_pth.split('.')[0].split('/')[-1]}_hist_{space}_{aauq_agg}.png')
        plt.close()

        plt.scatter(aa_uqs, vals, alpha=0.1)
        plt.xlabel('AA-based uncertainty')
        plt.ylabel(f'{loss_fn} in {space} space')
        plt.grid()
        plt.savefig(f'{fig_dir}/{args.model_pth.split('.')[0].split('/')[-1]}_scatter_{space}_{aauq_agg}.png')
        plt.close()

    # in VF space, compare MAE between VF reconstructed from predicted AA coefficients and VF reconstructed from actual AA coefficients
    fig, ax = plt.subplots(2,2)
    fig.set_figwidth(8)
    fig.set_figheight(8)
    cm = plt.get_cmap('tab10')
    alpha = 0.2
    n_bins = 100

    ax[0,0].scatter(aa_uqs, base_mae_vf, alpha=alpha, color=cm(0), marker='.')
    ax[0,1].hist2d(aa_uqs, base_mae_vf, bins=n_bins)

    ax[1,0].scatter(aa_uqs, mae_vf, alpha=alpha, color=cm(1), marker='.')
    ax[1,1].hist2d(aa_uqs, mae_vf, bins=n_bins)

    for c in range(2):
        xmin, xmax, ymin, ymax = np.Inf, -np.Inf, np.Inf, -np.Inf
        for r in range(2):
            xlim = ax[r,c].get_xlim()
            ylim = ax[r,c].get_ylim()
            if xlim[0] < xmin: xmin = xlim[0]
            if xlim[1] > xmax: xmax = xlim[1]
            if ylim[0] < ymin: ymin = ylim[0]
            if ylim[1] > ymax: ymax = ylim[1]
        for r in range(2):
            ax[r,c].set_xlim((xmin, xmax))
            ax[r,c].set_ylim((ymin, ymax))

    for i in range(2):
        ax[i,0].grid()
        ax[1,i].set_xlabel('AA-based uncertainty')
    ax[0,0].set_ylabel(f'MAE (recon. from gt AA coeff.s)')
    ax[1,0].set_ylabel(f'MAE (recon. from pred AA coeff.s)')
    plt.tight_layout()
    plt.savefig(f'{fig_dir}/{args.model_pth.split('.')[0].split('/')[-1]}_futureVFbaseline_{aauq_agg}.png', dpi=300)
    plt.close()