import json
import os
import argparse
import wandb

import numpy as np
import random
import torch
from torch import nn, optim
from tqdm.auto import tqdm
from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoConfig, PretrainedConfig
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.utils.data import DataLoader

from pathlib import Path
import sys
print(Path(__file__).parents[0])
print(Path(__file__).parents[1])
print(Path(__file__).parents[2])
path_root = Path(__file__).parents[2]
print(path_root)
sys.path.append(str(path_root))
from PIL import Image
from models.clip_cls import ClipClassifier
from models.gpt2_cls import Gpt2Classifier
import open_clip

sys.path.append('/tmp/exlib/src')
import exlib
from exlib.evaluators.attributions import NNZ, NNZGroup, InsDel, CompSuff, Consistency
from exlib.evaluators.comp_suff import CompSuffText
from exlib.modules.sop import SOP
from exlib.modules.fresh import FRESH
from exlib.evaluators.visualizer import TextVisualizer

# from models.wrapper import SumOfPartsWrapper, get_inverse_sqrt_with_separate_heads_schedule_with_warmup
from models.cnn import CNNModel
from custom_datasets import get_datasets, get_masked_datasets


def get_chained_attr(obj, attr_path):
    if attr_path is None:
        return None
    # Split the attribute path into individual attribute names
    attributes = attr_path.split('.')

    # Access the attribute dynamically
    desired_attribute = obj
    for attr in attributes:
        desired_attribute = getattr(desired_attribute, attr)
    return desired_attribute

