from __future__ import annotations

import importlib
import logging
import os
import pathlib

from vtamp.environments.utils import Action, State
from vtamp.policies.utils import (
    Policy,
    guassian_rejection_sample,
    parse_code,
    query_llm,
)
from vtamp.utils import parse_text_prompt, save_log, write_prompt

_, _ = Action(), State()
log = logging.getLogger(__name__)

FUNC_NAME = "gen_plan"
FUNC_DOMAIN = "gen_domain"
ENGINE = "gpt-4-0125-preview"  # "gpt-4-turbo-2024-04-09" #"gpt-4-0125-preview"  #'gpt-3.5-turbo-instruct'


def import_constants_from_class(cls):
    # Get the module name from the class
    module_name = cls.__module__

    # Dynamically import the module
    module = importlib.import_module(module_name)

    # Import all uppercase attributes (assuming these are constants)
    for attribute_name in module.__all__:
        # Importing the attribute into the global namespace
        globals()[attribute_name] = getattr(module, attribute_name)
        print(f"Imported {attribute_name}: {globals()[attribute_name]}")


class LLM3(Policy):
    def __init__(
        self,
        twin=None,
        max_feedbacks=0,
        gaussian_blur=False,
        seed=0,
        max_csp_samples=0,
        **kwargs,
    ):
        self.twin = twin
        self.max_csp_samples = max_csp_samples
        self.gaussian_blur = gaussian_blur
        self.seed = seed
        self.max_feedbacks = max_feedbacks

        import_constants_from_class(twin.__class__)

        # Get environment specific prompt
        prompt_fn = "prompt_{}".format(twin.__class__.__name__)
        prompt_path = os.path.join(
            pathlib.Path(__file__).parent, "{}.txt".format(prompt_fn)
        )
        self.prompt = parse_text_prompt(prompt_path)

        self.plan = None

    def get_action(self, belief, goal: str):
        if self.plan is None:
            # No plan yet, we need to come up with one
            content = "Goal: {}".format(goal)
            content = "State: {}\n".format(str(belief)) + content

            chat_history = self.prompt + [{"role": "user", "content": content}]
            for iter in range(self.max_feedbacks + 1):
                llm_response = query_llm(chat_history, seed=0)
                write_prompt(f"llm_input_{iter}.txt", chat_history)
                chat_history.append({"role": "assistant", "content": llm_response})
                save_log(f"llm_output_{iter}.txt", llm_response)
                llm_code = parse_code(llm_response)
                exec(llm_code, globals())
                ground_plan = globals()["gen_plan"]

                violation_str = ""
                _ = self.twin.reset()
                for ai, action in enumerate(ground_plan):
                    _, _, _, info = self.twin.step(action)
                    if len(info["constraint_violations"]) > 0:
                        violation_str += "Step {}, Action {}, Violations: {}\n".format(
                            ai, action.name, info["constraint_violations"]
                        )
                        break
                if len(violation_str) > 0 and self.max_feedbacks > iter:
                    if self.gaussian_blur:
                        blurred_plan = guassian_rejection_sample(
                            self.twin, ground_plan, max_attempts=self.max_csp_samples
                        )
                        if blurred_plan is not None:
                            self.plan = blurred_plan[1:]
                            return blurred_plan[0]
                    failure_response = "Failure: {}".format(str(violation_str))
                    save_log(f"feedback_output_{iter}.txt", failure_response)
                    chat_history.append({"role": "user", "content": failure_response})
                else:
                    if (
                        ground_plan is not None
                        and isinstance(ground_plan, list)
                        and len(ground_plan) > 0
                    ):
                        self.plan = ground_plan[1:]
                        return ground_plan[0]
        elif len(self.plan) > 0:
            next_action = self.plan[0]
            self.plan = self.plan[1:]
            return next_action

        return None
