import argparse
import os
from tqdm import tqdm
import torch
from torch import nn
import numpy as np
import json
import pickle
import warnings

from unet import Unet, Unet_MoCap
from diffusion import Diffusion, Diffusion_MoCap
from utils.data import VF_Dataset, scale_data, MoCap_Dataset
from utils.utils import *
from feedforward import FF, FF_MoCap
   
def test_model(diffusion, dynamics_model, aa, archetypes, dataset, gpu, w, gamma, batch_size=128, guidance_type='constant'):
    # gpu
    device = torch.device("cuda:" + str(gpu) if torch.cuda.is_available() else "cpu")
    diffusion.model.to(device)
    diffusion.model.eval()
    pattern_model.to(device)
    pattern_model.eval()

    # metrics
    criterion = nn.L1Loss(reduce=False)
    reduced_criterion = nn.L1Loss(reduction='mean')

    # data
    if args.te_dataset in ['uwhvf']:
        field_mask = torch.from_numpy(dataset.get_field_mask().astype(int)).to(device)
    dataloader = torch.utils.data.DataLoader(dataset,
                                batch_size=batch_size,
                                shuffle=False)
    
    w = torch.tensor(w).to(device)
    zero_const = torch.FloatTensor([0]).to(device)

    # test loop
    patterns, predictions, trues, baselines, uqs, guidances = [], [], [], [], [], []
    unrot_gts, unrot_preds, unrot_base = [], [], []
    n_samples = args.n_samples
    with torch.no_grad():
        total_loss = [0]*n_samples
        for batch_idx, (input, output, _) in tqdm(enumerate(dataloader), total=len(dataloader)):
            x_b = input[0].to(device).float()
            B = x_b.shape[0]

            if args.te_dataset in ['uwhvf']:
                x_b, start_age, deltas, baseline_aa = input
                start_age = input[1].reshape(-1,1).to(device).float() 
                deltas = torch.unsqueeze(input[2],-1).to(device).float() 
                x_b = input[0].to(device).float() 

                x_f = output.to(device).float() 

                # get baseline_aa, either by pulling from dataloader or calculating AA coefficients based on AA object
                if args.tr_dataset == args.te_dataset:
                    baseline_aa = torch.squeeze(input[3]).to(device).float() # B, args.n_frames, n_archetypes
                else:
                    with warnings.catch_warnings():
                        warnings.filterwarnings("ignore", category=FutureWarning) # suppress warning about BaseEstimator._validate_data deprecation
                        baseline_vf = x_b[:,:,field_mask>0].detach().cpu().numpy().reshape(-1,args.n_frames,52) # B, args.n_frames, 52
                        baseline_coeffs = torch.from_numpy(aa.transform(baseline_vf.reshape(-1,52)))                # (B*n_frames, 52)
                        baseline_aa = baseline_coeffs.reshape(B, args.n_frames, -1).to(device).float()              # (B, n_frames, n_archetypes)
            else:
                rot_history = input[1].to(device).float() 
                pos_horizon = output[0].to(device).float()        
                rot_horizon = output[1].to(device).float() 

                x_f = torch.concatenate((pos_horizon.reshape(-1, args.n_horizon, 17*3), rot_horizon), dim=-1)

                # get baseline_aa, either by pulling from dataloader or calculating AA coefficients based on AA object
                if args.tr_dataset == args.te_dataset:
                    baseline_aa = torch.squeeze(input[2]).to(device).float() # B, args.n_frames, n_archetypes
                else:
                    with warnings.catch_warnings():
                        warnings.filterwarnings("ignore", category=FutureWarning) # suppress warning about BaseEstimator._validate_data deprecation
                        history_coeffs = torch.from_numpy(aa.transform(x_b.reshape(-1,17*3).cpu().detach().numpy()))
                        baseline_aa = history_coeffs.reshape(-1, args.n_frames, n_archetypes)

            # create archetype conditioning
            if args.te_dataset in ['uwhvf']:
                pred_aa = pattern_model(baseline_aa, start_age, deltas)
                recon = pred_aa @ archetypes

                aa_pattern = torch.zeros_like(x_f)
                aa_pattern[:,0,(field_mask > 0)] = recon[:,:]
                null_pattern = torch.zeros_like(x_f)
            else:
                baseline_aa = baseline_aa.reshape(-1, args.n_frames*n_archetypes).to(device).float()
                pred_aa = pattern_model(baseline_aa) 
                aa_pattern = pred_aa @ archetypes 
                null_pattern = torch.zeros_like(pos_horizon)
            

            # generate images
            # generate images
            if guidance_type == 'constant':
                expanded_w = torch.full(torch.Size([B]), w).to(device)
                
                if args.te_dataset in ['uwhvf']:
                    generated_images = diffusion.sample(n_samples, x_b, start_age, deltas, aa_pattern, null_pattern, expanded_w)
                    if args.guide_mix and args.mix_scale > 0:
                        pattern_exp = aa_pattern.repeat_interleave(n_samples, dim=0).view(B, n_samples, x_f.shape[1], x_f.shape[2], x_f.shape[3])
                        generated_images = args.mix_scale*pattern_exp + (1-args.mix_scale)*generated_images
                else:
                    generated_images = diffusion.sample(n_samples, x_b, rot_history, aa_pattern, null_pattern, expanded_w)
                    generated_mocap = generated_images[:, :, :, :-2] 
                    generated_rotation = generated_images[:, :, :, -2:]

                    if args.guide_mix and args.mix_scale > 0:
                        pattern_exp = aa_pattern.repeat_interleave(n_samples, dim=0).view(B, n_samples, args.n_horizon, 17*3)
                        generated_images = args.mix_scale*pattern_exp + (1-args.mix_scale)*generated_mocap
            else:
                # UQ for guidance scale: for better reconstruction (smaller error), we want more guidance (larger guidance scale)
                if args.te_dataset in ['uwhvf']:
                    relevant_x_b_points = x_b[:,:,field_mask>0]
                else:
                    relevant_x_b_points = x_b.reshape(-1, args.n_frames, 17*3)
                if args.aauq_type == 'last':
                    recon_recent = baseline_aa.reshape(-1, args.n_frames, n_archetypes)[:,-1,:] @ archetypes
                    aauq = torch.mean(torch.abs(recon_recent - relevant_x_b_points[:,-1,:]), dim=1) # reconstructed baseline at most recent step vs. true baseline
                elif args.aauq_type == 'mean':
                    recon_recent = baseline_aa.reshape(-1, args.n_frames, n_archetypes) @ archetypes
                    aauq = torch.mean(torch.mean(torch.abs(recon_recent - relevant_x_b_points), dim=2),dim=1) # reconstructed baseline at most recent step vs. true baseline
                else:
                    raise NotImplementedError
            
                if guidance_type == 'log':
                    uq_guidance = torch.max(-torch.log(w*aauq), zero_const.expand_as(aauq)) # zero crossing (zero guidance) is at 1/w aauq, min value 0
                    uq_mixing = torch.max(-torch.log(args.mix_scale*aauq), zero_const.expand_as(aauq)) # zero crossing (zero guidance) is at 1/w aauq, min value 0
                elif guidance_type == 'sigmoid':
                    uq_guidance = w / (1 + torch.exp(aauq - gamma/2))
                    uq_mixing = args.mix_scale / (1 + torch.exp(aauq - args.mix_error/2))
                elif guidance_type == 'relu':
                    uq_guidance = torch.max(-(w/gamma)*aauq + w, zero_const.expand_as(aauq))
                    uq_mixing = torch.max(-(args.mix_scale/args.mix_error)*aauq + args.mix_scale, zero_const.expand_as(aauq))
                else:
                    raise NotImplementedError
                
                if args.te_dataset in ['uwhvf']:
                    generated_images = diffusion.sample(n_samples, x_b, start_age, deltas, aa_pattern, null_pattern, uq_guidance) # B, 1, 10, 10
                    if args.guide_mix:
                        pattern_exp = aa_pattern.repeat_interleave(n_samples, dim=0).view(B, n_samples, x_f.shape[1], x_f.shape[2], x_f.shape[3])
                        generated_images = uq_mixing.view(-1,1,1,1,1) * pattern_exp + (1-uq_mixing).view(-1,1,1,1,1) * generated_images
                else:
                    generated_images = diffusion.sample(n_samples, x_b, rot_history, aa_pattern, null_pattern, uq_guidance)
                    generated_mocap = generated_images[:, :, :, :-2] # N x S x H x 17*3
                    generated_rotation = generated_images[:, :, :, -2:] # N S x x H x 2
                    if args.guide_mix:
                        pattern_exp = aa_pattern.repeat_interleave(n_samples, dim=0).view(B, n_samples, args.n_horizon, 17*3)
                        generated_images = uq_mixing.view(-1,1,1,1,1) * pattern_exp + (1-uq_mixing).view(-1,1,1,1,1) * generated_mocap
                        
            # calc metrics
            if args.te_dataset in ['uwhvf']:
                for sample in range(n_samples):
                    total_loss[sample] += reduced_criterion(generated_images[:,sample,:,field_mask>0].view(B,-1), x_f[:,:,field_mask>0].view(B,-1)).detach().cpu()

                preds_array = generated_images.detach().cpu().numpy()
                trues_array = x_f.detach().cpu().numpy()
                inputs_array = x_b.detach().cpu().numpy()

            else:
                generated_mocap = scale_data(generated_mocap, 0, 1, data_range[0], data_range[1]).reshape(-1, n_samples, args.n_horizon, 17, 3) 
                generated_rotation = scale_data(generated_rotation, 0, 1, -1, 1) 
                horizon_mocap = scale_data(pos_horizon, 0, 1, data_range[0], data_range[1]).reshape(-1, args.n_horizon, 17, 3) 
                horizon_rotation = scale_data(rot_horizon, 0, 1, -1, 1) 
                history_mocap = scale_data(x_b, 0, 1, data_range[0], data_range[1]).reshape(-1, args.n_frames, 17, 3) 
                history_rotation = scale_data(rot_history, 0, 1, -1, 1) 

                # add rotation back into generations
                zeros, ones = torch.zeros(B, n_samples, args.n_horizon, 1, 1, device=device), torch.ones(B, n_samples, args.n_horizon, 1, 1, device=device)
                gen_rotation_norm = torch.sqrt(generated_rotation[...,0]**2 + generated_rotation[...,1]**2 + 1e-8) 
                gen_rotation = generated_rotation / gen_rotation_norm.unsqueeze(-1) 
                gen_rotation_combined = torch.atan2(gen_rotation[...,0], gen_rotation[...,1]).detach().cpu().numpy() 
                gen_yaws = torch.from_numpy(np.unwrap(gen_rotation_combined, axis=-1)).to(device) 
                gen_cos_yaws, gen_sin_yaws = torch.cos(gen_yaws)[..., None, None], torch.sin(gen_yaws)[..., None, None] 
                gen_R_yaw = torch.cat([
                    torch.cat([ gen_cos_yaws, zeros,  gen_sin_yaws], dim=-1),
                    torch.cat([ zeros,        ones,   zeros], dim=-1),
                    torch.cat([-gen_sin_yaws, zeros,  gen_cos_yaws], dim=-1)
                ], dim=-2)  
                gen_roots = (generated_mocap[...,11,:] - generated_mocap[...,12,:]) / 2 
                gen_rotated_mocap = (generated_mocap - gen_roots.unsqueeze(-2)) @ torch.permute(gen_R_yaw,(0,1,2,4,3))

                # add rotation back into groundtruths
                zeros, ones = torch.zeros(B, args.n_horizon, 1, 1, device=device), torch.ones(B, args.n_horizon, 1, 1, device=device)
                gt_rotation_norm = torch.sqrt(horizon_rotation[...,0]**2 + horizon_rotation[...,1]**2 + 1e-8) 
                gt_rotation = horizon_rotation / gt_rotation_norm.unsqueeze(-1) 
                gt_rotation_combined = torch.atan2(gt_rotation[...,0], gt_rotation[...,1]).detach().cpu().numpy() 
                gt_yaws = torch.from_numpy(np.unwrap(gt_rotation_combined, axis=-1)).to(device) 
                gt_cos_yaws, gt_sin_yaws = torch.cos(gt_yaws)[..., None, None], torch.sin(gt_yaws)[..., None, None] 
                gt_R_yaw = torch.cat([
                    torch.cat([ gt_cos_yaws, zeros,  gt_sin_yaws], dim=-1),
                    torch.cat([ zeros,       ones,   zeros], dim=-1),
                    torch.cat([-gt_sin_yaws, zeros,  gt_cos_yaws], dim=-1)
                ], dim=-2)  
                gt_roots = (horizon_mocap[...,11,:] - horizon_mocap[...,12,:]) / 2 
                gt_rotated_mocap = (horizon_mocap - gt_roots.unsqueeze(2)) @ torch.permute(gt_R_yaw,(0,1,3,2)) 

                # add rotation back into histories
                zeros, ones = torch.zeros(B, args.n_frames, 1, 1, device=device), torch.ones(B, args.n_frames, 1, 1, device=device)
                hist_rotation_norm = torch.sqrt(history_rotation[...,0]**2 + history_rotation[...,1]**2 + 1e-8) 
                hist_rotation = history_rotation / hist_rotation_norm.unsqueeze(-1)
                hist_rotation_combined = torch.atan2(hist_rotation[...,0], hist_rotation[...,1]).detach().cpu().numpy() 
                hist_yaws = torch.from_numpy(np.unwrap(hist_rotation_combined, axis=-1)).to(device) 
                hist_cos_yaws, hist_sin_yaws = torch.cos(hist_yaws)[..., None, None], torch.sin(hist_yaws)[..., None, None] 
                hist_R_yaw = torch.cat([
                    torch.cat([ hist_cos_yaws, zeros,  hist_sin_yaws], dim=-1),
                    torch.cat([ zeros,         ones,   zeros], dim=-1),
                    torch.cat([-hist_sin_yaws, zeros,  hist_cos_yaws], dim=-1)
                ], dim=-2)
                hist_roots = (history_mocap[...,11,:] - history_mocap[...,12,:]) / 2 
                hist_rotated_mocap = (history_mocap - hist_roots.unsqueeze(2)) @ torch.permute(hist_R_yaw,(0,1,3,2)) 

                # for MAE
                gen_rotated_mocap_normed_mae = scale_data(gen_rotated_mocap, data_range[0], data_range[1], 0, 100) 
                gt_rotated_mocap_normed_mae = scale_data(gt_rotated_mocap, data_range[0], data_range[1], 0, 100) 
                for sample in range(n_samples):
                    total_loss[sample] += reduced_criterion(gen_rotated_mocap_normed_mae[:,sample,:,:,:], gt_rotated_mocap_normed_mae).detach().cpu()

                preds_array = gen_rotated_mocap.detach().cpu().numpy()
                trues_array = gt_rotated_mocap.detach().cpu().numpy()
                inputs_array = hist_rotated_mocap.detach().cpu().numpy()
                unrot_gts.extend(horizon_mocap.detach().cpu().numpy())
                unrot_preds.extend(generated_mocap.detach().cpu().numpy())
                unrot_base.extend(history_mocap.detach().cpu().numpy())

            patterns_array = aa_pattern.detach().cpu().numpy()          
            if guidance_type != 'constant':
                uqs_array = aauq.detach().cpu().numpy()
                guidances_array = uq_guidance.detach().cpu().numpy()
            
            patterns.extend(patterns_array)
            predictions.extend(preds_array)
            trues.extend(trues_array)
            baselines.extend(inputs_array)
            if guidance_type != 'constant':
                uqs.extend(uqs_array)
                guidances.extend(guidances_array)
            else:
                uqs = np.NaN
                guidances = np.NaN            
            
    for sample in range(n_samples):
        total_loss[sample] /= len(dataloader)
        print(f'Sample {sample} MAE: {total_loss[sample]}')
    print(total_loss)
    print(f'Mean MAE: {np.mean(total_loss)}, Std MAE: {np.std(total_loss)}')

    if guidance_type == 'constant':
        return np.array(patterns), np.array(predictions), np.array(trues), np.array(baselines), np.array(unrot_gts), np.array(unrot_preds), np.array(unrot_base)#, np.array(uqs), np.array(guidances)
    else:
        return np.array(patterns), np.array(predictions), np.array(trues), np.array(baselines), np.array(unrot_gts), np.array(unrot_preds), np.array(unrot_base), np.array(uqs), np.array(guidances)

