""" test_model.py
    Test models

    Collaboratively developed
    by Avi Schwarzschild, Eitan Borgnia,
    Arpit Bansal, and Zeyad Emam.

    Developed for DeepThinking project
    October 2021
"""

import logging
import os
import sys
from collections import OrderedDict

import json

import hydra
import torch
import matplotlib.pyplot as plt
from omegaconf import DictConfig, OmegaConf

import deepthinking as dt

# Ignore statements for pylint:
#     Too many branches (R0912), Too many statements (R0915), No member (E1101),
#     Not callable (E1102), Invalid name (C0103), No exception (W0702),
#     Too many local variables (R0914), Missing docstring (C0116, C0115).
# pylint: disable=R0912, R0915, E1101, E1102, C0103, W0702, R0914, C0116, C0115


@hydra.main(config_path="config", config_name="test_model_config")
def main(cfg: DictConfig):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch.backends.cudnn.benchmark = True
    if cfg.problem.hyp.save_period is None:
        cfg.problem.hyp.save_period = cfg.problem.hyp.epochs
    log = logging.getLogger()
    log.info("\n_________________________________________________\n")
    log.info("test_model.py main() running.")
    log.info(OmegaConf.to_yaml(cfg))

    training_args = OmegaConf.load(os.path.join(cfg.problem.model.model_path, ".hydra/config.yaml"))
    cfg_keys_to_load = [("hyp", "alpha"),
                        ("hyp", "epochs"),
                        ("hyp", "lr"),
                        ("hyp", "lr_factor"),
                        ("model", "max_iters"),
                        # ("model", "model"),
                        ("hyp", "optimizer"),
                        ("hyp", "train_mode"),
                        ("model", "width")]
    for k1, k2 in cfg_keys_to_load:
        cfg["problem"][k1][k2] = training_args["problem"][k1][k2]
    cfg.problem.train_data = cfg.problem.train_data

    log.info(OmegaConf.to_yaml(cfg))

    ####################################################
    #               Dataset and Network and Optimizer
    loaders = dt.utils.get_dataloaders(cfg.problem)

    cfg.problem.model.model_path = os.path.join(cfg.problem.model.model_path, "model_best.pth")
    
    net, _, _ = dt.utils.load_model_from_checkpoint(cfg.problem.name,
                                                    cfg.problem.model,
                                                    device)
    pytorch_total_params = sum(p.numel() for p in net.parameters())
    log.info(f"This {cfg.problem.model.model} has {pytorch_total_params/1e6:0.3f} million parameters.")
    ####################################################

    ####################################################
    #        Test
    log.info("==> Starting testing...")
    if "feedforward" in cfg.problem.model.model:
        test_iterations = [cfg.problem.model.max_iters]
    else:
        test_iterations = list(range(cfg.problem.model.test_iterations["low"],
                                     cfg.problem.model.test_iterations["high"] + 1))
    test_acc, ssh_acc, num_iter = dt.test_stop_condition(net, [loaders["test"]], cfg.test_stop_condition.mode, test_iterations, cfg.problem.name, device, cfg.test_stop_condition.threshold)
    print(f"{dt.utils.now()} Testing accuracy (hard data): {test_acc}")
    print(f"{dt.utils.now()} Testing ssh accuracy (hard data): {ssh_acc}")

    # log.info(f"{dt.utils.now()} Training accuracy: {train_acc}")
    # log.info(f"{dt.utils.now()} Val accuracy: {val_acc}")
    log.info(f"{dt.utils.now()} Testing accuracy (hard data): {test_acc}")
    
    # log.info(f"{dt.utils.now()} Training ssh accuracy: {train_ssh_acc}")
    # log.info(f"{dt.utils.now()} Val ssh accuracy: {val_ssh_acc}")
    log.info(f"{dt.utils.now()} Testing ssh accuracy (hard data): {ssh_acc}")

    log.info(f"{dt.utils.now()} Testing number of iteration (hard data): {num_iter}")
    # model_name_str = f"{cfg.problem.model.model}_width={cfg.problem.model.width}"
    
    # log.info(stats)
    ####################################################


if __name__ == "__main__":
    import datetime
    now = datetime.datetime.now()
    sys.argv.append(f"+run_id={now}")
    main()
