import argparse
import importlib.util
import logging
import os
import pprint
import sys
from typing import List, Dict, Any
import tempfile

import numpy as np

import torch
import torch.utils.data
from runners import kpop_trainer
from utils.create_rendering import render_to_path,  rendering_fix_dynamic
from utils.parse_args import parse_optfloat


def setup_logging(log_level=logging.INFO):
    handlers = [logging.StreamHandler(sys.stdout)]
    logging.basicConfig(level=log_level,
                        format='%(asctime)s|%(levelname)8s| %(message)s',
                        handlers=handlers,
                        force=True)


def load_data(model_type: str, data_downsample, data_dirs, render_only: bool, **kwargs):
    data_downsample = parse_optfloat(data_downsample, default_val=1.0)
    if model_type == "kpop_dataset":
        return kpop_trainer.load_data(
            data_downsample, data_dirs,
            render_only=render_only, **kwargs)


def init_trainer(model_type: str, **kwargs):
    if model_type == "kpop_dataset":
        from runners import kpop_trainer
        print(kwargs)
        return kpop_trainer.KpopTrainer(**kwargs)

def save_config(config):
    log_dir = os.path.join(config['logdir'], config['expname'])
    os.makedirs(log_dir, exist_ok=True)

    with open(os.path.join(log_dir, 'config.py'), 'wt') as out:
        out.write('config = ' + pprint.pformat(config))

    with open(os.path.join(log_dir, 'config.csv'), 'w') as f:
        for key in config.keys():
            f.write("%s\t%s\n" % (key, config[key]))


def main():
    #setup_logging()

    p = argparse.ArgumentParser(description="")
    p.add_argument('--render-only', action='store_true')
    p.add_argument('--config-path', type=str, required=True)
    p.add_argument('--log-dir', type=str, default=None)
    p.add_argument('--model_name', type=str, default=None)

    p.add_argument('--seed', type=int, default=0)
    p.add_argument('override', nargs=argparse.REMAINDER)

    args = p.parse_args()

    # Set random seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # Import config
    spec = importlib.util.spec_from_file_location(os.path.basename(args.config_path), args.config_path)
    cfg = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(cfg)
    config: Dict[str, Any] = cfg.config
    overrides: List[str] = args.override
    overrides_dict = {ovr.split("=")[0]: ovr.split("=")[1] for ovr in overrides}
    config.update(overrides_dict)
    model_type = "kpop_dataset"

    render_only = args.render_only
    pprint.pprint(config)
    if render_only:
        assert args.log_dir is not None and os.path.isdir(args.log_dir)
    else:
        save_config(config)
    data = load_data(model_type, render_only=render_only, **config)
    config.update(data)
    trainer = init_trainer(model_type, **config)

    if args.log_dir is not None:

        checkpoint_path = os.path.join(args.log_dir, args.model_name)
        training_needed = not  render_only
        trainer.load_model(torch.load(checkpoint_path), training_needed=training_needed)

    if render_only:
        render_to_path(trainer, extra_name="")
    else:
        trainer.train()


if __name__ == "__main__":
    main()
