import os
import sys
import pathlib
import atexit
import re
import socket

VHOME_PATH = pathlib.Path(__file__).parent / 'LID' / 'virtualhome'
VHOME_LID_PATH = pathlib.Path(__file__).parent / 'LID' / 'vh_mdp'
VHOME_DATA_PATH = pathlib.Path(__file__).parent / 'LID' / 'data'
BC_LID_PATH = pathlib.Path(__file__).parent / 'LID' / 'behavior_cloning'
EXEC_PATH = VHOME_PATH / 'simulation/unity_simulator/v2.2.5/linux_exec.v2.2.5_beta.x86_64'
CHECKPOINT_PATH = BC_LID_PATH / 'checkpoints'

sys.path.append(str(VHOME_PATH))
sys.path.append(str(VHOME_LID_PATH))
sys.path.append(str(BC_LID_PATH))

from envs.vh_environment import VHEnvironment
from promptrl.envs.virtualhome.default_args import VHomeEnvArgs
import init_path

class PartitionedVHEnvironment(VHEnvironment):
    def __init__(self,
        num_agents=1,
        max_episode_length=100,
        env_task_set=None,
        observation_types=None,
        use_editor=False,
        base_port=8080,
        port_id=0,
        recording=False,
        output_folder=None,
        file_name_prefix=None,
        executable_args={},
        seed=123,
        partition=0,
        n_partitions=10,
        flag_stop_early=False
    ):
        self.n_partitions = n_partitions
        self.partition = partition
        self.part_task_id = 0
        self.num_games = len(env_task_set) // self.n_partitions
        self.parse_action = lambda a: re.sub(r' \(\d+\)', '', a)
        super().__init__(
            num_agents=num_agents,
            max_episode_length=max_episode_length,
            env_task_set=env_task_set,
            observation_types=observation_types,
            use_editor=use_editor,
            base_port=base_port,
            port_id=port_id,
            recording=recording,
            output_folder=output_folder,
            file_name_prefix=file_name_prefix,
            executable_args=executable_args,
            seed=seed,
            flag_stop_early=flag_stop_early
        )

    def reset(self, **kwargs):
        task_id = (self.partition + self.n_partitions * self.part_task_id) % len(self.env_task_set)
        self.part_task_id += 1
        kwargs.pop('task_id', None)
        obs, infos = super().reset(task_id=task_id, **kwargs)
        infos['won'] = infos['success']
        infos['admissible_commands'] = [self.parse_action(a) for a in infos['admissible']]
        return [obs], {k: [v] for k, v in infos.items()}

    def step(self, action_nls, **kwargs):
        assert isinstance(action_nls, list)# batched
        assert len(action_nls) == 1
        obs, r, done, infos = super().step(action_nl=action_nls[0], **kwargs)
        infos['won'] = infos['success']
        infos['admissible_commands'] = [self.parse_action(a) for a in infos['admissible']]
        return [obs], [r], [done], {k: [v] for k, v in infos.items()}

MODE_LOOKUP = {
    'id': ('InDistributation', 0),#misspelled in the original code
    'novel_tasks': ('NovelTasks', 1),
    'novel_scenes': ('NovelScenes', 2),
}
def make_vhome_loader(mode, partition=0):
    base_args = VHomeEnvArgs()
    base_args.exec_file = str(EXEC_PATH)
    base_args.data_dir = str(VHOME_DATA_PATH)
    base_args.pretrained_model_dir = str(CHECKPOINT_PATH)
    base_args.subset, mode_id = MODE_LOOKUP[mode]

    port_no = int(os.environ.get('PORT', '8904'))
    port_no += mode_id# less chance port conflicts
    open_port = False
    while not open_port:
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        result = sock.connect_ex(('127.0.0.1', port_no))
        open_port = result != 0
        if not open_port:
            print(f'Port {port_no} closed, retrying..')
            port_no += 1

    base_args.base_port = port_no
    def _initializer():
        args = init_path.get_logger_path(base_args)
        args = init_path.initialize_path(args)
        args = init_path.load_data_info(args)
        executable_args = {
            'file_name': args.exec_file,
            'x_display': args.display,
            'no_graphics': not args.graphics
        }
        env_args = {
            'num_agents': args.n_agent,
            'max_episode_length': args.max_episode_length,
            'port_id': 0,
            'env_task_set': args.env_task_set,
            'observation_types': [args.obs_type, args.obs_type],
            'use_editor': args.use_editor,
            'executable_args': executable_args,
            'base_port': args.base_port,
            'seed': args.seed,
            'flag_stop_early': False,
        }
        env_cls = PartitionedVHEnvironment
        if partition is None:
            env_args['partition'] = 0
            env_args['n_partitions'] = 1
        else:
            env_args['partition'] = partition
            env_args['n_partitions'] = 10

        env = env_cls(**env_args)
        print(f'VHome Env {mode} loaded. {env.num_games} tasks. ({len(args.env_task_set)} total)')
        atexit.register(env.close)
        return env

    return _initializer
