from __future__ import annotations

import importlib
import logging
import math
import os
import pathlib
import random
import time
import traceback
from collections import Counter
from dataclasses import dataclass
from typing import Callable, List, Union
import matplotlib.pyplot as plt

import numpy as np

from vtamp.environments.utils import Action, Environment, State
from vtamp.policies.utils import (
    ContinuousSampler,
    DiscreteSampler,
    Policy,
    Sampler,
    parse_code,
    query_llm,
)
from vtamp.utils import (
    are_files_identical,
    get_log_dir,
    get_previous_log_folder,
    parse_text_prompt,
    read_file,
    save_log,
    write_prompt,
    encode_image_tob64,
)

_, _ = Action(), State()
log = logging.getLogger(__name__)


FUNC_NAME = "gen_plan"
FUNC_DOMAIN = "gen_domain"


def rejection_sample_csp(
    env: Environment,
    initial_state: State,
    plan_gen: Callable[[List[Union[int, float]]], List[Action]],
    domains_gen: List[Sampler],
    max_attempts: int = 10000,
) -> Union[List[Action], str]:
    """A constraint satisfaction strategy that randomly samples input vectors
    until it finds one that satisfies the constraints.

    If none are found, it returns the most common mode of failure.
    """
    # max_attempts = 3000
    max_attempts = 1500
    # max_attempts = 1
    lowest_cost = 1e7
    cost_history = []
    violation_modes = Counter()
    best_plan = None
    for i in range(max_attempts):
        log.info(f"CSP Sampling iter {i}")
        domains = domains_gen(initial_state)
        input_vec = {name: domain.sample() for name, domain in domains.items()}
        _ = env.reset()
        ground_plan = plan_gen(initial_state, **input_vec)
        constraint_violated = False

        # Rollout the entire action plan
        for ai, action in enumerate(ground_plan):
            _, _, _, info = env.step(action, vis=False)
        cost = env.compute_cost()
        if cost < lowest_cost:
            lowest_cost = cost
            best_plan = ground_plan
            best_x = input_vec
        cost_history.append(cost)

        # Terminate early if everything is already ok
        constraint_violated = not (cost < -1 or i+1 == max_attempts)

        if not constraint_violated:
            print(f"Solved problem at iter {i}. Cost {cost}")
            return best_plan, None, i, cost_history, best_x
        else:
            print(f"Finished iter {i}. Cost {cost}, (lowest cost {lowest_cost})")

    return None, violation_modes, i, cost_history, best_x


def import_constants_from_class(cls):
    # Get the module name from the class
    module_name = cls.__module__

    # Dynamically import the module
    module = importlib.import_module(module_name)

    # Import all uppercase attributes (assuming these are constants)
    for attribute_name in module.__all__:
        # Importing the attribute into the global namespace
        globals()[attribute_name] = getattr(module, attribute_name)
        print(f"Imported {attribute_name}: {globals()[attribute_name]}")


