import os
import argparse
import importlib

import torch
from torch.utils.data import Dataset

from train.common.config import Config
from train.pytorch_wrapper.network import Network
from train.pytorch_wrapper.utils import import_from_package_by_name


def parse_args(parent_parsers=None):
    """Parse arguments
    """

    if parent_parsers:
        parser = argparse.ArgumentParser(description='Train model.', parents=parent_parsers)
    else:
        parser = argparse.ArgumentParser(description='Train model.')

    parser.add_argument('--model', help='name of model constructor.', default="models/OrigActs.py")
    parser.add_argument('--param_root', help='path where to store model parameters.', type=str, default='../../params')
    parser.add_argument('--train_strategy', help='path to training strategy.', type=str)

    # parse data set arguments
    data = parser.add_argument_group(title="dataset")
    data.add_argument("--dataset", help="data set type", type=str, default="datasets/ds/ds_s32_orig_minerl_diamond.py")
    data.add_argument("--data", type=str, metavar='PATH',
                      default="/publicdata/mine_rl/dataset_20190712/MineRLObtainDiamond-v0/subtasks/log",
                      help="data path")
    data.add_argument("--trainidx", type=str, default="train.csv", help="name of index for train split")
    data.add_argument("--validx", type=str, default="val.csv", help="name of index for val split")
    data.add_argument("--subset", type=float, default=1, help="fraction of data to use (intended for debugging)")
    data.add_argument("--heading", action='store_true',
                      help="flag to include estimated camera pitch as inputs")
    data.add_argument("--evalset", type=str, default="valid",
                      help="split for evaluation scripts (train, valid)")
    data.add_argument("--evalhook", type=str, default=None, help="extra evaluation steps")

    # testing options
    test = parser.add_argument_group(title="testing arguments")
    test.add_argument("--envseed", type=int, default=None, help="seed for resetting the environment")
    test.add_argument('--verbosity', type=int, default=0, help='set verbosity level of script.')
    test.add_argument("--recordto", type=str, default=None, help="path were to store video recordings.")
    test.add_argument('--checkpoint', type=int, default=None, help='model checkpoint.')
    test.add_argument('--env', type=str, help='environment used for testing.')
    test.add_argument('--trials', type=int, default=1, help='number of evaluation trials.')
    test.add_argument('--maxsteps', type=int, default=1200, help='number of evaluation steps.')
    test.add_argument("--paramfile", type=str, default=None, help="path to model parameters.")

    # training options
    parser.add_argument('--num_workers', help='number of workers for batch generation.', type=int, default=4)
    parser.add_argument('--find_lr', help='run learning rate finder.', action='store_true')
    parser.add_argument('--freeze', help='freeze weights of pre-trained model.', action='store_true')
    parser.add_argument('--continue_train', help='load model state and continue training.', action='store_true')
    parser.add_argument('--trainset_eval', help='run training set evaluation without updates.', action='store_true')

    args = parser.parse_args()

    return args


def get_arg_string(args):
    model = os.path.splitext(os.path.basename(args.model))[0].replace("/", ".")
    ts = os.path.splitext(os.path.basename(args.train_strategy))[0].replace("/", ".")
    ds = os.path.splitext(os.path.basename(args.dataset))[0].replace("/", ".")
    arg_str = "bc-%s-%s-%s" % (model, ts, ds)
    return arg_str


def build_model(args):
    module = os.path.splitext(args.model)[0].replace("/", ".")
    print("network:", module)
    Network = import_from_package_by_name("Network", module)
    return Network()


def compile_train_strategy(args):
    module = os.path.splitext(args.train_strategy)[0].replace("/", ".")
    print("train strategy:", module)
    compile_training_strategy = import_from_package_by_name("compile_training_strategy", module)
    return compile_training_strategy()


def load_eval_hook(args, dataset, dump_dir):
    module = os.path.splitext(args.evalhook)[0].replace("/", ".")
    print("eval hook:", module)
    compile_eval_hook = import_from_package_by_name("compile_eval_hook", module)
    return compile_eval_hook(dataset, dump_dir, get_arg_string(args))


