from __future__ import annotations

import os
import cma
import json
import pathlib
import logging
import numpy as np
import importlib
import traceback
from typing import Callable, List, Union
import matplotlib.pyplot as plt

logging.getLogger('matplotlib.font_manager').setLevel(logging.WARNING)

from vtamp.environments.utils import Action, Environment, State
from vtamp.policies.utils import (
    Policy,
    Sampler,
    parse_code,
    query_llm,
    encode_image_tob64
)
from vtamp.utils import parse_text_prompt, save_log, write_prompt

_, _ = Action(), State()
log = logging.getLogger(__name__)

FUNC_NAME = "gen_plan"
FUNC_DOMAIN = "gen_initial_guess"

def reshape_like(template: list, flat_list: list) -> list:
    def helper(template):
        if isinstance(template, list):
            return [helper(item) for item in template]
        else:
            return flat_list.pop(0)
    
    # Make a copy to avoid modifying original
    flat_list = flat_list.copy()
    return helper(template)

def flatten(nested: list) -> list:
    flat = []
    for item in nested:
        if isinstance(item, list):
            flat.extend(flatten(item))
        else:
            flat.append(item)
    return flat


def bbo_on_motion_plan(
    env,
    initial_state,
    plan_gen: Callable[[List[Union[int, float]]], List],
    domains_gen: Callable,
    use_komo: bool = False,
    max_evals: int = 1000,
    optimizer: str = "hill_climb",  # or "cma"
) -> Union[List, str]:
    failure_message = ""

    domains = domains_gen(initial_state)
    initial_x = [v for _, v in domains.items()]

    evals_infeasible = []
    evals_feasible = []
    cost_history = []
    best_qs = []

    # max_evals = 3000
    max_evals = 1500
    # max_evals = 1

    def compute_cost(input_vec: np.ndarray) -> float:
        input_vec_reshaped = reshape_like(initial_x, list(input_vec))
        env.reset()
        ground_plan = plan_gen(initial_state, *input_vec_reshaped)
        if use_komo:
            env.step_komo(ground_plan, vis=False)
        else:
            for action in ground_plan:
                env.step(action, vis=False)
        cost = env.compute_cost()
        cost_history.append(cost)
        
        if min(cost_history) >= cost:
            best_qs.append(env.path)

        if cost >= 100:
            evals_infeasible.append(input_vec)
        else:
            evals_feasible.append(input_vec)
        print(f"Input vec {np.round(input_vec, 3)}; Cost {np.round(cost, 3)}, Iter: {len(evals_feasible)+len(evals_infeasible)}")
        return cost

    # Init plan eval
    print("Rendering initial plan")
    env.reset()
    ground_plan = plan_gen(initial_state, *initial_x)
    for action in ground_plan:
        env.step(action, vis=False)

    flat_init = flatten(initial_x)

    if optimizer == "hill_climb":
        print(f"Running Probabilistic Hill Climbing with {max_evals} evaluations")
        x_best = flat_init
        y_best = compute_cost(flat_init)
        step_size = 0.05
        eval_count = 0

        while eval_count < max_evals:
            improved = False

            # Sample a random perturbation vector from a normal distribution
            perturbation = np.random.normal(loc=0.0, scale=step_size, size=len(x_best))
            x_new = x_best + perturbation
            y_new = compute_cost(x_new)
            eval_count += 1

            if y_new < y_best:
                x_best = x_new
                y_best = y_new
                improved = True
                print(f"Improved with random step; New cost: {np.round(y_best, 3)}")

            if not improved:
                step_size *= 0.9  # decay step size slowly
                print(f"No improvement, reducing step size to {step_size:.5f}")
                if step_size < 1e-4:
                    step_size *= 1/0.9
                    print("Step size too small, not stopping tho.")

        best_x = x_best
        best_y = y_best

    elif optimizer == "cma":
        print(f"Running CMA-ES with {max_evals} evaluations")
        bbo_options = {
            # "popsize": 2,
            "maxfevals": max_evals,
            "ftarget": 0,
            "CMA_active": True,
        }
        es = cma.CMAEvolutionStrategy(flat_init, 0.01, bbo_options)
        best_x, es_result = cma.fmin2(compute_cost, x0=flat_init, sigma0=0.3, options=bbo_options)
        best_y = es_result.result.fbest
    else:
        raise ValueError(f"Unknown optimizer: {optimizer}. Use 'hill_climb' or 'cma'.")
    # Convert best_x to a list if it's a numpy array
    if isinstance(best_x, np.ndarray):
        best_x = best_x.tolist()

    # Generate final plan
    ground_plan = plan_gen(initial_state, *reshape_like(initial_x, best_x))
    
    # Print summary information
    print(f"Optimization completed")
    print(f"Total evaluations: {len(cost_history)}")
    print(f"Best solution: {best_x}")
    print(f"Best cost: {best_y}")

    return ground_plan, failure_message, best_x, best_qs[-1], cost_history


def import_constants_from_class(cls):
    module_name = cls.__module__
    module = importlib.import_module(module_name)
    for attribute_name in module.__all__:
        globals()[attribute_name] = getattr(module, attribute_name)
        print(f"Imported {attribute_name}: {globals()[attribute_name]}")


