import random
from typing import List, Dict

import copy
from verbenvs.envs.alfworld import VerbalizedALFWorld
from tqdm import tqdm
import numpy as np

from ALFWorldSearch import get_expert_steps
from DataGenerators.DataGenerator import DataGenerator





class AlfWorldDataGenerator(DataGenerator):
    def __init__(self, env, type="binary_feedback", distribution=1):
        self.env = env
        self.type = type
        self.distribution = distribution
        self.state_array = None
        self.traj = []

    def initialise_env(self):
        self.env = VerbalizedALFWorld()
        self.env.set_split("eval_out_of_distribution", use_planner=True)
        # self.env.set_split("valid_unseen")
        self.env.expert_type = "planner"

    def sample_array_of_states(self, number) -> np.array:
        pass

    def sample_array_of_trajectories(self, number) -> List[Dict]:
        pass

    def get_expert_actions(self, state) -> List:
        # expert_actions = []
        # expert_actions.append(self.env.get_expert_action())
        min_step, expert_actions, expert_paths = get_expert_steps(self.env, info=state, traj=self.traj)
        return expert_actions

    def get_expert_value(self, state) -> float:
        pass

    def get_expert_qvalue(self, state, action) -> float:
        expert_actions = self.get_expert_actions(state)
        if action in expert_actions:
            return 1
        else:
            return -1

    def get_expert_binary_feedback(self, state, action) -> int:
        """
        This function returns the binary feedback that the expert would give for a given state-action pair.
        :param state:
        :param action:
        :return: 1, positive, -1 negative
        TODO: indifferent feedback?
        """
        expert_actions = self.get_expert_actions(state)
        if action in expert_actions:
            return 1
        else:
            return -1

    def get_expert_preference(self, state, action1, action2) -> int:
        expert_q1 = self.get_expert_qvalue(state, action1)
        expert_q2 = self.get_expert_qvalue(state, action2)
        if expert_q1 > expert_q2:
            return 1
        elif expert_q1 == expert_q2:
            return 0
        else:
            return -1

    def get_expert_action_advising(self, state) -> List:
        """
        This function returns the action that the expert would advise in a given state.
        TODO: Do we want stochastic or deterministic advising?
        :param state: np.array of single action
        :return:
        """
        expert_actions = self.get_expert_actions(state)
        return expert_actions

    def sample_data(self, cutoff_length=100, rounds=134) -> List[Dict]:
        ret_list = []
        self.initialise_env()
        skip_count = 0
        bar = tqdm(range(rounds))
        for i in bar:

            history = []
            self.traj = []
            obs, infos = self.env.reset()
            history.append(obs)
            traj_length = 0

            for k in range(cutoff_length):
                # if possible_actions
                if isinstance(infos, dict):
                    possible_actions = infos["admissible_actions"]
                expert_action = self.get_expert_actions(infos)
                if None in expert_action:
                    skip_count += 1
                    break
                random_action = random.choice(possible_actions)
                if np.random.rand() < self.distribution:
                    action_taken = expert_action[0]
                else:
                    action_taken = random_action

                # print(possible_actions)
                # print(expert_action)
                # print(random_action)
                # print(action_taken)
                if self.type == "binary_feedback":
                    for j in possible_actions:
                        ret_list.append({"state": obs, "action": j, "feedback": 1 if j in expert_action else -1,
                                         "possible_actions": possible_actions, "history": copy.deepcopy(history),
                                         "expert_actions": expert_action})
                elif self.type == "preference":
                    for a in possible_actions:
                        for b in possible_actions:
                            if a != b:
                                if a in expert_action and b in expert_action:
                                    preference = 0
                                elif a not in expert_action and b not in expert_action:
                                    preference = 0
                                elif a in expert_action and b not in expert_action:
                                    preference = 1
                                elif a not in expert_action and b in expert_action:
                                    preference = -1
                                ret_list.append({"state": obs, "action1": a, "action2": b, "feedback": preference,
                                                 "possible_actions": possible_actions,
                                                 "history": copy.deepcopy(history), "expert_actions": expert_action})
                elif self.type == "action_advising":
                    ret_list.append({"state": obs, "feedback": expert_action, "possible_actions": possible_actions,
                                     "history": copy.deepcopy(history), "expert_actions": expert_action})
                # print(action_taken)
                # print(possible_actions)
                obs, reward, done, truncated, infos = self.env.step(action_taken)
                history.append(action_taken)
                history.append(obs)
                self.traj.append(action_taken)
                # print(history)
                traj_length = k
                if done:
                    break
            bar.set_description(
                f"Skip Count: {skip_count}, Traj Length: {traj_length}, Count: {i}")

        return ret_list





if __name__ == "__main__":
    import os

    os.environ["VERBENVS_DATA"] = "/h/PLACEHOLDER_FOR_ANOYNOMITYli/scratch/dataset/verbenvs"
    os.environ["ALFWOLRD_DATA"] = "/h/PLACEHOLDER_FOR_ANOYNOMITYli/scratch/dataset/"
    # for i in [(1, "binary_feedback")]:
    #     mgdg = AlfWorldDataGenerator(env="AlfWorld", type=i[1], distribution=i[0])
    #     data = mgdg.sample_data(cutoff_length=80, rounds=134)
    #     np.save(f"persistent_data/ALF/ALF{i[1]}_{i[0]}.npy", data)
    #
    # for i in [(1, "preference")]:
    #     mgdg = AlfWorldDataGenerator(env="AlfWorld", type=i[1], distribution=i[0])
    #     data = mgdg.sample_data(cutoff_length=80, rounds=134)
    #     np.save(f"persistent_data/ALF/ALF{i[1]}_{i[0]}.npy", data)

    # for i in [(1, "action_advising"), (0, "action_advising"), (0.5, "action_advising")]:
    #     mgdg = AlfWorldDataGenerator(env="AlfWorld", type=i[1], distribution=i[0])
    #     data = mgdg.sample_data(cutoff_length=80, rounds=134)
    #     np.save(f"persistent_data/ALF/ALF{i[1]}_{i[0]}.npy", data)
    #
    # for i in [(1, "preference"), (0, "preference"), (0.5, "preference")]:
    #     mgdg = AlfWorldDataGenerator(env="AlfWorld", type=i[1], distribution=i[0])
    #     data = mgdg.sample_data(cutoff_length=80, rounds=134)
    #     np.save(f"persistent_data/ALF/ALF{i[1]}_{i[0]}.npy", data)
    #
    # for i in [(1, "binary_feedback"), (0, "binary_feedback"), (0.5, "binary_feedback")]:
    #     mgdg = AlfWorldDataGenerator(env="AlfWorld", type=i[1], distribution=i[0])
    #     data = mgdg.sample_data(cutoff_length=80, rounds=134)
    #     np.save(f"persistent_data/ALF/ALF{i[1]}_{i[0]}.npy", data)

    import sys

    print("Arguments:", sys.argv)
    # for dist in [0, 1, 0.5]:
    #     mgdg = AlfWorldDataGenerator(env="AlfWorld", type=sys.argv[1], distribution=dist)
    #     data = mgdg.sample_data(cutoff_length=80, rounds=134)
    #     np.save(f"persistent_data/ALF/ALF{sys.argv[1]}_{dist}.npy", data)
    mgdg = AlfWorldDataGenerator(env="AlfWorld", type=sys.argv[1], distribution=float(sys.argv[2]))
    data = mgdg.sample_data(cutoff_length=80, rounds=134)
    np.save(f"/h/PLACEHOLDER_FOR_ANOYNOMITYli/scratch/dataset/ALF{sys.argv[1]}_{sys.argv[2]}.npy", data)