def parse_args():
    parser = argparse.ArgumentParser()

    # models and configs
    parser.add_argument('--blackbox-model-name', type=str, 
                        default='google/vit-base-patch16-224', 
                        help='black box model name')
    parser.add_argument('--blackbox-processor-name', type=str, 
                        default='google/vit-base-patch16-224', 
                        help='black box processor name')
    parser.add_argument('--wrapper-config-filepath', type=str, 
                        default='actions/wrapper/configs/imagenet_vit_wrapper_default.json', 
                        help='wrapper config file')
    parser.add_argument('--exp-dir', type=str, 
                        default='exps/imgenet_wrapper', 
                        help='exp dir')
    parser.add_argument('--mask-dir', type=str, 
                        default=None, 
                        help='mask dir, if none, generate, else use')
    # parser.add_argument('--vis-dir', type=str, 
    #                     default='exps/baselines/baselines.json', 
    #                     help='exp dir')
    parser.add_argument('--wrapper-model', type=str, 
                        default='sop', choices=['sop', 'fresh'],
                        help='sop or fresh')
    parser.add_argument('--model-type', type=str, 
                        default='image', choices=['image', 'text'],
                        help='image or text model')
    parser.add_argument('--model-type-spec', type=str, 
                        default='encoder', choices=['encoder', 'clip', 'gpt2'],
                        help='what specific model to use')
    parser.add_argument('--projection-layer-name', type=str, 
                        default=None,  #distilbert.embeddings
                        help='projection layer if specified, else train from scratch')
    
    # data
    parser.add_argument('--dataset', type=str, 
                        default='imagenet', choices=['imagenet', 
                                                     'imagenet_m',
                                                     'multirc',
                                                     'movies',
                                                     'voc',
                                                     'cosmogrid'],
                        help='which dataset to use')
    parser.add_argument('--train-size', type=int, 
                        default=-1, 
                        help='-1: use all, else randomly choose k per class')
    parser.add_argument('--val-size', type=int, 
                        default=-1, 
                        help='-1: use all, else randomly choose k per class')
    
    # training
    parser.add_argument('--num-epochs', type=int, 
                        default=1, 
                        help='num epochs')
    parser.add_argument('--num-train-reps', type=int, 
                        default=1, 
                        help='number of times to train each head')
    parser.add_argument('--lr', type=float, 
                        default=1e-4, 
                        help='learning rate')
    parser.add_argument('--lr-scheduler-step-size', type=int, 
                        default=1, 
                        help='learning rate scheduler step size (by epoch)')
    parser.add_argument('--lr-scheduler-gamma', type=float, 
                        default=0.1, 
                        help='learning rate scheduler gamma')
    parser.add_argument('--warmup-epochs', type=float, 
                        default=-1, 
                        help='number of epochs to warmup')
    parser.add_argument('--warmup-steps', type=float, 
                        default=0, 
                        help='number of steps to warmup. Use this when warmup epochs is -1')
    parser.add_argument('--scheduler-type', type=str, 
                        default='inverse_sqrt_heads', choices=['cosine',
                                                               'constant',
                                                               'inverse_sqrt_heads'],
                        help='scheduler type')
    parser.add_argument('--criterion', type=str, 
                        default='ce', choices=['ce', 'bce', 'mse'],
                        help='which criterion to use, if multi-label then bce')
    parser.add_argument('--batch-size', type=int, 
                        default=32, 
                        help='batch size')
    parser.add_argument('--mask-batch-size', type=int, 
                        default=16, 
                        help='mask batch size')
    parser.add_argument('--project-name', type=str, 
                        default='attn', 
                        help='wandb project name')
    parser.add_argument('--seed', type=int, 
                        default=42, 
                        help='seed')
    parser.add_argument('--shuffle-val', action='store_true', 
                        default=False, 
                        help='if true then shuffle val')
    parser.add_argument('--num-masks-sample-inference', type=int, 
                        default=None, 
                        help='number of masks to retain for mask dropout.')

    # specify wrapper config. If not None, then use this instead of ones specified in config
    parser.add_argument('--attn-patch-size', type=int, 
                        default=None, 
                        help='attn patch size, does not have to match black box model patch size')
    parser.add_argument('--attn-stride-size', type=int, 
                        default=None, 
                        help='attn stride size, if smaller than patch size then there is overlap')
    parser.add_argument('--num-heads', type=int, 
                        default=None, 
                        help='hidden dim for the first attention layer')
    parser.add_argument('--num-masks-sample', type=int, 
                        default=None, 
                        help='number of masks to retain for mask dropout.')
    parser.add_argument('--num-masks-max', type=int, 
                        default=None, 
                        help='number of maximum masks to retain.')
    parser.add_argument('--finetune-layers', type=str, 
                        default=None, 
                        help='Which layer to finetune, seperated by comma.')
    parser.add_argument('--ins-del-step-size', 
                        default=None,
                        type=int,
                        help='ins del step size for insertion deletion')
    
    parser.add_argument('--aggr-type', type=str, 
                        default='joint', choices=['joint', 'independent'],
                        help='usually we use joint for classification, but can also try independent')
    parser.add_argument('--proj-hid-size', type=int, 
                        default=None, 
                        help='If specified, use this instead of hidden size for projection')
    parser.add_argument('--mean-center-offset', 
                        default=0,
                        type=float,
                        help='offset to mean center before projection')
    parser.add_argument('--mean-center-offset2', 
                        default=None,
                        type=float,
                        help='offset to mean center after projection')
    parser.add_argument('--mean-center-scale', 
                        default=0,
                        type=float,
                        help='scale factor to scale up the features after mean centering before projection')
    parser.add_argument('--mean-center-scale2', 
                        default=0,
                        type=float,
                        help='scale factor to scale up the features after mean centering after projection')
    parser.add_argument('--save', action='store_true', 
                        default=False, 
                        help='save the attributions if true, else compute the metrics')
    
    
    
    return parser