class DENECK(Policy):
    def __init__(
        self,
        cost_thresh=1,
        twin=None,
        max_feedbacks=0,
        seed=0,
        max_evals=1000,
        use_cache=True,
        use_komo=False,
        optim="cma",
        **kwargs,
    ):
        self.twin = twin
        self.seed = seed
        self.cost_thresh = cost_thresh
        self.max_feedbacks = max_feedbacks
        self.max_evals = max_evals
        self.use_komo = use_komo
        self.use_cache = use_cache
        self.optim = optim

        import_constants_from_class(twin.__class__)

        prompt_fn = "prompt_{}".format(twin.__class__.__name__)
        prompt_path = os.path.join(pathlib.Path(__file__).parent, "{}.txt".format(prompt_fn))
        self.prompt = parse_text_prompt(prompt_path)
        self.plan = None

    def get_action(self, belief, goal: str, llm_out: str=None):
        statistics = {}
        if self.plan is None:
            ground_plan, statistics = self.full_query_bbo(belief, goal, llm_out)
            if ground_plan is None or not len(ground_plan):
                return None, statistics
            elif self.use_komo:
                self.plan = ground_plan
                return self.plan, statistics
            else:
                self.plan = ground_plan[1:]
                return ground_plan[0], statistics
        elif self.use_komo:
            return self.plan, statistics
        elif len(self.plan) > 0:
            next_action = self.plan[0]
            self.plan = self.plan[1:]
            return next_action, statistics
        return None, statistics

    def full_query_bbo(self, belief, task, llm_out: str=None):
        self.twin.reset()
        content = "initial={}\nGoal: {}".format(str(belief), task)
        chat_history = self.prompt + [{"role": "user", "content": content}]
        statistics = {
            "bbo_evals": 0,
            "bbo_solve_time": 0,
            "llm_query_time": 0,
            "num_bbo_evals": 0,
        }

        attempts = 0
        success = False
        ground_plan = None
        error_message = None

        # while attempts < self.config.max_feedbacks and not success:
        while attempts <= self.max_feedbacks and not success:
            input_fn = f"llm_input_{attempts}.txt"
            output_fn = f"llm_output_{attempts}.txt"
            write_prompt(input_fn, chat_history)

            if llm_out is None:
                llm_response, llm_query_time = query_llm(chat_history, seed=self.seed)
            else:
                llm_response = open(llm_out, 'r').read()
                llm_query_time = 0

            #####################################################

            statistics["llm_query_time"] += llm_query_time
            write_prompt("llm_input.txt", chat_history)
            chat_history.append({"role": "assistant", "content": llm_response})
            save_log(output_fn, llm_response)
            try:
                llm_code = parse_code(llm_response)
                exec(llm_code, globals())
                komo_generator = globals()[FUNC_NAME]
                if FUNC_DOMAIN in globals():
                    domain = globals()[FUNC_DOMAIN]
                    ground_plan, failure_message, best_x, best_qs, cost_hist = bbo_on_motion_plan(
                        self.twin,
                        belief,
                        komo_generator,
                        domain,
                        max_evals=self.max_evals,
                        use_komo=self.use_komo,
                        optimizer=self.optim
                    )
                    # json.dump(best_qs, open("qs.json", 'w'))
                    statistics[f"cost_history_{attempts}"] = cost_hist
                    statistics["bext_x"], statistics["best_qs"] = best_x, best_qs.tolist()
                else:
                    log.info("No variables provided to optimize. Continuing without ES.")
                    ground_plan = komo_generator(belief)

                # Evaluate the generated plan
                self.twin.reset()
                for action in ground_plan:
                    self.twin.step(action, vis=False)
                cost = self.twin.compute_cost()
                state = self.twin.getState()
                img = self.twin.render(False)
                self.twin.reset()
                plt.imsave("feedback.png", img)

                if ground_plan is not None and error_message is None and cost <= self.cost_thresh:
                    return ground_plan, statistics
                
                else:
                    # Feedback for plan cost being too high (csp_final_cost >= cost_threshold)
                    log.warning(f"Feedback {attempts}: Plan cost {cost:.4f} >= {self.cost_thresh}")
                
                    img = encode_image_tob64("feedback.png")

                    feedback_str = (f"The best parameters that were found based on your solution are {best_x} "
                                    f"and have a cost of {cost}, which is above the target cost of <= {self.cost_thresh} "
                                    f"Please revise your solution accordingly."
                                    f"The final state after running the best solution is: {state}.")

                    feedback = {
                        "role": "user",
                        "content": [
                            {"type": "input_text", "text": f"{feedback_str} Image of final state after attempt {attempts}"},
                            {
                                "type": "input_image",
                                "image_url": f"data:image/jpeg;base64,{img}",
                            },
                        ],
                    }

                    chat_history.append(feedback)
                    attempts += 1

            except Exception as e:
                error_message = traceback.format_exc()
                print(error_message)
                exit()
                print(f"Attempt {attempts}: Code execution error:\n{error_message}")
                chat_history.append({
                    "role": "user",
                    "content": f"That didn't work! Error: {error_message}"
                })
                attempts += 1

        import json
        json.dump(chat_history, open("chat_log.json", "w"))
        statistics["chat_history"] = chat_history

        return ground_plan, statistics