def get_batch_sampler(args, train_dataset):
    module = os.path.splitext(args.train_strategy)[0].replace("/", ".")

    try:
        compile_batch_sampler = import_from_package_by_name("compile_batch_sampler", module)
    except:
        compile_batch_sampler = None

    if compile_batch_sampler:
        print("train strategy: using custom batch sampler")
        sampler, shuffle = compile_batch_sampler(train_dataset)
    else:
        sampler, shuffle = None, False

    return sampler, shuffle


def load_data(args):
    # import dataset function
    module = os.path.splitext(args.dataset)[0].replace("/", ".")
    print("dataset:", module)
    compile_dataset = import_from_package_by_name("compile_dataset", module)
    dataset = importlib.import_module(module)

    # compile dataset
    train_data = compile_dataset(root=args.data, index_file=args.trainidx, subset=args.subset)
    valid_data = compile_dataset(root=args.data, index_file=args.validx, subset=args.subset)

    if len(train_data) <= 0:
        raise ValueError("No training data!")

    # if len(valid_data) <= 0:
    #    raise ValueError("No validation data!")

    # info = {"n_channels": 3, "n_classes": 17}
    info = {}

    return train_data, valid_data, info, dataset


class BehaviouralCloning:
    def __init__(self):
        self.net = None
        self.best_model = None

    def train_model(self, args, model=None, eval_hook=None):
        """
        train model
        """

        # create dump folder if it does not exist
        os.makedirs(args.param_root, exist_ok=True)

        # load data
        # ---------
        print("Loading data ...")
        train_dataset, valid_dataset, info, dataset_spec = load_data(args)

        # init training strategy
        # ----------------------
        print("Initializing training strategy ...")

        # init training strategy
        train_strategy = compile_train_strategy(args)

        # compile batch sampler
        sampler, shuffle = get_batch_sampler(args, train_dataset)

        # init data loaders
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=train_strategy.tr_batch_size,
                                                   shuffle=shuffle, drop_last=False, num_workers=args.num_workers)

        if args.trainset_eval:
            trval_loader = torch.utils.data.DataLoader(train_dataset, batch_size=train_strategy.va_batch_size,
                                                       shuffle=False, drop_last=False, num_workers=args.num_workers)
        else:
            trval_loader = None

        if len(valid_dataset) > 0:
            valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=train_strategy.va_batch_size,
                                                       shuffle=False, drop_last=False, num_workers=args.num_workers)
        else:
            valid_loader = None
        # Eval hook for evaluating model in environment (if not passed from metacontroller check command line args)
        if eval_hook is None:
            if args.evalhook:
                eval_hook = load_eval_hook(args, dataset_spec, args.param_root)
        else:
            if eval_hook.action_space is None:
                eval_hook.set_action_space(dataset_spec.ACTION_SPACE)

        # init model
        # ----------
        if model is None:
            model = build_model(args)
            if hasattr(model, "set_action_statistics") and hasattr(train_dataset, "action_statistics"):
                model.set_action_statistics(train_dataset.action_statistics)

        train_strategy.register_optimizer(model)

        if args.freeze:
            print("Freezing pre-trained model parameters ...")
            model.freeze(True)

        # init wrapper network
        self.net = Network(model)

        # prepare dump and logging path
        arg_str = get_arg_string(args)
        log_file = os.path.join(args.param_root, "log_%s.npy" % arg_str)
        dump_file = os.path.join(args.param_root, "params_%s.npy" % arg_str)
        print("dump_file:", dump_file)

        # continue training from previous state
        if args.continue_train:
            print("Loading previous model parameters ...")
            self.net.load(dump_file)

        if train_strategy.load_checkpoint:
            print("Loading checkpoint: %s" % train_strategy.load_checkpoint)
            self.net.load(train_strategy.load_checkpoint)

        # fit model
        if args.find_lr:
            print("Trying to find learning rate ...")
            self.net.find_lr(train_loader, train_strategy, start_lr=1e-7, end_lr=10, num_it=100, plot=True)
        else:
            print("Starting training ...")
            self.net.fit(train_loader, valid_loader, train_strategy, trval_loader,
                         dump_file=dump_file, log_file=log_file, eval_hook=eval_hook)
            self.best_model = self.net.best_model
        return self.best_model


def train_model_with_config(config_file: str):
    args = Config(config_file)
    bc = BehaviouralCloning()
    bc.train_model(args)


if __name__ == "__main__":
    """ main """
    args = parse_args()
    bc = BehaviouralCloning()
    bc.train_model(args)