def main():
    parser = parse_args()
    args = parser.parse_args()

    print('\n---argparser---:')
    for arg in vars(args):
        print(arg, getattr(args, arg), '\t', type(arg))

    os.makedirs(args.exp_dir, exist_ok=True)

    # Set the seed for reproducibility
    if args.seed != -1:
        # Torch RNG
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        # Python RNG
        np.random.seed(args.seed)
        random.seed(args.seed)

    print(f'Project name {args.project_name}\n')

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Define your dataset and dataloader
    if args.model_type == 'image':
        if args.model_type_spec == 'clip':
            # processor = inner_model.preprocess
            _, _, processor = open_clip.create_model_and_transforms('ViT-B-16',
                                                                    pretrained='laion2b_s34b_b88k')
            def transform(batch):
                return processor(batch)
        else:
            if args.dataset == 'cosmogrid':
                processor = None
            else:
                processor = AutoImageProcessor.from_pretrained(args.blackbox_processor_name)
    else:
        processor = AutoTokenizer.from_pretrained(args.blackbox_processor_name)

    if args.dataset == 'cosmogrid':
        # merge blackbox config and wrapper config
        blackbox_config = PretrainedConfig()
        with open(args.wrapper_config_filepath) as input_file:
            wrapper_config = json.load(input_file)
        config = blackbox_config
        config.__dict__.update(wrapper_config)
        config.blackbox_model_name = args.blackbox_model_name
        config.blackbox_processor_name = args.blackbox_processor_name
    else:
        # merge blackbox config and wrapper config
        blackbox_config = AutoConfig.from_pretrained(args.blackbox_model_name)
        with open(args.wrapper_config_filepath) as input_file:
            wrapper_config = json.load(input_file)
        config = blackbox_config
        config.__dict__.update(wrapper_config)
        config.blackbox_model_name = args.blackbox_model_name
        config.blackbox_processor_name = args.blackbox_processor_name

    # allow specifying args to be different from in the json file
    specifiable_arg_list = ['attn_patch_size', 
                            'attn_stride_size', 
                            'num_heads',
                            'num_masks_sample',
                            'num_masks_max',
                            'finetune_layers',
                            'ins_del_step_size']
    for specifiable_arg in specifiable_arg_list:
        arg_value = getattr(args, specifiable_arg)
        if arg_value is not None:
            if specifiable_arg == 'finetune_layers':
                config.__dict__[specifiable_arg] = arg_value.split(',')
            else:
                config.__dict__[specifiable_arg] = arg_value

    if args.mask_dir is None:
        val_dataset = get_datasets(args.dataset, 
                                    processor, 
                                    debug=False,
                                    train_size=args.train_size,
                                    val_size=args.val_size,
                                    label2id=config.label2id,
                                    val_only=True,
                                    mode='test')
    else:
        val_dataset = get_masked_datasets(args.dataset, 
                                                    # processor, 
                                                    processor=None,
                                                    transform=lambda x: x,
                                                    mask_dir=args.mask_dir,
                                                    seg_mask_cut_off=args.num_masks_max,
                                                    debug=False,
                                                    train_size=args.train_size,
                                                    val_size=args.val_size,
                                                    val_only=True)
    
    if args.seed != -1:
        # Torch RNG
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        # Python RNG
        np.random.seed(args.seed)
        random.seed(args.seed)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, 
                            shuffle=args.shuffle_val)

    if args.model_type == 'image':
        if args.dataset == 'cosmogrid':
            inner_model = CNNModel(config.num_labels)
            state_dict = torch.load(args.blackbox_model_name)
            inner_model.load_state_dict(state_dict=state_dict)
        else:
            if args.model_type_spec == 'clip':
                inner_model = ClipClassifier()
            else:
                inner_model = AutoModelForImageClassification.from_pretrained(args.blackbox_model_name)     
    else:
        inner_model = AutoModelForSequenceClassification.from_pretrained(args.blackbox_model_name)
    
    projection_layer = get_chained_attr(inner_model, args.projection_layer_name)

    if args.wrapper_model == 'sop':
        model = SOP.from_pretrained(args.exp_dir,
                                config=config, 
                                blackbox_model=inner_model,
                                model_type=args.model_type,
                                projection_layer=projection_layer,
                                aggr_type=args.aggr_type,
                                pooler=False if args.model_type_spec == 'encoder' and args.dataset != 'cosmogrid' else True,
                                return_tuple=True,
                                mean_center_scale=args.mean_center_scale,
                                mean_center_scale2=args.mean_center_scale2,
                                mean_center_offset=args.mean_center_offset,
                                mean_center_offset2=args.mean_center_offset2)
    else:
        model = FRESH.from_pretrained(args.exp_dir,
                                      config=config,
                      blackbox_model=inner_model,
                      model_type=args.model_type,
                      return_tuple=True,
                      postprocess_attn=lambda x: x.attentions[-1].mean(dim=1)[:,0],
                      postprocess_logits=lambda x: x.logits)

    # model = model.to(device)
    model = model.to(device)
    model.eval()

    if args.wrapper_model == 'sop':
        nnz_evaluator = NNZGroup(model, postprocess=lambda x:x.logits).to(device)
    else:
        nnz_evaluator = NNZ().to(device)
    consis_evaluator = Consistency(model, postprocess=lambda x:x.logits).to(device)
    net_in_size = config.ins_del_step_size if config.ins_del_step_size else 224
    ins_evaluator = InsDel(model, 
                            'ins', 
                            net_in_size, 
                            substrate_fn=torch.zeros_like,
                            postprocess=lambda x: x.logits,
                            task_type='cls' if args.dataset != 'cosmogrid' else 'reg')
    
    if args.model_type == 'image':
        klen = 11
        ksig = 5
        kern = InsDel.gkern(klen, ksig, config.num_channels).to(device)
        blur = lambda x: nn.functional.conv2d(x, kern, padding=klen//2)
        ins_blur_evaluator = InsDel(model, 
                                    'ins', 
                                    net_in_size, 
                                    substrate_fn=blur,
                                    postprocess=lambda x: x.logits,
                                    task_type='cls' if args.dataset != 'cosmogrid' else 'reg').to(device)
    else: # doesn't make sense to blur the text
        ins_blur_evaluator = None
    del_evaluator = InsDel(model, 
                            'del', 
                            net_in_size,
                            substrate_fn=torch.zeros_like,
                            postprocess=lambda x: x.logits,
                            task_type='cls' if args.dataset != 'cosmogrid' else 'reg').to(device)
    
    if args.model_type == 'image':
        comp_evaluator = CompSuff(model, 
                                'comp', 
                                k_fraction=0.2,
                                postprocess=lambda x: x.logits).to(device)
        
        suff_evaluator = CompSuff(model, 
                                'suff', 
                                k_fraction=0.2,
                                postprocess=lambda x: x.logits).to(device)
    else:
        comp_evaluator = CompSuffText(model, 
                                'comp', 
                                postprocess=lambda x: x.logits).to(device)
        
        suff_evaluator = CompSuffText(model, 
                                'suff', 
                                postprocess=lambda x: x.logits).to(device)
    
    # if args.model_type == 'text':
    #     visualizer = TextVisualizer(processor, normalize=True)
    # else:
    #     visualizer = None

    htmls = []
    htmls_max = []

    correct = 0
    nnz_scores = 0.
    nnz_scores_max = 0.
    ins_scores = 0.
    ins_blur_scores = 0.
    del_scores = 0.
    comp_scores = 0.
    suff_scores = 0.
    consis_scores = 0.
    ins_scores_max = 0.
    ins_blur_scores_max = 0.
    del_scores_max = 0.
    comp_scores_max = 0.
    suff_scores_max = 0.
    total = 0
    print('Eval..')
    
    if args.save:
        os.makedirs(os.path.join(args.exp_dir, 'attributions'), exist_ok=True)
    with torch.no_grad():
        progress_bar_eval = tqdm(range(len(val_loader)))
        b_count = 0
        for b_i, batch in enumerate(val_loader):
            if args.model_type == 'image':
                # inputs, labels = batch
                # inputs, labels = inputs.to(device), labels.to(device)
                if args.mask_dir is None:
                    inputs, labels = batch
                    masks = None
                else:
                    inputs, labels, masks, masks_i = batch
                    masks = masks.to(device)
                if args.dataset != 'cosmogrid':
                    inputs, labels = inputs.to(device), labels.to(device)
                else:
                    inputs = inputs.to(device, dtype=torch.float)
                    labels = labels.to(device, dtype=torch.float)
                token_type_ids = None
                attention_mask = None
                kwargs = {}
            else:
                if not isinstance(batch['input_ids'], torch.Tensor):
                    inputs = torch.stack(batch['input_ids']).transpose(0, 1).to(device)
                    if 'token_type_ids' in batch:
                        token_type_ids = torch.stack(batch['token_type_ids']).transpose(0, 1).to(device)
                    else:
                        token_type_ids = None
                    attention_mask = torch.stack(batch['attention_mask']).transpose(0, 1).to(device)
                else:
                    inputs = batch['input_ids'].to(device)
                    if 'token_type_ids' in batch:
                        token_type_ids = batch['token_type_ids'].to(device)
                    else:
                        token_type_ids = None
                    attention_mask = batch['attention_mask'].to(device)
                labels = batch['label'].to(device)
                kwargs = {'token_type_ids': token_type_ids,
                          'attention_mask': attention_mask}

            bsz = inputs.size(0)
            
            if args.wrapper_model == 'sop':
                outputs = model(inputs, 
                                token_type_ids=token_type_ids,
                                attention_mask=attention_mask,
                                epoch=model.config.num_heads,
                                mask_batch_size=args.mask_batch_size,
                                label=torch.abs(labels) - 1 if args.criterion == 'bce' else None
                                )
            else:  # 'fresh'
                outputs = model(inputs,
                                token_type_ids=token_type_ids,
                                attention_mask=attention_mask)
            
            attributions = outputs.attributions
            logits = outputs.logits

            # accuracy
            if args.criterion == 'ce':
                _, predicted = torch.max(logits.data, 1)
                correct += (predicted == labels).sum().item()
            elif args.criterion == 'bce':
                probs = torch.sigmoid(logits)
                predicted = (probs > 0.5).float()
                correct += (predicted[range(bsz), torch.abs(labels) - 1] == (labels > 0)).sum().item()
            else: # mse
                criterion = nn.MSELoss()
                predicted = criterion(logits, labels)
                correct += predicted.sum().item()
            total += labels.size(0)

            if not args.save:
                if args.wrapper_model == 'sop':
                    masks = outputs.masks
                    mask_weights = outputs.mask_weights
                    nnz_scores += nnz_evaluator(inputs, masks,
                                            kwargs=kwargs, W=mask_weights, 
                                            pred=predicted if args.criterion == 'ce' else None,
                                            normalize=True).sum().item()
                    # for i in range(bsz):
                    #     masks_used = masks[i][mask_weights[i].sum(-1).bool()].unsqueeze(0)
                    #     # import pdb
                    #     # pdb.set_trace()
                    #     nnz_scores += nnz_evaluator(inputs[i], masks[i].unsqueeze(0),
                    #                                 mask_weights[i].unsqueeze(0), 
                    #                                 normalize=True).item()
                else:
                    # get nnz score for aggregated mask
                    nnz_score = nnz_evaluator(inputs, 
                        attributions,
                        normalize=True).to(device)
                    nnz_scores += nnz_score.sum().item()

                # consis_scores = consis_evaluator(inputs, attributions, kwargs,
                #                                  pred=predicted).sum().item()
                if args.model_type == 'image':
                    ins_score = ins_evaluator(inputs, attributions, kwargs)
                    if ins_blur_evaluator is not None:
                        ins_blur_score = ins_blur_evaluator(inputs, attributions, kwargs)
                    del_score = del_evaluator(inputs, attributions, kwargs)
                    if args.dataset == 'cosmogrid':
                        ins_score = ins_score.mean(-1)
                        ins_blur_score = ins_blur_score.mean(-1)
                        del_score = del_score.mean(-1)
                    ins_scores += ins_score.sum().item()
                    ins_blur_scores += ins_blur_score.sum().item()
                    del_scores += del_score.sum().item()
                else:
                    comp_score = comp_evaluator(inputs, attributions, kwargs)
                    suff_score = suff_evaluator(inputs, attributions, kwargs)
                    comp_scores += comp_score.sum().item()
                    suff_scores += suff_score.sum().item()
            else:
                save_filepath = os.path.join(args.exp_dir, 'attributions' , f'{b_i}.pt')
                torch.save({'inputs': inputs,
                            'labels': labels,
                            'outputs': outputs}, save_filepath)

            

            
            

            # if visualizer is not None:
            #     htmls.extend(visualizer(inputs, attributions, labels))
            
            # if args.wrapper_model == 'sop':
            #     attributions_max = outputs.attributions_max
            #     nnz_scores_max += nnz_evaluator(inputs, attributions_max, normalize=True).sum().item()
            #     ins_scores_max += ins_evaluator(inputs, attributions_max, kwargs).sum().item()
            #     if ins_blur_evaluator is not None:
            #         ins_blur_scores_max += ins_blur_evaluator(inputs, attributions_max, kwargs).sum().item()
            #     del_scores_max += del_evaluator(inputs, attributions_max, kwargs).sum().item()
            #     comp_scores_max += comp_evaluator(inputs, attributions_max, kwargs).sum().item()
            #     suff_scores_max += suff_evaluator(inputs, attributions_max, kwargs).sum().item()
                
            #     if visualizer is not None:
            #         htmls_max.extend(visualizer(inputs, attributions_max, labels))

            progress_bar_eval.update(1)
            b_count += 1
            # if b_count >= 3:
            #     break
    if not args.save:
        val_acc = None
        nnz_avg = None
        ins_avg = None
        ins_blur_avg = None
        del_avg = None
        comp_avg = None
        suff_avg = None
        nnz_avg_max = None
        ins_avg_max = None
        ins_blur_avg_max = None
        del_avg_max = None
        comp_avg_max = None
        suff_avg_max = None

        val_acc = 100 * correct / total
        nnz_avg = 100 * nnz_scores / total
        consis_avg = 100 * consis_scores / total
        nnz_avg_max = 100 * nnz_scores_max / total
        ins_avg = 100 * ins_scores / total
        if ins_blur_evaluator is not None:
            ins_blur_avg = 100 * ins_blur_scores / total
        del_avg = 100 * del_scores / total
        comp_avg = 100 * comp_scores / total
        suff_avg = 100 * suff_scores / total
        if args.wrapper_model == 'sop':
            ins_avg_max = 100 * ins_scores_max / total
            if ins_blur_evaluator is not None:
                ins_blur_avg_max = 100 * ins_blur_scores_max / total
            del_avg_max = 100 * del_scores_max / total
            comp_avg_max = 100 * comp_scores_max / total
            suff_avg_max = 100 * suff_scores_max / total
        print(f'Validation Accuracy (High) {val_acc:.2f}%')
        print(f'Number of Non-zeros (Low) {nnz_avg:.2f}%')
        print(f'Consistency (High) {consis_avg:.2f}%')
        print(f'Insertion Score Avg (High) {ins_avg:.2f}%')
        if ins_blur_evaluator is not None:
            print(f'Insertion Blur Score Avg (High) {ins_blur_avg:.2f}%')
        print(f'Deletion Score Avg (Low) {del_avg:.2f}%')
        print(f'Comprehensiveness Score Avg (High) {comp_avg:.2f}%')
        print(f'Sufficiency Score Avg (Low) {suff_avg:.2f}%')

        # if args.wrapper_model == 'sop':
        #     print(f'Number of Non-zeros Max (Low) {nnz_avg_max:.2f}%')
        #     print(f'Insertion Score Avg Max (High) {ins_avg_max:.2f}%')
        #     if ins_blur_evaluator is not None:
        #         print(f'Insertion Blur Score Avg Max (High) {ins_blur_avg_max:.2f}%')
        #     print(f'Deletion Score Avg Max (Low) {del_avg_max:.2f}%')
        #     print(f'Comprehensiveness Score Avg Max (High) {comp_avg_max:.2f}%')
        #     print(f'Sufficiency Score Avg Max (Low) {suff_avg_max:.2f}%')

        # if len(htmls) > 0:
        #     vis_filepath = os.path.join(args.exp_dir, 'vis.html')
        #     visualizer.save(htmls, vis_filepath)
        # if len(htmls_max) > 0:
        #     vis_filepath = os.path.join(args.exp_dir, 'vis_max.html')
        #     visualizer.save(htmls, vis_filepath)

        results = {
            args.wrapper_model: {
                'dataset': args.dataset,
                'accuracy': val_acc,
                'num_nonzeros': nnz_avg,
                'insertion': ins_avg,
                'insertion_blur': ins_blur_avg,
                'deletion': del_avg,
                'comprehensiveness': comp_avg,
                'sufficiency': suff_avg,
                # 'num_nonzeros_max': nnz_avg_max,
                # 'insertion_max': ins_avg_max,
                # 'insertion_blur_max': ins_blur_avg_max,
                # 'deletion_max': del_avg_max,
                # 'comprehensiveness_max': comp_avg_max,
                # 'sufficiency_max': suff_avg_max
            }
        }

        output_filename = os.path.join(args.exp_dir, f'eval_results_{args.wrapper_model}_{args.val_size}.json')
        with open(output_filename, 'wt') as output_file:
            json.dump(results,
                    output_file,
                    indent=4)

        
if __name__ == '__main__':
    main()