import os
import argparse
import importlib
from train.pytorch_wrapper.utils import import_from_package_by_name
from natsort import natsorted
from glob import glob
import torch
from train.pytorch_wrapper.network import Network
from train.pytorch_wrapper.utils import BColors, EpochLogger


def get_arg_string(args):
    model = os.path.splitext(os.path.basename(args.model))[0].replace("/", ".")
    ts = os.path.splitext(os.path.basename(args.train_strategy))[0].replace("/", ".")
    ds = os.path.splitext(os.path.basename(args.dataset))[0].replace("/", ".")
    arg_str = "bc-%s-%s-%s" % (model, ts, ds)
    return arg_str


def build_model(args):
    module = os.path.splitext(args.model)[0].replace("/", ".")
    print("network:", module)
    Network = import_from_package_by_name("Network", module)
    return Network()


def compile_train_strategy(args):
    module = os.path.splitext(args.train_strategy)[0].replace("/", ".")
    print("train strategy:", module)
    compile_training_strategy = import_from_package_by_name("compile_training_strategy", module)
    return compile_training_strategy()


def load_eval_hook(args, dataset, dump_dir):
    module = os.path.splitext(args.evalhook)[0].replace("/", ".")
    print("eval hook:", module)
    compile_eval_hook = import_from_package_by_name("compile_eval_hook", module)
    return compile_eval_hook(dataset, dump_dir, get_arg_string(args))


def load_dataset(args):
    # import dataset function
    module = os.path.splitext(args.dataset)[0].replace("/", ".")
    print("dataset:", module)
    return importlib.import_module(module)


def parse_args():
    parser = argparse.ArgumentParser(description='Eval Model')
    mode = parser.add_mutually_exclusive_group()
    mode.add_argument('--param_root', help='path where to store model parameters.', type=str)
    mode.add_argument('--checkpoint', help='single checkpoint for evaluation', type=str)
    parser.add_argument("--outdir", type=str, default=None,
                        help="optional path to outdir if results should not be written to checkpoint dir")
    parser.add_argument("--evalhook", type=str, help="path to eval hook")
    parser.add_argument('--model', help='name of model constructor.', )
    parser.add_argument('--train_strategy', help='path to training strategy.', type=str)
    parser.add_argument("--dataset", help="data set type", type=str)
    parser.add_argument("--port", help="port of env server", type=int)
    return parser.parse_args()


def main(args):
    arg_str = get_arg_string(args)

    if args.param_root:
        dump_dir = args.param_root
        param_file_template = os.path.join(args.param_root, "params_%s*" % arg_str)
        print(param_file_template)
        param_files = natsorted(glob(param_file_template))
    elif args.checkpoint:
        param_files = [args.checkpoint]
        dump_dir = os.path.dirname(args.checkpoint)
    else:
        raise Exception("Invalid mode")

    if args.outdir:
        dump_dir = args.outdir
        if not os.path.exists(dump_dir):
            os.makedirs(dump_dir)

    dataset = load_dataset(args)
    eval_hook = load_eval_hook(args, dataset, dump_dir)
    train_strategy = compile_train_strategy(args)
    net = Network(build_model(args))
    net._to_device(train_strategy)

    log_file = os.path.join(dump_dir, "log_%s_evalhook.npy" % arg_str)

    # initialize epoch logger
    logger = EpochLogger()
    # init color printer
    col = BColors()

    # eval
    for param_file in param_files:
        print("Evaluating %s" % param_file)
        net.load(param_file)
        with torch.no_grad():
            eval_dict = eval_hook(net.net)
            for k, v in eval_dict.items():
                logger.append("va_%s" % k, v)
                print(col.print_colored("va_%s: %.3f" % (k, v), col.OKBLUE))

        # summarize logged data
        logger.summarize_epoch()
        logger.dump(log_file)


if __name__ == "__main__":
    args = parse_args()
    main(args)
