# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import argparse
import multiprocessing as mp
import threading
import time

import numpy as np
import libtorchbeast
from torchbeast_dmlab import dmlab_wrappers

from torchbeast_dmlab import atari_wrappers


# yapf: disable
parser = argparse.ArgumentParser(description='Remote Environment Server')

parser.add_argument("--pipes_basename", default="unix:/tmp/polybeast",
                    help="Basename for the pipes for inter-process communication. "
                    "Has to be of the type unix:/some/path.")
parser.add_argument('--num_servers', default=4, type=int, metavar='N',
                    help='Number of environment servers.')
parser.add_argument('--env', type=str, default='psychlab_arbitrary_visuomotor_mapping',
                    help='DMlab environment.')
parser.add_argument('--seed', default=1, type=int, metavar='N',
                    help='seed.')
parser.add_argument('--allow_oov', action="store_true",
                    help='Allow action space larger than the env specific one.'
                    ' All out-of-vocab action will be mapped to NoOp.')
# yapf: enable


class Env:
    def reset(self):
        print("reset called")
        return np.ones((4, 84, 84), dtype=np.uint8)

    def step(self, action):
        frame = np.zeros((4, 84, 84), dtype=np.uint8)
        return frame, 0.0, False, {}  # First three mandatory.


def create_env(env_name, seed=1, lock=threading.Lock()):
    level_name = 'contributed/dmlab30/' + env_name
    config = {
        'width': 96,
        'height': 72,
        'logLevel': 'WARN',
    }
    with lock:
        return dmlab_wrappers.create_env_dmlab(level_name, config, seed)


def create_test_env(env_name, seed=1):
    level_name = 'contributed/dmlab30/' + env_name
    config = {
        'width': 96,
        'height': 72,
        'logLevel': 'WARN',
    }
    return dmlab_wrappers.create_env_dmlab(level_name, config, seed)


def serve(env_name, server_address, seed=1):
    init = Env if env_name == "Mock" else lambda: create_env(env_name, seed)
    server = libtorchbeast.Server(init, server_address=server_address)
    server.run()


def main(flags):
    if not flags.pipes_basename.startswith("unix:"):
        raise Exception("--pipes_basename has to be of the form unix:/some/path.")

    processes = []
    for i in range(flags.num_servers):
        p = mp.Process(
            target=serve, args=(flags.env, f"{flags.pipes_basename}.{i}", flags.seed), daemon=True
        )
        p.start()
        processes.append(p)

    try:
        # We are only here to listen to the interrupt.
        while True:
            time.sleep(10)
    except KeyboardInterrupt:
        pass


if __name__ == "__main__":
    flags = parser.parse_args()
    print(f"Env: {flags.env}")
    main(flags)
