import os

from train.common.config import Config
from train.iface import ExternalInterface
from train.iface_debug import ExternalInterfaceDebug
from train.proc.bc_proc import BehaviouralCloningProc, BehaviouralCloningData
from train.proc.env_client_proc import EnvClientProc
from train.proc.env_server_proc import EnvServerProc
from train.proc.rl_proc import ReinforcementLearningProc
from train.proc.subtask_proc import SubtaskIdentifierProc
from train.statement.base_stmt import Statement
from train.statement.eval_stmt import EvaluateAgentStatement
from train.statement.learning_stmt import LearningStatement
from train.subtask_identifier.task_builder import TaskBuilder
from train.task import TaskType
from test import main as test_main
from train import main as train_main
from train.enums import ExecMode, ClientEnvOption

META_CONFIG_FILE = 'configs/experiment/config.meta.json'


def test_config():
    args = Config(META_CONFIG_FILE)
    print(args.bc)
    print(args.bc.evalhook)
    print(args.bc.gpu)
    print(args.rl.env)
    print(args.env.config_file)
    print(args.model)


def test_subtask_identifier():
    config = Config(META_CONFIG_FILE)
    config.subtask.__dict__['task'] = config.debug.env
    config.subtask.__dict__['rootdir'] = config.debug.subtask_rootdir
    config.subtask.__dict__['inputdir'] = os.path.join(config.debug.subtask_rootdir, config.debug.env)
    config.subtask.__dict__['outputdir'] = os.path.join(config.debug.subtask_outputdir, config.debug.env)

    proc = SubtaskIdentifierProc(config)
    proc.run()
    assert proc.success()


def test_task():
    config = Config(META_CONFIG_FILE)
    ext_iface = ExternalInterfaceDebug(config)
    proc = ext_iface.create_subtask_alignment()
    proc.run()
    assert proc.success()
    tasks = proc.consensus
    for task in tasks:
        task.save()
        task_dir = os.path.join(config.checkpoint_path, task.id)
        task_file = os.path.join(task_dir, task.task_file)
        assert os.path.exists(task_dir) and os.path.exists(task_file)
    consensus_file = os.path.join(config.checkpoint_path, config.checkpoint_file)
    with open(consensus_file, 'w') as f:
        [f.write("%s\n" % c.id) for c in tasks]
    tasks_loaded = ext_iface.load_consensus()
    for t1, t2 in zip(tasks, tasks_loaded):
        print(t1.id, t2.id)
        assert t1.id == t2.id


def test_client():
    config = Config(META_CONFIG_FILE)
    proc = EnvClientProc(config, [1])
    proc.run()
    assert proc.success()
    env = proc.env
    env.reset()
    action = env.action_space.noop()
    print(action, env)
    env.step(action)
    env.close()
    if proc._done:
        print('Test successfully passed.')
    else:
        print('Test failed!')


def test_server():
    config = Config(META_CONFIG_FILE)
    proc = EnvServerProc(config, config.env.num_env)
    proc.wait()
    assert proc.success()


def test_rl():
    config = Config(META_CONFIG_FILE)
    task = TaskBuilder(config).extract_dummy_tasks('S')
    proc = EnvClientProc(config, [5])
    proc.run()
    assert proc.success()
    proc = ReinforcementLearningProc(config, proc.env, task)
    proc.run()
    assert proc.success()


def test_bc():
    config = Config(META_CONFIG_FILE)
    consensus = 'SAQ'  # [log, cobblestone, iron_ore]
    tb = TaskBuilder(config)
    tasks, _ = tb.extract_dummy_tasks(consensus)
    ext_iface = ExternalInterfaceDebug(config)
    num_envs = 2
    for task in tasks:
        proc = ext_iface.make_env(num_envs, ClientEnvOption.Normal)
        proc.wait()
        proc = ext_iface.clone_behaviour(task, data_type=BehaviouralCloningData.General,
                                         env=proc.env, n_workers=proc.num_envs)
        proc.run()
        assert proc.success()


