from pathlib import Path

import numpy as np
np.set_printoptions(precision=3, linewidth=120)

from code import cli
from code.defaults import MACHINE_VARIABLES
from code.log import default_log as log
from code.checkpoint import CheckpointableData, Checkpointer
from code.config import BaseConfig, Require
from code.torch_util import device
from code.shared import get_env
from code.smbpo import SMBPO


ROOT_DIR = Path(MACHINE_VARIABLES['root-dir'])
SAVE_PERIOD = 5


class Config(BaseConfig):
    domain = Require(str)
    algorithm = Require(str)
    seed = 1
    epochs = 1000
    alg_args = SMBPO.Config()


def main(cfg):
    env_factory = lambda: get_env(cfg.domain)
    data = CheckpointableData()
    alg = SMBPO(cfg.alg_args, env_factory, data)
    alg.to(device)
    checkpointer = Checkpointer(alg, log.dir, 'ckpt_{}.pt')
    data_checkpointer = Checkpointer(data, log.dir, 'data.pt')

    # Check if existing run
    if data_checkpointer.try_load():
        log.message('Data load succeeded')
        loaded_epoch = checkpointer.load_latest(list(range(0, cfg.epochs, SAVE_PERIOD)))
        if isinstance(loaded_epoch, int):
            assert loaded_epoch == alg.epochs_completed
            log.message('Solver load succeeded')
        else:
            assert alg.epochs_completed == 0
            log.message('Solver load failed')
    else:
        log.message('Data load failed')

    if alg.epochs_completed == 0:
        alg.setup()

        # So that we can compare to the performance of randomly initialized policy
        alg.evaluate()

    while alg.epochs_completed < cfg.epochs:
        log.message(f'Beginning epoch {alg.epochs_completed+1}')
        alg.epoch()
        alg.evaluate()

        if alg.epochs_completed % SAVE_PERIOD == 0:
            checkpointer.save(alg.epochs_completed)
            data_checkpointer.save()


if __name__ == '__main__':
    cli.main(Config(), main)