import argparse
import os
import pickle
import random
import time
from os import path as osp
from collections import OrderedDict

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from tensorboardX import SummaryWriter
from termcolor import colored
from torch.utils.data import DataLoader

import utils_training.optimize as optimize
from utils_training.evaluation import Evaluator
from utils_training.utils import parse_list, log_args, load_checkpoint, save_checkpoint, boolean_string
from data import download
from models.glbt import GLBT


if __name__ == "__main__":
    # Argument parsing
    parser = argparse.ArgumentParser(description='GLBT Test Script')
    # Paths
    parser.add_argument('--name_exp', type=str,
                        default=time.strftime('%Y_%m_%d_%H_%M'),
                        help='name of the experiment to save')
    parser.add_argument('--model_type', type=int, default=0, help='choose model type')
    parser.add_argument('--snapshots', type=str, default='./eval')
    parser.add_argument('--pretrained', dest='pretrained',
                       help='path to pre-trained model')
    parser.add_argument('--batch-size', type=int, default=1,
                        help='training batch size')
    parser.add_argument('--n_threads', type=int, default=1,
                        help='number of parallel threads for dataloaders')
    parser.add_argument('--seed', type=int, default=2021,
                        help='Pseudo-RNG seed')
                        
    parser.add_argument('--datapath', type=str, default='../Datasets')
    parser.add_argument('--benchmark', type=str, choices=['pfpascal', 'spair', 'pfwillow'])
    parser.add_argument('--thres', type=str, default='auto', choices=['auto', 'img', 'bbox'])
    parser.add_argument('--alpha', type=float, default=0.1)


    parser.add_argument('--feature-size', type=int, default=32)
    parser.add_argument('--feature-proj-dim', type=int, default=128)
    parser.add_argument('--augmentation', type=boolean_string, nargs='?', const=True, default=True)
    # Seed
    args = parser.parse_args()
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    # Initialize Evaluator
    Evaluator.initialize(args.benchmark, args.alpha)
    # Dataloader
    download.download_dataset(args.datapath, args.benchmark)
    test_dataset = download.load_dataset(args.benchmark, args.datapath, args.thres, device, 'test', args.augmentation, args.feature_size)
    test_dataloader = DataLoader(test_dataset,
        batch_size=args.batch_size,
        num_workers=args.n_threads,
        shuffle=False)

    # Model
    model = GLBT()

    if args.pretrained:
        checkpoint = torch.load(osp.join(args.pretrained, 'best_model.pt'))
        new_state_dict = OrderedDict()
        for k, v in checkpoint.items():
            new_state_dict[k[7:]] = v
        model.load_state_dict(new_state_dict)
    else:
        raise NotImplementedError()
    # create summary writer

    model = nn.DataParallel(model)
    model = model.to(device)

    train_started = time.time()

    val_loss_grid, val_mean_pck = optimize.validate_epoch(model,
                                                    test_dataloader,
                                                    device,
                                                    epoch=0)
    print(colored('==> ', 'blue') + 'Test average grid loss :',
            val_loss_grid)
    print('mean PCK is {}'.format(val_mean_pck))

    print(args.seed, 'Test took:', time.time()-train_started, 'seconds')
