from abc import abstractmethod

import gymnasium as gym
import stable_baselines3
from stable_baselines3 import PPO
import numpy as np
import torch
import torch.nn as nn
from stable_baselines3 import A2C, PPO

from config import *
from tqdm import tqdm
from together import Together
from Testers.Test import LMFeedbackVerifier
from abc import ABC, abstractmethod
import json, requests


def join_list_of_strings(list_of_strings):
    return "\n".join(list_of_strings)



def alf_obs_extractor(data, condition_list):
    return ""

def construct_current(instruction_obs_dict):
    template = """
You received instruction: TASK.
Your observation is: OBSERVATION.
    """
    if isinstance(instruction_obs_dict, dict):
        return template.replace("TASK", instruction_obs_dict["instruction"]).replace("OBSERVATION", instruction_obs_dict["observation"])
    else:
        print("Unexpected type encountered: ", instruction_obs_dict)
        return instruction_obs_dict


def construct_history(instruction_obs_dict, action):
    template = """
    You received instruction: TASK.
    Your observation is: OBSERVATION.
    You took action: ACTIONTAKEN.
"""
    if isinstance(instruction_obs_dict, dict):
        return template.replace("TASK", instruction_obs_dict["instruction"]).replace("OBSERVATION",
                                                                                     instruction_obs_dict["observation"]).replace("ACTIONTAKEN", action)
    else:
        print("Unexpected type encountered: ", instruction_obs_dict)
        return instruction_obs_dict

def process_history_list(history_list):
    final_string = ""
    previous = ""
    for i in range(0, len(history_list) - 1, 2):
        instruction_obs_dict = history_list[i]
        action_taken = history_list[i+1]
        one_step = construct_history(instruction_obs_dict, action_taken)

        if one_step != previous:
            final_string += one_step


        #Delete repeated steps to save tokens
        previous = one_step



    # print("History String\n")
    # print(final_string)
    # print("History List\n")
    # print(history_list)
    # print("\n")
    return final_string


