from __future__ import annotations

import os
import cma
import json
import pathlib
import logging
import numpy as np
import importlib
import traceback
from typing import Callable, List, Union
import matplotlib.pyplot as plt

logging.getLogger('matplotlib.font_manager').setLevel(logging.WARNING)

from vtamp.environments.utils import Action, Environment, State
from vtamp.policies.utils import (
    Policy,
    Sampler,
    parse_code,
    query_llm,
)
from vtamp.utils import parse_text_prompt, save_log, write_prompt

_, _ = Action(), State()
log = logging.getLogger(__name__)

FUNC_NAME = "gen_plan"
FUNC_DOMAIN = "gen_initial_guess"

def reshape_like(template: list, flat_list: list) -> list:
    def helper(template):
        if isinstance(template, list):
            return [helper(item) for item in template]
        else:
            return flat_list.pop(0)
    
    # Make a copy to avoid modifying original
    flat_list = flat_list.copy()
    return helper(template)

def flatten(nested: list) -> list:
    flat = []
    for item in nested:
        if isinstance(item, list):
            flat.extend(flatten(item))
        else:
            flat.append(item)
    return flat


def bbo_on_motion_plan(
    env,
    initial_state,
    plan_gen: Callable[[List[Union[int, float]]], List],
    domains_gen: Callable,
    use_komo: bool = False,
    max_evals: int = 1000,
    optimizer: str = "hill_climb",  # or "cma"
) -> Union[List, str]:
    failure_message = ""

    domains = domains_gen(initial_state)
    initial_x = [v for _, v in domains.items()]

    evals_infeasible = []
    evals_feasible = []
    cost_history = []
    best_qs = []

    max_evals = 3000
    # max_evals = 1

    def compute_cost(input_vec: np.ndarray) -> float:
        input_vec_reshaped = reshape_like(initial_x, list(input_vec))
        env.reset()
        ground_plan = plan_gen(initial_state, *input_vec_reshaped)
        if use_komo:
            env.step_komo(ground_plan, vis=False)
        else:
            for action in ground_plan:
                env.step(action, vis=False)
        cost = env.compute_cost()
        cost_history.append(cost)
        
        if min(cost_history) >= cost:
            best_qs.append(env.path)

        if cost >= 100:
            evals_infeasible.append(input_vec)
        else:
            evals_feasible.append(input_vec)
        print(f"Input vec {np.round(input_vec, 3)}; Cost {np.round(cost, 3)}, Iter: {len(evals_feasible)+len(evals_infeasible)}")
        return cost

    # Init plan eval
    print("Rendering initial plan")
    env.reset()
    ground_plan = plan_gen(initial_state, *initial_x)
    for action in ground_plan:
        env.step(action, vis=False)

    flat_init = flatten(initial_x)

    if optimizer == "hill_climb":
        print(f"Running Probabilistic Hill Climbing with {max_evals} evaluations")
        x_best = flat_init
        y_best = compute_cost(flat_init)
        step_size = 0.05
        eval_count = 0

        while eval_count < max_evals:
            improved = False

            # Sample a random perturbation vector from a normal distribution
            perturbation = np.random.normal(loc=0.0, scale=step_size, size=len(x_best))
            x_new = x_best + perturbation
            y_new = compute_cost(x_new)
            eval_count += 1

            if y_new < y_best:
                x_best = x_new
                y_best = y_new
                improved = True
                print(f"Improved with random step; New cost: {np.round(y_best, 3)}")

            if not improved:
                step_size *= 0.9  # decay step size slowly
                print(f"No improvement, reducing step size to {step_size:.5f}")
                if step_size < 1e-4:
                    step_size *= 1/0.9
                    print("Step size too small, not stopping tho.")

        best_x = x_best
        best_y = y_best

    elif optimizer == "cma":
        print(f"Running CMA-ES with {max_evals} evaluations")
        bbo_options = {
            # "popsize": 2,
            "maxfevals": max_evals,
            "ftarget": 0,
            "CMA_active": True,
        }
        es = cma.CMAEvolutionStrategy(flat_init, 0.01, bbo_options)
        best_x, es_result = cma.fmin2(compute_cost, x0=flat_init, sigma0=0.01, options=bbo_options)
        best_y = es_result.result.fbest
    else:
        raise ValueError(f"Unknown optimizer: {optimizer}. Use 'hill_climb' or 'cma'.")
    # Convert best_x to a list if it's a numpy array
    if isinstance(best_x, np.ndarray):
        best_x = best_x.tolist()

    # Generate final plan
    ground_plan = plan_gen(initial_state, *reshape_like(initial_x, best_x))
    
    # Print summary information
    print(f"Optimization completed")
    print(f"Total evaluations: {len(cost_history)}")
    print(f"Best solution: {best_x}")
    print(f"Best cost: {best_y}")

    return ground_plan, failure_message, best_x, best_qs[-1], cost_history


