import builtins
import os
import io
import time

from config import PERSISTENT_DATA_PATH

_real_open = builtins.open
_real_listdir = os.listdir

# Some dictionary to store file contents
_file_cache = {}


def _cached_open(path, mode='r', buffering=-1, encoding=None, errors=None, newline=None, closefd=True, opener=None):
    # Only patch in read mode
    if 'r' in mode:
        if path not in _file_cache:
            with _real_open(path, 'rb') as f:
                _file_cache[path] = f.read()
        # We emulate an open file handle from our cache
        import io
        if 'b' in mode:
            return io.BytesIO(_file_cache[path])
        else:
            # decode for text mode
            return io.StringIO(_file_cache[path].decode(encoding or 'utf-8'))
    else:
        return _real_open(path, mode, buffering, encoding, errors, newline, closefd, opener)


def _cached_listdir(path):
    # If needed, you can store or cache the directory listing
    return _real_listdir(path)


def enable_monkey_patch():
    builtins.open = _cached_open
    os.listdir = _cached_listdir


def disable_monkey_patch():
    builtins.open = _real_open
    os.listdir = _real_listdir


enable_monkey_patch()
print("Monkey patching enabled")
from multiprocessing import Pool
from functools import partial
import numpy as np
import verbenvs
from verbenvs.envs.alfworld import VerbalizedALFWorld
from tqdm import tqdm
from verbenvs.context.auto import load_context_by_name
from verbenvs.envs.auto import load_verbenv_by_name
import random
import copy
from concurrent.futures import ProcessPoolExecutor
import os
from datetime import datetime
import time


def reset_to_same_env_state(game_file, trajectory):
    new_env = VerbalizedALFWorld(game_files=[game_file], split="eval_out_of_distribution", use_planner=True)
    new_env.reset()
    for action in trajectory:
        obs, reward, done, truncated, infos = new_env.step(action)
    return new_env


def rollout_until_solve(env, max_steps=200):
    done = False
    step_count = 0
    expert_path = []
    while not done and step_count < max_steps:
        action = env.get_expert_action()
        obs, reward, done, truncated, infos = env.step(action)
        print(obs, reward, done, truncated, infos)
        expert_path.append(action)
        step_count += 1
    return step_count, expert_path


def get_steps(action, game_file, traj):
    env2 = reset_to_same_env_state(game_file, traj)
    obs, reward, done, truncated, infos = env2.step(action)
    steps, expert_path = rollout_until_solve(env2)
    return steps, expert_path


def get_expert_steps(env, info, traj):
    step_count_list = []
    # expert_path_list = []
    expert_action_list = []

    # with Pool() as pool:
    #     ret = pool.map(partial(get_steps, game_file=env._curr_gamefile, traj=traj), info["admissible_actions"])
    num_cpus = os.cpu_count()
    num_cpus = int(os.getenv("SLURM_CPUS_PER_TASK", "1"))
    print(f"Using {num_cpus} CPUs")

    # all_cpus = set(range(os.cpu_count()))  # Get all available cores
    # os.sched_setaffinity(0, all_cpus)  # Set process affinity
    # print(f"Updated affinity: {os.sched_getaffinity(0)}")

    with ProcessPoolExecutor(max_workers=num_cpus, initializer=enable_monkey_patch) as executor:
        ret = list(
            executor.map(partial(get_steps, game_file=env._curr_gamefile, traj=traj), info["admissible_actions"]))

    step_count_list = [i[0] for i in ret]
    # print(step_count_list)
    expert_path_list = [i[1] for i in ret]
    # print(expert_path_list)
    min_step = np.min(step_count_list)

    minimum_indexes = np.where(step_count_list == min_step)[0]

    return min_step, [info["admissible_actions"][i] for i in minimum_indexes], [expert_path_list[i] for i in
                                                                                minimum_indexes]


def get_data_for_one_task(env, type, distribution, cutoff_length=200):
    history = []
    traj = []
    obs, infos = env.reset()
    history.append(obs)
    traj_length = 0
    skip_count = 0
    ret_list = []
    for k in range(cutoff_length):
        # if possible_actions
        if isinstance(infos, dict):
            possible_actions = infos["admissible_actions"]
        # expert_action = self.get_expert_actions(infos)

        min_step, expert_actions, expert_paths = get_expert_steps(env, info=infos, traj=traj)
        expert_action = expert_actions

        if None in expert_action:
            skip_count += 1
            break
        random_action = random.choice(possible_actions)
        if np.random.rand() < distribution:
            action_taken = expert_action[0]
        else:
            action_taken = random_action

        # print(possible_actions)
        # print(expert_action)
        # print(random_action)
        # print(action_taken)
        if type == "binary_feedback":
            for j in possible_actions:
                ret_list.append({"state": obs, "action": j, "feedback": 1 if j in expert_action else -1,
                                 "possible_actions": possible_actions, "history": copy.deepcopy(history),
                                 "expert_actions": expert_action})
        elif type == "preference":
            for a in possible_actions:
                for b in possible_actions:
                    if a != b:
                        if a in expert_action and b in expert_action:
                            preference = 0
                        elif a not in expert_action and b not in expert_action:
                            preference = 0
                        elif a in expert_action and b not in expert_action:
                            preference = 1
                        elif a not in expert_action and b in expert_action:
                            preference = -1
                        ret_list.append({"state": obs, "action1": a, "action2": b, "feedback": preference,
                                         "possible_actions": possible_actions,
                                         "history": copy.deepcopy(history), "expert_actions": expert_action})
        elif type == "action_advising":
            ret_list.append({"state": obs, "feedback": expert_action, "possible_actions": possible_actions,
                             "history": copy.deepcopy(history), "expert_actions": expert_action})
        # print(action_taken)
        # print(possible_actions)
        obs, reward, done, truncated, infos = env.step(action_taken)
        history.append(action_taken)
        history.append(obs)
        traj.append(action_taken)
        # print(history)
        traj_length = k
        if done:
            break

    return ret_list