class ALFVerifier(LMFeedbackVerifier):
    def __init__(self, env, feedback_type, data_path, condition_list, **kwargs):
        super().__init__(env, feedback_type, data_path, condition_list, **kwargs)
        self.use_action_map_dict = False
        self.obs_representation_extractor = alf_obs_extractor

        self.base_prompt = "You will be challenged with ALF World, a text-based game."

        self.if_optimal_prompt_cot = """
HISTORY1
Below is your history:
HISTORY2
You can take the following actions:
POSSIBLELIST
Is action ACTION the best action you can take? Please think step by step.
Only give the answer in a new line in JSON format:
{"reasoning": <REASONING>, "feedback": <FEEDBACK>}
Where <FEEDBACK> is one of "YES" or "NO", <REASONING> is a string of your thinking steps.
                """

        self.action_advising_base_prompt_cot = """
HISTORY1
Below is your history:
HISTORY2
You can take the following actions:
POSSIBLELIST
Which action do you choose? Please think step by step.
Only give the answer in a new line in JSON format:
{"reasoning": <REASONING>, "action": <ACTION>}
Where <ACTION> is one of possible actions, <REASONING> is a string of your thinking steps.
                """

        self.preference_base_prompt_cot = """
HISTORY1
Below is your history:
HISTORY2
You can take the following actions:
POSSIBLELIST
Given ACTION1 or ACTION2, which action is better? Please think step by step.
Only give the answer in a new line in JSON format:
{"reasoning": <REASONING>, "preference": <PREFERENCE>}
Where <PREFERENCE> is one of "FIRST" or "SECOND", <REASONING> is a string of your thinking steps.
        """

        self.explicit_thinking_guides = """
You should think about these questions.
What is your current goal?
Based on the rules you know, what subgoals do you need to achieve?
Based on what you see and what you have, what subgoals have you achieved?
If you have not achieved all of the subgoals, based on the rules and doable actions, can you achieve any subgoals now?
If yes, which subgoal do you want to achieve next?
If no, what do you need or where should you go next?
    """.strip()


    def domain_specific_prompt_process(self, data, prompt):
        if self.feedback_type == "binary_feedback":
                final_string = prompt.replace("HISTORY1", construct_current(data["state"])).replace("HISTORY2", "None\n" if len(data["history"])==1 else process_history_list(data["history"])).replace("POSSIBLELIST", join_list_of_strings(data["possible_actions"]))
        if self.feedback_type == "action_advising":
                final_string = prompt.replace("HISTORY1", construct_current(data["state"]))   \
                                .replace("HISTORY2", "None\n" if len(data["history"])==1 else process_history_list(data["history"]))  \
                                .replace("POSSIBLELIST", join_list_of_strings(data["possible_actions"]))
        if self.feedback_type == "preference":
                final_string = prompt.replace("HISTORY1", construct_current(data["state"]))   \
                                .replace("HISTORY2", "None\n" if len(data["history"])==1 else process_history_list(data["history"]))  \
                                .replace("POSSIBLELIST", join_list_of_strings(data["possible_actions"]))
        # if "explicit_thinking_guides" in self.condition_list:
        #     final_string += self.explicit_thinking_guides
        return final_string

    # def generate_prompt(self, data):
    #     if_optimal_prompt_cot = self.if_optimal_prompt_cot
    #     action_advising_base_prompt_cot = self.action_advising_base_prompt_cot
    #
    #     preference_base_prompt_cot = self.preference_base_prompt_cot
    #
    #
    #     if self.feedback_type == "binary_feedback":
    #         final_string = if_optimal_prompt_cot.replace("HISTORY1", data["history"][0]).replace("HISTORY2", "None\n" if len(data["history"])==1 else join_list_of_strings(data["history"][1:])).replace("POSSIBLEACTION", join_list_of_strings(data["possible_actions"])).replace("ACTION", data["action"])
    #     if self.feedback_type == "action_advising":
    #         final_string = action_advising_base_prompt_cot.replace("HISTORY1", data["history"][0])   \
    #                         .replace("HISTORY2", "None\n" if len(data["history"])==1 else join_list_of_strings(data["history"][1:]))  \
    #                         .replace("POSSIBLEACTION", join_list_of_strings(data["possible_actions"]))
    #     if self.feedback_type == "preference":
    #         final_string = preference_base_prompt_cot.replace("HISTORY1", data["history"][0])   \
    #                         .replace("HISTORY2", "None\n" if len(data["history"])==1 else join_list_of_strings(data["history"][1:]))  \
    #                         .replace("POSSIBLEACTION", join_list_of_strings(data["possible_actions"])) \
    #                         .replace("ACTION1", data["action1"]) \
    #                         .replace("ACTION2", data["action2"])
    #     return final_string


    # # TODO code reuse among these three functions?
    # def verify_binary_feedback(self, response, data):
    #     response_to_numeric_dict = {
    #         "YES": 1,
    #         "NO": -1
    #     }
    #     ret_dict = {"State": data, "Response": response}
    #
    #     expert_feedback = data["feedback"]
    #     try:
    #         response = response.split("{")
    #         response = response[-1]
    #         response = "{" + response
    #         llm = json.loads(response)
    #     except Exception as e:
    #         print(e)
    #         ret_dict["Correct"] = False
    #         ret_dict["JSONCorrect"] = False
    #         return ret_dict
    #     llm_feedback = response_to_numeric_dict[llm["feedback"].upper()]
    #     ret_dict["JSONCorrect"] = True
    #     if llm_feedback == expert_feedback:
    #         ret_dict["Correct"] = True
    #     else:
    #         ret_dict["Correct"] = False
    #
    #     return ret_dict
    #
    # def verify_action_advising(self, response, data):
    #     action_to_number_dict = {
    #         "TURN LEFT": 0,
    #         "TURN RIGHT": 1,
    #         "MOVE FORWARD": 2,
    #         "PICK UP THE KEY": 3,
    #         "UNLOCK THE DOOR": 5
    #     }
    #     ret_dict = {"State": data, "Response": response}
    #
    #     expert_feedback = data["feedback"]
    #     try:
    #         response = response.split("{")
    #         response = response[-1]
    #         response = "{" + response
    #         llm = json.loads(response)
    #     except Exception as e:
    #         print(e)
    #         ret_dict["Correct"] = False
    #         ret_dict["JSONCorrect"] = False
    #         return ret_dict
    #     llm_feedback = llm["action"]
    #     # llm_feedback = action_to_number_dict[llm["action"].upper()]
    #     ret_dict["JSONCorrect"] = True
    #     # print(data)
    #     # print(llm_feedback, expert_feedback)
    #     if llm_feedback in expert_feedback:
    #         ret_dict["Correct"] = True
    #     else:
    #         ret_dict["Correct"] = False
    #
    #     return ret_dict
    #
    #
    # def verify_preference(self, response, data):
    #
    #     preference_to_number_dict = {
    #         "FIRST": 1,
    #         "SECOND": -1
    #     }
    #     ret_dict = {"State": data, "Response": response}
    #
    #     expert_feedback = data["feedback"]
    #     try:
    #         response = response.split("{")
    #         response = response[-1]
    #         response = "{" + response
    #         llm = json.loads(response)
    #     except Exception as e:
    #         print(e)
    #         ret_dict["Correct"] = False
    #         ret_dict["JSONCorrect"] = False
    #         return ret_dict
    #     llm_feedback = preference_to_number_dict[llm["preference"].upper()]
    #     ret_dict["JSONCorrect"] = True
    #     # print(data["state"]["expert_paths"])
    #     if llm_feedback == expert_feedback or expert_feedback == 0:
    #         ret_dict["Correct"] = True
    #     else:
    #         ret_dict["Correct"] = False
    #
    #     return ret_dict



if __name__ == "__main__":
    type = "binary_feedback"
    distribution = 1
    # for type in ["binary_feedback", "action_advising", "preference"]:
    for type in ["binary_feedback", "action_advising", "preference"]:
        for distribution in [1]:
            v = ALFVerifier("ALFWolrd", type,
                                  PERSISTENT_DATA_PATH + "/ALF/ALF{type}_{distribution}.npy".format(type=type, distribution=distribution), [])
            results = v.verify(1, source="ollama", url="http://localhost:11434/api/chat", model="llama3.1:8b-instruct-fp16", max_item_num=2000)
