
from plansearch.python_utils import log_to_dir, batch_map_on_nested_list, index_nested_list, map_nary_fn_on_nested_list, markdown_codeblock_extract
from plansearch.ObservationNode import ObservationNode
import plansearch.prompts as prompts
from plansearch.Agent import Agent, AgentSupplierType, AgentConfig
from nips2025.cache import CacheConfig
from nips2025.Simulator import SimulatorCache, SimulatorConfig
from nips2025.Parser import CacheCodeParser, DesignJsonParser
from utils import write_to_file
from typing import Any, Optional, Tuple, List, Dict
import argparse
import os
import random
import datetime

class PlanSearch:
    # COMPLETION_FROM_MODEL_SUPPORTED = True
    record_jsonl_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log", "record.jsonl")
    statistics_json_path = record_jsonl_path.replace("record.jsonl", "statistics.json")
    def __init__(
        self, 
        # # cache simulator params
        # cache_capacity: int,
        # cache_trace_path: str,
        # agent parameters
        agent_supplier: AgentSupplierType,
        # # signatary params
        # test_folder: str,
        # trace_filter,

        # log parameters
        experiment_directory: Optional[str] = record_jsonl_path.replace("record.jsonl", "plansearch_log"), 
        # cache_file: Optional[str] = None, 
        # querier_batch_size: Optional[int] = 12_288, 
        
        # structure params
        max_observation_k: int = 2, 
        num_observations_to_generate: int = 10, 
        num_observation_layers: int = 2,
        use_fix_idea: bool = False,
        use_pseudocode: bool = False,
        use_idea: bool = True,

        # frequency_penalty: Optional[float] = None, 
        # logit_bias: Optional[dict[str, int]] = None, 
        # max_tokens: Optional[int] = None, 
        # presence_penalty: Optional[float] = None, 
        # seed: Optional[int] = None, 
        # stop: Union[Optional[str], list[str]] = None, 
        # idea_temperature: Optional[float] = None, 
        # code_temperature: Optional[float] = None, 
        # top_p: Optional[float] = None,
    ):
        # super().__init__(
        #     "observation", 
        #     experiment_directory=experiment_directory, 
        #     cache_file=cache_file, 
        #     querier_batch_size=querier_batch_size
        # )

        # self.idea_model = idea_model_config_path
        # self.code_model = code_model_config_path
        print("PlanSearch: Generate Observations")
        
        os.makedirs(experiment_directory, exist_ok=True)
        self.experiment_directory = experiment_directory

        self.max_observation_k = max_observation_k
        self.num_observations_to_generate = num_observations_to_generate
        self.num_observation_layers = num_observation_layers
        self.use_fix_idea = use_fix_idea
        self.use_pseudocode = use_pseudocode
        self.use_idea = use_idea

        if not use_idea:
            assert use_pseudocode == False and use_fix_idea == False, "If not using solution sketch, must disable both pseudocode and fixing idea step"

        # # simulator
        # cache_simulator = SimulatorCache(
        #     SimulatorConfig(
        #         name="Cache",
        #         config=CacheConfig(
        #             capacity=cache_capacity,
        #             consider_obj_size=False,
        #             trace_path=cache_trace_path,
        #             key_col_id=1,
        #             size_col_id=2,
        #             has_header=False,
        #             delimiter=","
        #         ),
        #         system_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache"),
        #         tune_runs=20,
        #         code_folder=os.path.join(os.path.dirname(os.path.abspath(__file__)), "log", "code"),
        #         tune_int_upper=None
        #     )
        # )

        # agents
        self.obsgen_agent = Agent(AgentConfig(agent_name="ObsGen", temperature=1.0, agent_supplier=agent_supplier))
        self.obs2list_agent = Agent(AgentConfig(agent_name="Obs2List", temperature=1.0, agent_supplier=agent_supplier))
        self.obsfilter_agent = Agent(AgentConfig(agent_name="ObsFilter", temperature=1.0, agent_supplier=agent_supplier))
        self.obs2python_agent = Agent(AgentConfig(agent_name="Obs2Python", temperature=1.0, agent_supplier=agent_supplier))
        self.design_agent = Agent(AgentConfig(agent_name="Design", answer_parser=DesignJsonParser(), temperature=1.0, agent_supplier=agent_supplier))
        # self.code_agent = self.code_agent = Agent(AgentConfig(agent_name="Code", temperature=0.5, trial_num=3, answer_parser=CacheCodeParser(unique_simulator=cache_simulator), agent_supplier=AgentSupplierType.OPENAI))
        
        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "prompt", "solution_formulation_with_hints.txt"), 'r') as file:
            self.DESIGN_PROMPT_TEMPL = file.read().strip()
        # with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache", "code_prompt_templ_norand.txt"), 'r') as file:
        #     self.CODE_PROMPT_TEMPL = file.read().strip()
        
        # self.frequency_penalty = frequency_penalty
        # self.logit_bias = logit_bias
        # self.max_tokens = max_tokens
        # self.presence_penalty = presence_penalty
        # self.seed = seed
        # self.stop = stop
        # self.code_temperature = code_temperature
        # self.idea_temperature = idea_temperature
        # self.top_p = top_p
        # random.seed(self.seed)
 
    # def query_model(self, model_type: str, prompts: list[list[dict[str, str]]], temperature: Optional[float] = None) -> list[str]:
    #     assert model_type in {"idea", "code"}
    #     if model_type == "idea":
    #         model = self.idea_model
    #         use_temperature = self.idea_temperature
    #     else:
    #         model = self.code_model
    #         use_temperature = self.code_temperature
    #     if temperature is not None:
    #         use_temperature = temperature

    #     outputs = self.querier.generate(model, 
    #                           prompts,
    #                           frequency_penalty=self.frequency_penalty,
    #                           logit_bias=self.logit_bias,
    #                           max_tokens=self.max_tokens,
    #                           presence_penalty=self.presence_penalty,
    #                           seed=self.seed,
    #                           stop=self.stop,
    #                           temperature=use_temperature,
    #                           top_p=self.top_p,
    #                           requery=True,
    #                           timeout=600,
    #                           )
    #     assert len(outputs) == len(prompts)
    #     return outputs

    
    def _format_observation_combo(self, obs_combo: Tuple[str]):
        at_least_two = len(obs_combo) >= 2
        has_observation = len(obs_combo) >= 1

        if at_least_two:
            observation_str = "Here are several correct observations which will help in solving the task:\n"
            observation_str += "- " + "\n- ".join(obs_combo) + "\n"
        elif has_observation:
            observation_str = "Here is a correct observation which will help in solving the task:\n"
            observation_str += f"- {obs_combo[0]}\n"
        else:
            observation_str = "No observations are necessary to solve this task.\n"

        first_blank = 'observations' if has_observation else 'task'
        second_blank = 'combine the observations' if at_least_two else 'use the implications'
        third_blank = '[Quotes of the observations]\n\n' if has_observation else ''
        fourth_blank = ' of joining the observations above' if at_least_two else ('of the observation above' if has_observation else '')
        fifth_blank = ' of the old observations' if has_observation else ''


        return observation_str, first_blank, second_blank, third_blank, fourth_blank, fifth_blank
    
    def get_prompt(self, iter_num: int, problem: str, observation_combo: Tuple[str, ...]):
        if iter_num == 0:
            assert len(observation_combo) == 0
            first_obs_prompt = [
                {"role": "system", "content": prompts.SYSTEM_PROMPT_OBSERVATION},
                {"role": "user", "content": prompts.FIRST_OBS_PROMPT_TEMPL.replace("[[num_obs]]", str(self.num_observations_to_generate))}
            ]
            return first_obs_prompt
            # return self.prompts.get_first_observations_prompt(problem.problem_str, self.num_observations_to_generate)
        
        else:
            assert isinstance(observation_combo, tuple)
            observation_str, first_blank, second_blank, third_blank, fourth_blank, fifth_blank = self._format_observation_combo(observation_combo)
            combined_obs_prompt = [
                {"role": "system", "content": prompts.SYSTEM_PROMPT_OBSERVATION2},
                {"role": "user", "content": (prompts.COMBINED_OBS_PROMPT_TEMPL
                                                .replace("[[obs_combo]]", observation_str)
                                                .replace("[[first_blank]]", first_blank)
                                                .replace("[[second_blank]]", second_blank)
                                                .replace("[[third_blank]]", third_blank)
                                                .replace("[[fourth_blank]]", fourth_blank)
                                                .replace("[[fifth_blank]]", fifth_blank))}
            ]
            return combined_obs_prompt
            # return self.prompts.get_combined_observations_prompt(problem.problem_str, observation_combo)

    def get_observations_strs(self, problems_obs: List[Tuple[str, Optional[Tuple[str, ...]]]], iter_num: int = 0) -> List[Tuple[Tuple[str, ...], Dict[str, Any]]]:
        problems = [problem for problem, _ in problems_obs]
        observation_combos = [obs for _, obs in problems_obs] # A list of obesrvation_combo, each combo is a tuple of observations

        ################ 1. LLM-generated observations  ####################################################
        get_observations_prompts = [self.get_prompt(iter_num, problem, obs_combo) for problem, obs_combo in zip(problems, observation_combos)]
        observations_strs = [self.obsgen_agent.answer(p) for p in get_observations_prompts]
        # observations_strs = self.query_model("idea", get_observations_prompts)
        #####################################################################################################

        ################ 2. LLM-parsed lists  ###############################################################
        # get_parse_into_list_prompts = [prompts + 
        #               ({"role": "assistant", "content": observations_str},
        #                {"role": "user", "content": self.prompts.FORMAT_INTO_LIST_PROMPT})
        #               for prompts, observations_str in zip(get_observations_prompts, observations_strs)]
        # observations_strlists = self.query_model("idea", get_parse_into_list_prompts)
        get_parse_into_list_prompts = [
            get_obs_prompt + [
                {"role": "assistant", "content": obs_str},
                {"role": "user", "content": prompts.FORMAT_INTO_LIST_PROMPT}
            ]
            for get_obs_prompt, obs_str in zip(get_observations_prompts, observations_strs)
        ]
        observations_strlists = [self.obs2list_agent.answer(p) for p in get_parse_into_list_prompts]
        ###################################################################################################
        
        ################ 3. LLM-filtered lists  ###########################################################
        # filter_observations_prompts = [prompts + 
        #               ({"role": "assistant", "content": observations_strlist},
        #                {"role": "user", "content": self.prompts.FILTER_TO_USEFUL_LIST_PROMPT})
        #               for prompts, observations_strlist in zip(get_parse_into_list_prompts, observations_strlists)]
        # filtered_obs_strlists = self.query_model("idea", filter_observations_prompts)
        if iter_num > 0:
            filter_observations_prompts = [
                get_parse_into_list_p + [
                    {"role": "assistant", "content": obs_strlist},
                    {"role": "user", "content": prompts.FILTER_TO_USEFUL_LIST_PROMPT}
                ]
                for get_parse_into_list_p, obs_strlist in zip(get_parse_into_list_prompts, observations_strlists)
            ]
            filtered_obs_strlists = [self.obsfilter_agent.answer(p) for p in filter_observations_prompts]
        else:
            filter_observations_prompts = [
                get_parse_into_list_p
                for get_parse_into_list_p, obs_strlist in zip(get_parse_into_list_prompts, observations_strlists)
            ]
            filtered_obs_strlists = [
                obs_strlist
                for obs_strlist in observations_strlists
            ]
        #####################################################################################################

        ################ 4. LLM-parsed lists  ###############################################################
        # parse_to_python_prompts = [prompts + 
        #               ({"role": "assistant", "content": filtered_obs_strlist},
        #                {"role": "user", "content": self.prompts.PARSE_INTO_PYTHON_LIST_PROMPT})
        #               for prompts, filtered_obs_strlist in zip(filter_observations_prompts, filtered_obs_strlists)]
        parse_to_python_prompts = [
            filter_obs_p + [
                {"role": "assistant", "content": filtered_obs_strlist},
                {"role": "user", "content": prompts.PARSE_INTO_PYTHON_LIST_PROMPT}
            ]
            for filter_obs_p, filtered_obs_strlist in zip(filter_observations_prompts, filtered_obs_strlists)
        ]
        python_obs_lists = [self.obs2python_agent.answer(p) for p in parse_to_python_prompts]
        MAX_PARSE_TRIES = 3
        python_obs_lists = [None] * len(filtered_obs_strlists)
        unused_idxs = list(range(len(filtered_obs_strlists)))
        log_attempted_parses = [[] for _ in range(len(filtered_obs_strlists))]
        for iter in range(MAX_PARSE_TRIES):
            to_query = [parse_to_python_prompts[i] for i in unused_idxs]
        #     if getattr(self.querier.clients[self.idea_model], "model_is_o1", False):
        #         temperature = 1
        #     else:
        #         temperature = iter * 0.2
            parsed_obs_pythonlists = [self.obs2python_agent.answer(p) for p in to_query]
        #     parsed_obs_pythonlists = self.query_model("idea", to_query, temperature=temperature)
            for orig_idx, parse in zip(unused_idxs, parsed_obs_pythonlists):
                log_attempted_parses[orig_idx].append(parse)
            for orig_idx, parsed_obs_python in zip(unused_idxs, parsed_obs_pythonlists):
                try:
                    attempted_parse = eval(markdown_codeblock_extract(parsed_obs_python))
                    assert isinstance(attempted_parse, list)
                    assert all(isinstance(parse, str) for parse in attempted_parse)
                    python_obs_lists[orig_idx] = attempted_parse
                except:
                    pass
            unused_idxs = [i for i, el in enumerate(python_obs_lists) if el is None]
            if len(unused_idxs) == 0:
                break
        # ################

        if any(el is None for el in python_obs_lists):
            print("Warning: Python parsing of observation lists failed.")

        python_obs_lists = [() if el is None else tuple(el) for el in python_obs_lists]

        logs = [{
            "problem_str": problems[i],
            "observations_str": observations_strs[i],
            "observations_strlist": observations_strlists[i],
            "filtered_strlist": filtered_obs_strlists[i],
            "attempted_parses": log_attempted_parses[i],
            "python_observation_list": python_obs_lists[i]}
            for i in range(len(problems))]

        assert len(python_obs_lists) == len(logs)
        return [(obs_list, log) for obs_list, log in zip(python_obs_lists, logs)]

    def get_nl_solutions_from_obs_combos(self, problem_observation_combos: List[Tuple[str, Tuple[str, ...]]]) -> List[Tuple[List[Tuple[str, str]], Dict[str, Any]]]:
        problems = [problem for problem, _ in problem_observation_combos]
        observation_combos = [obs for _, obs in problem_observation_combos]

        ################ LLM-generated nl_sols: use our own design agent!##############
        # get_nl_sols_prompt = [self.prompts.get_nl_sols_prompt(problem.problem_str, observation_combo)
        #                       for problem, observation_combo in zip(problems, observation_combos)]
        # nl_solutions = self.query_model("idea", get_nl_sols_prompt)
        #### TODO: We do it later
        # get_nl_sols_prompt = [
        #     [
        #         {"role": "user", "content": self.DESIGN_PROMPT_TEMPL.replace("[[hints]]", self._format_observation_combo(obs_combo)[0])}
        #     ]
        #     for problem, obs_combo in zip(problems, observation_combos)
        # ]
        # nl_solutions = [self.design_agent.answer(p) for p in get_nl_sols_prompt]
        #### TODO ####
        nl_solutions = ["" for p in problems]
        ###############################################################################

        if self.use_fix_idea:
            assert False, f"PlanSearch.use_fix_idea = {self.use_fix_idea}, but we require it to be False." # never allowed
            get_criticism_prompts = [self.prompts.get_criticism_prompt(problem, nl_solution)
                                for problem, nl_solution in zip(problems, nl_solutions)]
            criticisms = self.query_model("idea", get_criticism_prompts)

            get_fixes_prompts = [prompt + (
                                            {"role": "assistant", "content": criticism},
                                            {"role": "user", "content": self.prompts.FIX_CRITICISM_PROMPT}
                                        )
                                    for prompt, criticism in zip(get_criticism_prompts, criticisms)]
            fixes = self.query_model("idea", get_fixes_prompts)

            logs = [{"problem_str": problems[i].problem_str, "observation_combo": observation_combos[i], "original_solution": nl_solutions[i], "criticism": criticisms[i], "fixes": fixes[i]} for i in range(len(problems))]
            return [
                        (
                            [(problem, orig_sol), (problem, fixed_sol)],
                            log
                        )
                    for problem, orig_sol, fixed_sol, log in zip(problems, nl_solutions, fixes, logs)]
        else:
            logs = [{"problem_str": problems[i], "observation_combo": observation_combos[i], "original_solution": nl_solutions[i]} for i in range(len(problems))]
            return [
                        (
                            [(problem, orig_sol)],
                            log
                        )
                    for problem, orig_sol, log in zip(problems, nl_solutions, logs)]

    def get_code_solution_from_nl_solutions(self, problem_nl_solutions: List[Tuple[str, str]]) -> List[Tuple[str, str, Dict[str, Any]]]:
        assert False, f"Doens't support nl->code currently. Do it separately."
        expanded_problems = [problem for problem, _ in problem_nl_solutions]
        nl_solutions = [nl_sol for _, nl_sol in problem_nl_solutions]

        if self.use_pseudocode:
            assert False, f"self.use_pseudocode = {self.use_pseudocode}, but we require it to be False."
            get_pseudocode_prompt = [self.prompts.get_pseudocode_prompt(problem.problem_str, nl_solution)
                                    for problem, nl_solution in zip(expanded_problems, nl_solutions)]
            pseudocodes = self.query_model("code", get_pseudocode_prompt)
            get_code_prompt = [self.prompts.pseudocode_to_code_solution_prompt(problem.problem_str, problem.starter_code, pseudocode)
                                    for problem, pseudocode in zip(expanded_problems, pseudocodes)]
        else:
            # pseudocodes = ["__PSEUDOCODE_FLAG_UNSET__"] * len(expanded_problems)
            ################ TODO: LLM-generated code
            get_code_prompt = [self.prompts.idea_to_code_solution_prompt(problem, nl_solution)
                               for problem, nl_solution in zip(expanded_problems, nl_solutions)]

            get_code_prompt = [
                self.CODE_PROMPT_TEMPL.replace()
            ]

        output_codes = self.query_model("code", get_code_prompt)
        parsed_codes = [markdown_codeblock_extract(genned).strip() for genned in output_codes]
        ################

        logs = [{
            "problem_str": expanded_problems[i],
            "nl_solution": nl_solutions[i],
            "pseudocode": pseudocodes[i],
            "output_code": output_codes[i],
            "parsed_code": parsed_codes[i]}
                for i in range(len(expanded_problems))]

        return [(code, log) for code, log in zip(parsed_codes, logs)]
    
    def get_code_solution_direct_from_obs_combos(self, problem_observation_combos: List[Tuple[str, Tuple[str, ...]]]) -> List[Tuple[str, Dict[str, Any]]]:
        assert False, f"Never allow LLM to generate code directly from observation combos."
        problems = [problem for problem, _ in problem_observation_combos]
        observation_combos = [obs for _, obs in problem_observation_combos]

        get_code_prompt = [self.prompts.get_code_sol_from_obs_combo_prompt(problem.problem_str, problem.starter_code, observation_combo)
                              for problem, observation_combo in zip(problems, observation_combos)]

        output_codes = self.query_model("code", get_code_prompt)
        parsed_codes = [markdown_codeblock_extract(genned).strip() for genned in output_codes]

        logs = [{
            "prompt": get_code_prompt[i],
            "problem_str": problems[i].problem_str,
            "observations": observation_combos[i],
            "output_code": output_codes[i],
            "parsed_code": parsed_codes[i]}
                for i in range(len(problems))]

        return [(code, log) for code, log in zip(parsed_codes, logs)]

    def generate_solutions(self, problems: List[str], *args, **kwargs) -> List[List[str]]:
        '''
        Main functoin for PlanSearch
        '''
        # num_completions = kwargs.get("num_completions", 1)
        # for plansearch, this is default as -1
        num_completions = -1

        observation_nodes = [ObservationNode(problem, 0, (), self.num_observations_to_generate, self.max_observation_k) for problem in problems]
        for iter_num in range(self.num_observation_layers):
            highest_level_problem_obs = [node.collect_highest_level_problem_obs() for node in observation_nodes]

            observations_lists_logs = batch_map_on_nested_list(highest_level_problem_obs, lambda li: self.get_observations_strs(li, iter_num))
            observations_lists: List[List[Tuple[str, ...]]] = map_nary_fn_on_nested_list(lambda x: x[0], observations_lists_logs)
            observations_logs: List[List[Dict[str, Any]]] = map_nary_fn_on_nested_list(lambda x: x[1], observations_lists_logs)

            to_log = []
            for i, node in enumerate(observation_nodes):
                node.attribute_new_observations(observations_lists[i], observations_logs[i])
                to_log.append(node.collect_logs())
            
            log_to_dir(self.experiment_directory, {f"observation_{iter_num}_{datetime.datetime.now().strftime('%m-%dT%H:%M:%S')}.json": to_log})


        all_problem_obs_combos = [node.collect_all_problem_obs() for node in observation_nodes]


        if self.use_idea:

            orig_and_fixed_nl_solutions_w_logs: List[List[List[Tuple[List[Tuple[str, str], Dict[str, Any]]]]]] = batch_map_on_nested_list(all_problem_obs_combos, self.get_nl_solutions_from_obs_combos)
            nl_solution_logs = map_nary_fn_on_nested_list(lambda x: x[1], orig_and_fixed_nl_solutions_w_logs)
            problem_nl_solutions: List[List[List[List[Tuple[str, str]]]]] = map_nary_fn_on_nested_list(lambda x: x[0], orig_and_fixed_nl_solutions_w_logs)

            log_to_dir(self.experiment_directory, {f"nl_solutions_{datetime.datetime.now().strftime('%m-%dT%H:%M:%S')}.json": nl_solution_logs})
            
            return problem_nl_solutions
            ##### TODO: currently not suppported. We plan to do this separately: code after all the nl solutions are given.
            # code_sols_and_logs = batch_map_on_nested_list(problem_nl_solutions, self.get_code_solution_from_nl_solutions)

        else:
            assert False, f"self.use_idea = {self.use_idea}, but we require it to be True."
            code_sols_and_logs = batch_map_on_nested_list(all_problem_obs_combos, self.get_code_solution_direct_from_obs_combos)



        code_sols: List[List[List[str]]] = map_nary_fn_on_nested_list(lambda x: x[0], code_sols_and_logs)
        code_logs: List[List[List[Dict[str, Any]]]] = map_nary_fn_on_nested_list(lambda x: x[1], code_sols_and_logs)
        log_to_dir(os.path.join(self.experiment_directory), {f"codes_{datetime.datetime.now().strftime('%m-%dT%H:%M:%S')}.json": code_logs})


        output_codes: List[List[str]] = []

        for code_sol_for_problem in code_sols:

            flattened_code_sols: List[str] = []
            index_nested_list(code_sol_for_problem, flattened_code_sols, [])
            random.shuffle(flattened_code_sols)
            
            # if num_completions < 0:
            if num_completions == -1:
                output_codes.append(flattened_code_sols)
            else:
                assert False, f"num_completions = {num_completions}, but the author required this to be -1"
                output_codes.append([])
                for completion_idx in range(num_completions):
                    output_codes[-1].append(flattened_code_sols[completion_idx % len(flattened_code_sols)])
            
        return output_codes

    def optimize(self):
        problem_nl_solutions = self.generate_solutions(
            problems=[""]
        )
        write_to_file(
            dest_path=self.statistics_json_path,
            contents=self.to_dict(),
            is_append=False,
            is_json=True
        )
        write_to_file(
            dest_path=self.record_jsonl_path.replace("record.jsonl", "problem_nl_solutions.json"),
            contents=problem_nl_solutions,
            is_append=False,
            is_json=True
        )

    def to_dict(self):
        return {
            "experiment_directory": self.experiment_directory,
            "max_observation_k": self.max_observation_k,
            "num_observations_to_generate": self.num_observations_to_generate,
            "num_observation_layer": self.num_observation_layers,
            "use_fix_idea": self.use_fix_idea,
            "use_pseudocode": self.use_pseudocode,
            "use_idea": self.use_idea,
            # agents
            "obs_agent": self.obsgen_agent.to_dict(),
            "obs2list_agent": self.obs2list_agent.to_dict(),
            "obs2filter_agent": self.obsfilter_agent.to_dict(),
            "obs2python_agent": self.obs2python_agent.to_dict(),
            "design_agent": self.design_agent.to_dict(),
        }