def prepare_env_and_get_traj(rounds, type, distribution, cutoff_length=200, path="/h/PLACEHOLDER_FOR_ANOYNOMITYli/scratch/dataset/"):
    t_start = time.time()
    t_start_txt = datetime.fromtimestamp(t_start).strftime('%Y-%m-%d-%H:%M:%S')
    print("Start on round ", rounds, t_start_txt)

    env = VerbalizedALFWorld(split="eval_out_of_distribution", use_planner=True)
    env.expert_type = "planner"
    for i in range(rounds):
        env.reset()
    ret_list = get_data_for_one_task(env, type, distribution, cutoff_length)

    t_end = time.time()
    t_end_txt = datetime.fromtimestamp(t_end).strftime('%Y-%m-%d-%H:%M:%S')
    print("Finish on round ", rounds, t_end_txt)
    print("Time taken: ", (t_end - t_start))
    np.save(path + f"/ALF{type}_{distribution}_{rounds}.npy", ret_list)
    return ret_list


def get_all_data(type, distribution, rounds=134, cutoff_length=80, start=0, path="/h/PLACEHOLDER_FOR_ANOYNOMITYli/scratch/dataset/"):
    round_list = list(range(start, rounds))
    # with Pool() as pool:
    #     ret = pool.map(partial(prepare_env_and_get_traj, type=type, distribution=distribution, cutoff_length=cutoff_length), round_list)

    # num_cpus = int(os.getenv("SLURM_CPUS_PER_TASK", "1"))
    # num_cpus = int(os.getenv("SLURM_CPUS_PER_TASK", "1"))
    num_cpus = 5
    print(f"Using {num_cpus} CPUs")
    # all_cpus = set(range(os.cpu_count()))  # Get all available cores
    # os.sched_setaffinity(0, all_cpus)  # Set process affinity
    # print(f"Updated affinity: {os.sched_getaffinity(0)}")
    with ProcessPoolExecutor(max_workers=num_cpus, initializer=enable_monkey_patch) as executor:
        ret = list(executor.map(
            partial(prepare_env_and_get_traj, type=type, distribution=distribution, cutoff_length=cutoff_length, path=path),
            round_list))

    # flatten the list
    ret_list = [item for sublist in ret for item in sublist]
    return ret_list


if __name__ == "__main__":
    # os.environ["VERBENVS_DATA"] = "/h/PLACEHOLDER_FOR_ANOYNOMITYli/scratch/dataset/verbenvs"
    # os.environ["ALFWOLRD_DATA"] = "/h/PLACEHOLDER_FOR_ANOYNOMITYli/scratch/dataset/"
    os.environ["VERBENVS_DATA"] = "/Users/PLACEHOLDER_FOR_ANOYNOMITYli/verbenvs/"
    os.environ["ALFWOLRD_DATA"] = "/Users/PLACEHOLDER_FOR_ANOYNOMITYli/"

    # import tqdm
    #
    # env = VerbalizedALFWorld()
    # env.set_split("eval_out_of_distribution", use_planner=True)
    #
    # for counter in tqdm.tqdm(range(134)):
    #     env.reset()
    #     path1 = []
    #     for i in range(2):
    #         action = env.get_expert_action()
    #         obs, reward, done, truncated, infos = env.step(action)
    #         path1.append(action)
    #     # print(path1)
    #     # print(obs)
    #     min_step, expert_action, expert_path = get_expert_steps(env, infos, path1)
    #     if len(expert_action) != 1:
    #         print("Multiple Optimal Actions!")
    #         print(env._curr_gamefile)
    #         print(env.get_expert_action())
    #         print(min_step)
    #         print(expert_action)
    #         print(expert_path)

    import sys

    print("Arguments:", sys.argv)
    type = sys.argv[1]
    dist = float(sys.argv[2])
    data = get_all_data(type, dist, start=113, rounds=114, cutoff_length=80, path=PERSISTENT_DATA_PATH + "/")
    # np.save(f"/h/PLACEHOLDER_FOR_ANOYNOMITYli/scratch/dataset/ALF{sys.argv[1]}_{sys.argv[2]}.npy", data)