def import_constants_from_class(cls):
    module_name = cls.__module__
    module = importlib.import_module(module_name)
    for attribute_name in module.__all__:
        globals()[attribute_name] = getattr(module, attribute_name)
        print(f"Imported {attribute_name}: {globals()[attribute_name]}")


class DENECK(Policy):
    def __init__(
        self,
        twin=None,
        max_feedbacks=0,
        seed=0,
        max_evals=1000,
        use_cache=True,
        use_komo=False,
        optim="cma",
        **kwargs,
    ):
        self.twin = twin
        self.seed = seed
        self.max_feedbacks = max_feedbacks
        self.max_evals = max_evals
        self.use_komo = use_komo
        self.use_cache = use_cache
        self.optim = optim

        import_constants_from_class(twin.__class__)

        prompt_fn = "prompt_{}".format(twin.__class__.__name__)
        prompt_path = os.path.join(pathlib.Path(__file__).parent, "{}.txt".format(prompt_fn))
        self.prompt = parse_text_prompt(prompt_path)
        self.plan = None

    def get_action(self, belief, goal: str, llm_out: str=None):
        statistics = {}
        if self.plan is None:
            ground_plan, statistics = self.full_query_bbo(belief, goal, llm_out)
            if ground_plan is None:
                return None, statistics
            elif self.use_komo:
                self.plan = ground_plan
                return self.plan, statistics
            else:
                self.plan = ground_plan[1:]
                return ground_plan[0], statistics
        elif self.use_komo:
            return self.plan, statistics
        elif len(self.plan) > 0:
            next_action = self.plan[0]
            self.plan = self.plan[1:]
            return next_action, statistics
        return None, statistics

    def full_query_bbo(self, belief, task, llm_out: str=None):
        self.twin.reset()
        content = "initial={}\nGoal: {}".format(str(belief), task)
        chat_history = self.prompt + [{"role": "user", "content": content}]
        statistics = {
            "bbo_evals": 0,
            "bbo_solve_time": 0,
            "llm_query_time": 0,
            "num_bbo_evals": 0,
        }

        input_fn = "llm_input.txt"
        output_fn = "llm_output.txt"
        write_prompt(input_fn, chat_history)
        
        if llm_out is None:
            llm_response, llm_query_time = query_llm(chat_history, seed=0)
        else:
            llm_response = open(llm_out, 'r').read()
            llm_query_time = 0
        
        statistics["llm_query_time"] += llm_query_time
        
        chat_history.append({"role": "assistant", "content": llm_response})
        save_log(output_fn, llm_response)

        error_message = None
        ground_plan = None

        try:
            llm_code = parse_code(llm_response)
            exec(llm_code, globals())
            func = globals()[FUNC_NAME]
            if FUNC_DOMAIN in globals():
                domain = globals()[FUNC_DOMAIN]
                ground_plan, failure_message, bext_x, best_qs, cost_hist = bbo_on_motion_plan(
                    self.twin,
                    belief,
                    func,
                    domain,
                    max_evals=self.max_evals,
                    use_komo=self.use_komo,
                    optimizer=self.optim
                )
                # json.dump(best_qs, open("qs.json", 'w'))
                statistics["cost_history"] = cost_hist
                statistics["bext_x"], statistics["best_qs"] = bext_x, best_qs.tolist()
            else:
                log.info("No variables given to optimize. Continuing without BBO.")
                ground_plan = func(belief)

        except Exception as e:
            error_message = traceback.format_exc()
            print(error_message)
            exit()
            print(f"Attempt {attempts}: Code execution error:\n{error_message}")
            chat_history.append({
                "role": "user",
                "content": f"That didn't work! Error: {error_message}"
            })
            attempts += 1

        return ground_plan, statistics