class Proc3s(Policy):
    def __init__(
        self,
        cost_thresh=1,
        twin=None,
        max_feedbacks=0,
        seed=0,
        max_csp_samples=10000,
        use_cache=False,
        **kwargs,
    ):
        self.cost_thresh = cost_thresh
        self.twin = twin
        self.seed = seed
        self.max_feedbacks = max_feedbacks
        self.max_csp_samples = max_csp_samples

        self.use_cache = use_cache

        import_constants_from_class(twin.__class__)
        
        # Get environment specific prompt
        prompt_fn = "prompt_{}".format(twin.__class__.__name__)
        prompt_path = os.path.join(
            pathlib.Path(__file__).parent, "{}.txt".format(prompt_fn)
        )

        self.prompt = parse_text_prompt(prompt_path)

        self.plan = None

    def get_action(self, belief, goal: str, llm_out: str=None):
        statistics = {}
        if self.plan is None:
            # No plan yet, we need to come up with one
            ground_plan, statistics = self.full_query_csp(belief, goal, llm_out)
            if ground_plan is None or not len(ground_plan):
                return None, statistics
            else:
                log.info("Found plan: {}".format(ground_plan))
                self.plan = ground_plan[1:]
                return ground_plan[0], statistics
        
        elif len(self.plan) > 0:
            next_action = self.plan[0]
            self.plan = self.plan[1:]
            return next_action, statistics
        
        else:
            return None, statistics

    def full_query_csp(self, belief, task, llm_out):
        _ = self.twin.reset()
        content = "Goal: {}".format(task)
        content = "State: {}\n".format(str(belief)) + content
        chat_history = self.prompt + [{"role": "user", "content": content}]
        statistics = {}
        statistics["csp_samples"] = 0
        statistics["csp_solve_time"] = 0
        statistics["llm_query_time"] = 0
        
        for iter in range(self.max_feedbacks + 1):
            statistics["num_feedbacks"] = iter
            st = time.time()
            input_fn = f"llm_input_{iter}.txt"
            output_fn = f"llm_output_{iter}.txt"
            write_prompt(input_fn, chat_history)

            # Check if the inputs match
            parent_log_folder = os.path.join(get_log_dir(), "..")
            previous_folder = get_previous_log_folder(parent_log_folder)
            llm_query_time = 0
            if (
                self.use_cache
                and os.path.isfile(os.path.join(previous_folder, output_fn))
                and are_files_identical(
                    os.path.join(previous_folder, input_fn),
                    os.path.join(get_log_dir(), input_fn),
                )
            ):
                log.info("Loading cached LLM response")
                llm_response = read_file(os.path.join(previous_folder, output_fn))
            else:
                log.info("Querying LLM")
                
                if llm_out is None:
                    llm_response, llm_query_time = query_llm(chat_history, seed=self.seed)
                else:
                    llm_response = open(llm_out, 'r').read()
                    if iter:
                        llm_response = f"Hmm, it seems like that didn't work. Here's the same thing lmao.\n{llm_response}"
                    llm_query_time = 0


            statistics["llm_query_time"] += llm_query_time

            chat_history.append({"role": "assistant", "content": llm_response})
            save_log(output_fn, llm_response)

            error_message = None
            ground_plan = None

            try:
                llm_code = parse_code(llm_response)
                exec(llm_code, globals())
                func = globals()[FUNC_NAME]
                domain = globals()[FUNC_DOMAIN]
                st = time.time()
                ground_plan, failure_message, csp_samples, cost_history, best_x = rejection_sample_csp(
                    self.twin,
                    belief,
                    func,
                    domain,
                    max_attempts=self.max_csp_samples,
                )
                statistics[f"cost_history_{iter}"] = cost_history
                statistics["csp_samples"] += csp_samples
                statistics["csp_solve_time"] += time.time() - st
                statistics["best_x"] = list(best_x)

                # Evaluate the generated plan
                self.twin.reset()

                if ground_plan is None:
                    if error_message is not None:
                        failure_response = error_message
                    else:
                        failure_response = ""
                        for fm, count in failure_message.most_common(2):
                            failure_response += f"{count} occurences: {fm}\n"

                    save_log(f"feedback_output_{iter}.txt", failure_response)
                    chat_history.append({"role": "user", "content": failure_response})
                    continue

                for action in ground_plan:
                    self.twin.step(action, vis=False)
                cost = self.twin.compute_cost()
                state = self.twin.getState()
                img = self.twin.render(False)
                self.twin.reset()
                plt.imsave("feedback.png", img)

                if ground_plan is not None and error_message is None and cost <= self.cost_thresh:
                    return ground_plan, statistics
                
                else:
                    # Feedback for plan cost being too high (csp_final_cost >= cost_threshold)
                    log.warning(f"Feedback {iter}: Plan cost {cost:.4f} >= {self.cost_thresh}")
                
                    img = encode_image_tob64("feedback.png")

                    feedback_str = (f"The best parameters that were found based on your solution are {best_x} "
                                    f"and have a cost of {cost}, which is above the target cost of <= {self.cost_thresh} "
                                    f"Please revise your solution accordingly."
                                    f"The final state after running the best solution is: {state}.")

                    feedback = {
                        "role": "user",
                        "content": [
                            {"type": "input_text", "text": f"{feedback_str} Image of final state after attempt {iter}"},
                            {
                                "type": "input_image",
                                "image_url": f"data:image/jpeg;base64,{img}",
                            },
                        ],
                    }

                    chat_history.append(feedback)

            except Exception as e:
                # Get the traceback as a string
                print(e)
                error_message = traceback.format_exc()
                log.info("Code error: " + str(error_message))

        import json
        json.dump(chat_history, open("chat_log.json", "w"))
        statistics["chat_history"] = chat_history

        return ground_plan, statistics