if __name__ == '__main__':
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--tr_dataset', type=str, choices=['uwhvf', 'break', 'house', 'ballet_jazz', 'street_jazz', 'krump', 'la_hip_hop', 'lock', 'middle_hip_hop', 'pop', 'wack'], default='scheie', help='Dataset.')
    parser.add_argument('--te_dataset', type=str, choices=['uwhvf', 'break', 'house', 'ballet_jazz', 'street_jazz', 'krump', 'la_hip_hop', 'lock', 'middle_hip_hop', 'pop', 'wack'], default='scheie', help='Dataset.')
    parser.add_argument('--split', type=str, choices=['val', 'test'], default='test', help='Dataset split.')
    parser.add_argument('--representation', type=str, choices=['hvf','td','kp_norm','kp_raw'], default='td', help='Representation of data.')
    parser.add_argument('--model_pth', type=str, help='Path of trained model.')
    parser.add_argument('--save_dir', type=str, default='./saved_pgdm_models', help='Directory to save training results in.')
    parser.add_argument('--seed', type=int, default=0, help='Random seed.')
    parser.add_argument('--gpu', type=int, default=0, help='GPU.')
    parser.add_argument('--cached', action='store_true', help='Use cached results.')
    parser.add_argument('--plot', action='store_true', help='Plot images.')
    parser.add_argument('--n_samples', type=int, default=5, help='Number of samples.')

    # details
    parser.add_argument('--target', type=str, choices=['x0', 'noise'], default='noise', help='Which target to use when training.')
    parser.add_argument('--batch_size', type=int, default=128, help='Batch size.')

    # conditioning inputs
    parser.add_argument('--n_frames', type=int, default=3, help='Number of frames required for prediction.')
    parser.add_argument('--n_horizon', type=int, default=1, help='Number of frames to predict.')
    parser.add_argument('--unet_dim', type=int, default=20, help='Dimensions for UNet')
    parser.add_argument('--time_dim', type=int, default=16, help='Dimension of time embedding.')
    parser.add_argument('--hidden_dim', type=int, default=64, help='Hidden dimensions of pattern prediction model.')
    parser.add_argument('--pattern_model_pth', type=str, help='Pattern prediction model.')
    parser.add_argument('--guidance_scale', type=float, default=1, help='Guidance scale for pattern guidance (w) or dynamic pattern guidance (w_bar).')
    parser.add_argument('--error_scale', type=float, default=1, help='Error scale for dynamic pattern guidance (gamma).')
    parser.add_argument('--guidance_type', type=str, choices=['constant', 'log', 'sigmoid', 'relu'], default='relu', help='Guidance type: constant or AAUQ scaling type.')
    parser.add_argument('--aauq_type', choices=['mean','last'], default='mean', help='Method for calculating AAUQ.')
    
    parser.add_argument('--guide_mix', action='store_true', help='Use both guidance and output mixing.')
    parser.add_argument('--mix_scale', type=float, default=1, help='Mixing scale.')
    parser.add_argument('--mix_error', type=float, default=1, help='Mixing error scale.')

    args = parser.parse_args()

    assert args.guidance_scale >= 0, 'Guidance scale must be non-negative.'
    assert args.mix_scale >= 0 and args.mix_scale <= 1, 'Mixing scale must be in [0,1].'
    args.mix_error = args.error_scale

    # set up save paths
    args.save_dir = f'{args.save_dir}_{args.tr_dataset}/{args.te_dataset}'
    os.makedirs(args.save_dir, exist_ok=True)

    # more setup
    device = torch.device(f"cuda:{str(args.gpu)}" if torch.cuda.is_available() else "cpu")
    img_size = (1,10,10)
    beta_minmax = [1e-4, 2e-2]

    with open(f"archetypes/{args.tr_dataset}_aa_object.pkl", 'rb') as f:
        aa = pickle.load(f)

    # seeds
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    # dataset
    if 'uwhvf' in args.te_dataset:
        data_range = (-37.69, 50.00)
    elif 'scheie' in args.te_dataset:
        data_range = (-38.00, 40.00)
    elif 'break' in args.te_dataset:
        data_range = (-217.49, 286.76)
    elif 'street_jazz' in args.te_dataset:
        data_range = (-102.57, 101.90)
    elif 'ballet_jazz' in args.te_dataset:
        data_range = (-308.86, 327.77)
    elif 'middle_hip_hop' in args.te_dataset:
        data_range = (-91.20, 92.65)
    elif 'house' in args.te_dataset:
        data_range = (-78.08, 81.47)
    elif 'krump' in args.te_dataset:
        data_range = (-85.32, 110.89)
    elif 'la_hip_hop' in args.te_dataset:
        data_range = (-77.41, 100.26)
    elif 'lock' in args.te_dataset:
        data_range = (-109.02, 95.57)
    elif 'pop' in args.te_dataset:
        data_range = (-132.20, 103.33)
    elif 'wack' in args.te_dataset:
        data_range = (-283.42, 413.92)
    else:
        raise NotImplementedError

    if args.te_dataset in ['uwhvf']:
        data_dir = f'datasets/{args.te_dataset}_seqp_{args.n_frames}'
    else:
        data_dir = f'datasets/aistpp_seqp_aa/{args.te_dataset}/h{args.n_frames}_H{args.n_horizon}'
    with open(os.path.join(data_dir, f'{args.split}.json')) as f:
        test_data = json.loads(f.read())

    if args.te_dataset in ['uwhvf']:
        test_dataset = VF_Dataset(test_data, representation=args.representation, normalize=True, data_range=data_range)
        padding_mask = test_dataset.get_padding_mask()
    else:
        test_dataset = MoCap_Dataset(test_data, representation=args.representation, normalize=True, data_range=data_range)

    model_name = os.path.basename(args.model_pth).split('.')[0]
    if args.guidance_type == 'constant':
        base_fig_dir = f'{os.path.dirname(args.model_pth)}/figs/{model_name}/{args.te_dataset}/{args.guidance_scale}'
    elif args.guidance_type in ['sigmoid', 'relu']:
        base_fig_dir = f'{os.path.dirname(args.model_pth)}/figs/{model_name}/{args.te_dataset}/{args.guidance_type}_{args.guidance_scale}_{args.error_scale}'
    else:   
        base_fig_dir = f'{os.path.dirname(args.model_pth)}/figs/{model_name}/{args.te_dataset}/{args.guidance_type}_{args.guidance_scale}'
    if args.guide_mix:
        base_fig_dir = f'{base_fig_dir}_{args.mix_scale}_{args.mix_error}'
    os.makedirs(base_fig_dir, exist_ok=True)

    if not args.cached:
        # Create pattern prediction model
        n_archetypes = aa.n_archetypes_
        archetypes = torch.FloatTensor(aa.archetypes_).to(device)
        if args.te_dataset in ['uwhvf']:
            pattern_model = FF(n_frames=args.n_frames,
                    n_archetypes = n_archetypes,
                    hidden_dim = args.hidden_dim,
                    out_dim = n_archetypes,
                    dim=args.time_dim,
                    ).to(device)
            pattern_model.load_state_dict(torch.load(args.pattern_model_pth))

            model = Unet(dim=args.unet_dim,
                    channels=1,
                    out_dim=1,
                    dim_mults=(1, 2,),
                    resnet_block_groups=4,
                    n_frames=args.n_frames,
                    ).to(device)
            
            diffusion = Diffusion(model, image_resolution=img_size, n_times=1000,
                                beta_minmax=beta_minmax, device=device, target=args.target).to(device)
        else:
            pattern_model = FF_MoCap(n_frames=args.n_frames,
                n_horizon=args.n_horizon,
                n_archetypes = n_archetypes,
                hidden_dim = args.hidden_dim,
                out_dim = n_archetypes*args.n_horizon,
                ).to(device)
            pattern_model.load_state_dict(torch.load(args.pattern_model_pth))

            model = Unet_MoCap(dim=args.unet_dim,
                    channels=1,
                    dim_mults=(1, 2,),
                    resnet_block_groups=4,
                    n_frames=args.n_frames,
                    out_dim=args.n_horizon,
                    ).to(device)

            diffusion = Diffusion_MoCap(model, image_resolution=img_size, n_times=1000,
                                beta_minmax=beta_minmax, device=device, target=args.target).to(device)
        
        diffusion.model.load_state_dict(torch.load(args.model_pth))
        diffusion.model.eval()

        # test loop
        print('Testing...')
        if args.guidance_type == 'constant':
            patterns, sampled, gts, histories, unrot_gts, unrot_sampled, unrot_hist = test_model(diffusion, pattern_model, aa, archetypes, test_dataset, args.gpu, args.guidance_scale, args.error_scale, 128, args.guidance_type)

            results = {
                'model_pth':args.model_pth,
                'data_dir':data_dir,
                'seed':args.seed,
                'patterns':patterns,
                'inputs':histories,
                'groundtruths':gts,
                'predictions':sampled,
                'unrot_groundtruths':unrot_gts,
                'unrot_predictions':unrot_sampled,
                'unrot_inputs':unrot_hist,
            }        
        else:
            patterns, sampled, gts, histories, unrot_gts, unrot_sampled, unrot_hist, uqs, guidances = test_model(diffusion, pattern_model, aa, archetypes, test_dataset, args.gpu, args.guidance_scale, args.error_scale, 128, args.guidance_type)

            results = {
                'model_pth':args.model_pth,
                'data_dir':data_dir,
                'seed':args.seed,
                'patterns':patterns,
                'inputs':histories,
                'groundtruths':gts,
                'predictions':sampled,
                'unrot_groundtruths':unrot_gts,
                'unrot_predictions':unrot_sampled,
                'unrot_inputs':unrot_hist,
            }
        with open(os.path.join(f'{base_fig_dir}/results.pkl'), 'wb') as f:
            pickle.dump(results, f)
    else:
        with open(os.path.join(f'{base_fig_dir}/results.pkl'), 'rb') as f:
            results = pickle.load(f)
        patterns = results['patterns']
        histories = results['inputs']
        gts = results['groundtruths']
        sampled = results['predictions']
        unrot_gts = results['unrot_groundtruths']
        unrot_sampled = results['unrot_predictions']
        unrot_hist = results['unrot_inputs']

    if args.plot and args.te_dataset in ['uwhvf']:
        batch_size = 10
        n_batches = int(np.ceil(len(test_dataset)/batch_size))
        gts = scale_data(gts, data_range[0], data_range[1], 0, 1)
        histories = scale_data(histories, data_range[0], data_range[1], 0, 1)
        sampled = scale_data(sampled, data_range[0], data_range[1], 0, 1)
        for b in range(n_batches - 1):
            idx_l = b*batch_size
            idx_u = b*batch_size + batch_size
            batch_patterns = torch.from_numpy(patterns[idx_l:idx_u,:,:,:])
            batch_sampled = torch.from_numpy(sampled[idx_l:idx_u,:,:,:])
            batch_gts = torch.from_numpy(gts[idx_l:idx_u,:,:,:])
            if args.guidance_type != 'constant':
                batch_uqs = uqs[idx_l:idx_u]
                batch_guidances = guidances[idx_l:idx_u]
                subtitles = [f'AAUQ: {batch_uqs[i]:.2f}, G:{batch_guidances[i]:.2f}' for i in range(len(batch_uqs))]
            save_pth = f'{base_fig_dir}/{b}.png'
            if args.guidance_type != 'constant':
                plot_gt_pred_batch(batch_gts, batch_patterns, batch_sampled, padding_mask, save_pth, sub_titles=subtitles)
            else:
                plot_gt_pred_batch(batch_gts, batch_patterns, batch_sampled, padding_mask, save_pth)
    if args.plot and args.te_dataset not in ['uwhvf']:
        n_examples = 20
        n_samples = 5
        patterns = patterns.reshape(-1,args.n_horizon,17,3)
        idx = np.random.choice(list(range(gts.shape[0])), size=n_examples, replace=False)
        for b in idx:
            batch_patterns = torch.from_numpy(patterns[b,:,:])
            batch_sampled = torch.from_numpy(sampled[b,0,:,:,:])
            batch_gts = torch.from_numpy(gts[b,:,:,:])
            batch_histories = torch.from_numpy(histories[b,:,:,:])
            batch_unrot_gts = torch.from_numpy(unrot_gts[b,:,:,:])
            batch_unrot_sampled = torch.from_numpy(unrot_sampled[b,0,:,:,:])
            batch_unrot_hist = torch.from_numpy(unrot_hist[b,:,:,:])
            save_pth = f'{base_fig_dir}/{b}_rot.png'
            plot_mocap_gt_pred(batch_histories, batch_gts, batch_patterns, batch_sampled, save_pth)
            save_pth = f'{base_fig_dir}/{b}_unrot.png'
            plot_mocap_gt_pred(batch_unrot_hist, batch_unrot_gts, batch_patterns, batch_unrot_sampled, save_pth)