from __future__ import annotations

import importlib
import logging
import math
import os
import pathlib
import random

import numpy as np

from vtamp.environments.utils import Action, State
from vtamp.policies.utils import (
    Policy,
    guassian_rejection_sample,
    parse_code,
    query_llm,
)
from vtamp.utils import parse_text_prompt, save_log, write_prompt

_, _ = Action(), State()
log = logging.getLogger(__name__)

FUNC_NAME = "gen_plan"
FUNC_DOMAIN = "gen_domain"
ENGINE = "gpt-4-0125-preview"  # "gpt-4-turbo-2024-04-09" #"gpt-4-0125-preview"  #'gpt-3.5-turbo-instruct'


def import_constants_from_class(cls):
    # Get the module name from the class
    module_name = cls.__module__

    # Dynamically import the module
    module = importlib.import_module(module_name)

    # Import all uppercase attributes (assuming these are constants)
    for attribute_name in module.__all__:
        # Importing the attribute into the global namespace
        globals()[attribute_name] = getattr(module, attribute_name)
        print(f"Imported {attribute_name}: {globals()[attribute_name]}")


class CaP(Policy):
    def __init__(
        self,
        twin=None,
        max_feedbacks=0,
        seed=0,
        gaussian_blur=False,
        max_csp_samples=0,
        **kwargs,
    ):
        self.twin = twin
        self.seed = seed
        self.max_feedbacks = max_feedbacks
        self.gaussian_blur = gaussian_blur
        self.max_csp_samples = max_csp_samples

        import_constants_from_class(twin.__class__)

        # Get environment specific prompt
        prompt_fn = "prompt_{}".format(twin.__class__.__name__)
        prompt_path = os.path.join(
            pathlib.Path(__file__).parent, "{}.txt".format(prompt_fn)
        )

        self.prompt = parse_text_prompt(prompt_path)

        self.plan = None

    def get_action(self, belief, goal: str):
        if self.plan is None:
            # No plan yet, we need to come up with one
            content = "Goal: {}".format(goal)
            content = "initial={}\n".format(str(belief)) + content
            chat_history = self.prompt + [{"role": "user", "content": content}]
            llm_response = query_llm(chat_history, seed=0)
            write_prompt(f"llm_input_{iter}.txt", chat_history)
            chat_history.append({"role": "assistant", "content": llm_response})
            save_log(f"llm_output_{iter}.txt", llm_response)
            llm_code = parse_code(llm_response)
            exec(llm_code, globals())
            ground_plan_code = globals()["gen_plan"]
            ground_plan = ground_plan_code(belief)
            if self.gaussian_blur:
                blurred_plan = guassian_rejection_sample(
                    self.twin, ground_plan, max_attempts=self.max_csp_samples
                )
                if blurred_plan is not None:
                    self.plan = blurred_plan[1:]
                    return blurred_plan[0]

            self.plan = ground_plan[1:]
            return ground_plan[0]

        elif len(self.plan) > 0:
            next_action = self.plan[0]
            self.plan = self.plan[1:]
            return next_action

        return None
