#!/usr/bin/env python3

import argparse
import os

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from sran import SRAN
from utility import PGMdataset, RAVENdataset, ToTensor
from utility.RAVENdataset_utility import figure_configuration_names


parser = argparse.ArgumentParser(description='our_model')
parser.add_argument('--model', type=str, default='SRAN')
parser.add_argument("--actfun", type=str, default="relu")
parser.add_argument("--feature_width", type=float, default=1.)
parser.add_argument('--dataset', type=str, default='I-RAVEN', choices=['PGM', 'I-RAVEN', 'RAVEN'])
parser.add_argument('--test_figure_configurations', type=int, nargs="+", default=[0,1,2,3,4,5,6])
parser.add_argument('--img_size', type=int, default=224)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--load_workers', type=int, default=16)
parser.add_argument('--resume', type=str, default='', required=True)
#parser.add_argument('--dataset_path', type=str, default='/media/dsg3/datasets/PGM')
parser.add_argument('--dataset_path', type=str, default='/media/dsg3/datasets/I-RAVEN')
parser.add_argument('--save', type=str)
parser.add_argument("--no-cuda", dest="cuda", action="store_false")
parser.add_argument("--verbose", "-v", action="count", default=1)
parser.add_argument("--quiet", "-q", action="count", default=0)
parser.add_argument('--print_freq', type=int, default=20)

args = parser.parse_args()
verbose = args.verbose - args.quiet

print("Arguments:")
for arg_name in vars(args):
    print("{} = {}".format(arg_name, repr(getattr(args, arg_name))))
print("")

args.lr = 0.0
args.beta1 = 0.9
args.beta2 = 0.999
args.epsilon = 1e-8
args.weight_decay = 0.0
args.meta_beta = 0.0


def count_parameters(model, only_trainable=True):
    r"""
    Count the number of (trainable) parameters within a model and its children.
    Arguments:
        model (torch.nn.Model): the model.
        only_trainable (bool, optional): indicates whether the count should be restricted
            to only trainable parameters (ones which require grad), otherwise all
            parameters are included. Default is ``True``.
    Returns:
        int: total number of (trainable) parameters possessed by the model.
    """
    if only_trainable:
        return sum(p.numel() for p in model.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in model.parameters())


def test(epoch):
    model.eval()
    accuracy = 0
    counter = 0
    total_correct = 0
    n_samp = 0
    with torch.no_grad():
        for batch_idx, (image, target, meta_target) in enumerate(testloader):
            counter += 1
            if args.cuda:
                image = image.cuda()
                target = target.cuda()
                meta_target = meta_target.cuda()
            output = pmodel(image)
            pred = output[0].data.max(1)[1]
            correct = pred.eq(target.data).cpu().sum().numpy()
            total_correct += correct
            n_samp += target.size()[0]
            if batch_idx <= 2 or batch_idx % args.print_freq == 0:
                print(
                    "Test Epoch:{:4d}: Batch:{:4d}/{:4d}, Acc:{:6.2f}".format(
                        epoch, batch_idx, len(testloader), correct * 100. / target.size()[0]
                    ),
                    flush=True,
                )
    accuracy = 100. * total_correct / n_samp
    if counter > 0:
        print("Total Testing Acc: {:.4f}".format(accuracy), flush=True)
    return accuracy


if args.dataset == 'PGM':
    testset = PGMdataset(args.dataset_path, "test", args.img_size, transform=transforms.Compose([ToTensor()]))
    test_config_name = ""
elif "RAVEN" in args.dataset:
    testset = RAVENdataset(args.dataset_path, "test", args.test_figure_configurations, args.img_size, transform=transforms.Compose([ToTensor()]))

    if len(args.test_figure_configurations) < 7:
        test_config_name = [figure_configuration_names[i] for i in args.test_figure_configurations]
    else:
        test_config_name = "(full)"

else:
    raise ValueError("Unsupported dataset: {}".format(args.dataset))

testloader = DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.load_workers)

print ('Test: {}'.format(len(testset)))
print ('Image size:', args.img_size)

if args.model == 'SRAN':
    model = SRAN(args)
else:
    raise ValueError("Unsupported model: {}".format(args.model))

pmodel = torch.nn.DataParallel(model)
torch.backends.cudnn.benchmark = True
pmodel = pmodel.cuda()

# optionally resume from a checkpoint
if not args.resume:
    raise ValueError("Resume path must be given")

if not os.path.isfile(args.resume):
    raise OSError("=> no checkpoint found at '{}'".format(args.resume))

print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
epoch = checkpoint["epoch"]
model.load_state_dict(checkpoint["state_dict"])

print(
    "=> loaded checkpoint '{}' (epoch {})".format(
        args.resume, checkpoint["epoch"]
    )
)

if verbose >= 2:
    print("Model architecture:")
    print(model)
    print()
    print(
        "Number of model parameters (total):     {}".format(
            count_parameters(model, only_trainable=False)
        )
    )
    print(
        "Number of model parameters (trainable): {}".format(
            count_parameters(model, only_trainable=True)
        )
    )
    for submodel_name in ("conv1", "conv2", "conv3", "h1", "h2", "h3"):
        if not hasattr(model, submodel_name):
            continue
        print(
            "model.{} has {} parameters, of which {} are trainable".format(
                submodel_name,
                count_parameters(getattr(model, submodel_name), only_trainable=False),
                count_parameters(getattr(model, submodel_name), only_trainable=True),
            )
        )
    print()

print(f"Processing test dataset {args.dataset} {test_config_name}")
avg_test_acc = test(epoch)

print(
    'Epoch {}, Dataset: {} {}, Testing Acc: {:.4f}\n'.format(
        epoch, args.dataset_path, test_config_name, avg_test_acc
    )
)

if args.save:
    os.makedirs(args.save, exist_ok=True)
    with open(os.path.join(args.save, 'test_results.log'), 'a') as f:
        f.write(
            'Epoch {}, Dataset: {} {}, Testing Acc: {:.4f}\n'.format(
                epoch,
                args.dataset_path,
                test_config_name,
                avg_test_acc,
            )
        )

if args.save:
    os.makedirs(args.save, exist_ok=True)
    with open(os.path.join(args.save, "test_results.yaml"), "a") as f:
        f.write(
            "{}{}: {:.4f}\n".format(
                args.dataset,
                test_config_name[0] if len(test_config_name) == 1 else test_config_name,
                avg_test_acc,
            )
        )