def test_eval_with_subtasks():
    config = Config(META_CONFIG_FILE)
    ext_iface = ExternalInterface(config)
    ext_iface.create_paths()
    config.subtask.__dict__['inputdir'] = os.path.join(config.debug.subtask_rootdir, config.debug.env)
    proc = ext_iface.create_subtask_alignment()
    proc.run()
    assert proc.success()
    tasks = proc.consensus
    device_list = ext_iface.get_device_list()
    proc = ext_iface.make_env(num_envs=1, option=ClientEnvOption.Normal)
    proc.run()
    assert proc.success()
    proc = ext_iface.evaluate_tasks(proc.env, tasks, device_list[0])
    proc.run()
    assert proc.success()


def test_eval():
    config = Config(META_CONFIG_FILE)
    ext_iface = ExternalInterfaceDebug(config)
    consensus = config.debug.subtask_consensus_demos[config.debug.subtask_consensus_selection]
    tasks, _ = TaskBuilder(config).extract_dummy_tasks(consensus)
    for t in tasks:
        if t.task_type == TaskType.Imitation:
            t.imitation_ready = False
    device_list = ext_iface.get_device_list()
    proc = ext_iface.make_env(num_envs=1, option=ClientEnvOption.Normal)
    proc.run()
    assert proc.success()
    proc = ext_iface.evaluate_tasks(proc.env, tasks, device_list[0])
    proc.run()
    assert proc.success()


def test_eval_with_collecting_learning():
    config = Config(META_CONFIG_FILE)
    config.__dict__['mode'] = ExecMode.Eval
    ext_iface = ExternalInterfaceDebug(config)
    consensus = config.debug.subtask_consensus_demos[config.debug.subtask_consensus_selection]
    tasks, _ = TaskBuilder(config).extract_dummy_tasks(consensus)
    for t in tasks:
        if t.id == "wooden_pickaxe-crafting_table" or t.id == 'cobblestone-stone_pickaxe':
            t.task_type = TaskType.Learning
            t.imitation_ready = False
    device_list = ext_iface.get_device_list()
    proc = ext_iface.make_env(num_envs=1, option=ClientEnvOption.Normal)
    proc.run()
    assert proc.success()
    proc = ext_iface.evaluate_tasks(proc.env, tasks, device_list[0])
    proc.run()
    assert proc.success()


def test_record_eval():
    config = Config(META_CONFIG_FILE)
    ext_iface = ExternalInterfaceDebug(config)
    tasks, _ = TaskBuilder(config).extract_dummy_tasks('SSSSSSPPPPLVVNLAAA')
    device_list = ext_iface.get_device_list()
    proc = ext_iface.make_env(num_envs=1, option=ClientEnvOption.Record, seed_list=[342])
    proc.run()
    assert proc.success()
    proc = ext_iface.evaluate_tasks(proc.env, tasks, device_list[0])
    proc.run()
    assert proc.success()


def test_eval_stmt():
    config = Config(META_CONFIG_FILE)
    ext_iface = ExternalInterfaceDebug(config)
    tasks = ext_iface.get_consensus()
    stmt = Statement(config, ext_iface, tasks[-1], tasks)
    stmt = EvaluateAgentStatement(stmt)
    stmt.exec()
    assert stmt.success()


def test_training():
    config = Config(META_CONFIG_FILE)
    config.__dict__['mode'] = ExecMode.Train

    consensus = 'SSSSPPVVLAAAQQQQ'
    tb = TaskBuilder(config)
    tasks = tb.extract_dummy_tasks(consensus)

    ext_iface = ExternalInterfaceDebug(config)
    # Save to have files for further calls
    consensus_file = os.path.join(config.checkpoint_path, config.checkpoint_file)
    ext_iface.save_consensus(tasks, consensus_file)

    history = []
    for task in tasks:
        history.append(task)
        if task.task_type != TaskType.Learning:
            continue
        stmt = Statement(config, ext_iface, task, task_history=history)
        stmt = LearningStatement(stmt)
        stmt.exec()
        assert stmt.success()


if __name__ == '__main__':
    # Uncomment whatever you want to test
    #test_config()
    #test_task()
    #test_subtask_identifier()
    #test_server()
    #test_client()
    #test_bc()
    #test_rl()
    #test_eval()
    #test_eval_with_collecting_learning()
    #test_eval_with_subtasks()
    #test_record_eval()
    #test_eval_stmt()
    #test_training()
    #test_main()
    train_main()
