from Entry import Entry, EntryType
from nips2025.Design import Design
from nips2025.cache import CacheConfig
from nips2025.FeedbackEmbedder import FeedbackEmbedder
from Agent import AgentConfig, Agent, AgentSupplierType
from nips2025.Parser import DesignJsonParser, CacheCodeParser
from nips2025.Simulator import SimulatorCache, SimulatorConfig
from utils import write_to_file

import os
import logging
import logging_config
import random
from typing import List
import json

def parse_obs_combo_item(item):
    if isinstance(item, list):
        obs_combo_list = []
        for l in item:
            obs_combo_list += parse_obs_combo_item(l)
        return obs_combo_list
    else:
        assert isinstance(item, dict)
        return [item["observation_combo"]]

def load_observation_combos(nl_solution_json_path):
    '''
    Same order!
    '''
    with open(nl_solution_json_path, 'r') as file:
        raw_list = json.load(file)

    obs_combo_list = parse_obs_combo_item(raw_list)

    print(f"There will be {len(obs_combo_list)} designs!")
    obs_combo_set = set([tuple(c) for c in obs_combo_list])
    print(f"There will be {len(set(obs_combo_set))} different designs!")
    return obs_combo_list

class PlanSearchDesignAndCode:
    # output paths
    record_jsonl_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log", "record.jsonl")
    statistics_json_path = record_jsonl_path.replace("record.jsonl", "statistics.json")
    def __init__(
        self,
        nl_solution_json_path: str, # obs_combo_list
        # cache simulator params
        cache_capacity: int,
        cache_trace_path: str,
        # agent params
        agent_supplier: AgentSupplierType,
        # feedback embedding params
        test_folder: str,
        trace_filter,
    ):
        print("PlanSearch: Generate Solutions")
        random.seed(42)
        # user-defined parameters: check and store
        assert cache_capacity >= 1
        assert os.path.exists(cache_trace_path)
        assert os.path.exists(test_folder)

        # structure
        self.entry_counter = 0
        self.llm_call_counter = 0

        # obs_combo_list
        self.obs_combo_counter = 0
        self.nl_solution_json_path = nl_solution_json_path
        self.obs_combo_list = load_observation_combos(nl_solution_json_path)
        self.tot_llm_call_num = len(self.obs_combo_list)
        
        # feedback_embedder
        self.feedback_embedder: FeedbackEmbedder = FeedbackEmbedder(test_folder=test_folder, trace_filter=trace_filter)
        
        # simulator
        cache_simulator = SimulatorCache(
            SimulatorConfig(
                name="Cache",
                config=CacheConfig(
                    capacity=cache_capacity,
                    consider_obj_size=False,
                    trace_path=cache_trace_path,
                    key_col_id=1,
                    size_col_id=2,
                    has_header=False,
                    delimiter=","
                ),
                system_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache"),
                tune_runs=20,
                code_folder=os.path.join(os.path.dirname(os.path.abspath(__file__)), "log", "code"),
                tune_int_upper=None
            )
        )

        # agents
        self.design_agent = Agent(AgentConfig(agent_name="Design", temperature=1.0, answer_parser=DesignJsonParser(), agent_supplier=agent_supplier))
        self.code_agent = Agent(AgentConfig(agent_name="Code", temperature=0.5, trial_num=3, answer_parser=CacheCodeParser(unique_simulator=cache_simulator), agent_supplier=AgentSupplierType.OPENAI))
        
        # prompts
        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "prompt", "waypoint_reasoning.txt"), 'r') as file:
            self.MAP_PROPMT_TEMPL = file.read().strip()
        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "prompt", "solution_formulation_with_hints.txt"), 'r') as file:
            self.DESIGN_PROMPT_TEMPL = file.read().strip()
        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "prompt", "code_generation.txt"), 'r') as file:
            self.CODE_PROMPT_TEMPL = file.read().strip()

    def _set_entry_feedback_embedding(self, entry: Entry):
        assert entry.code != None
        if not entry.code.endswith(".py"):
            # sota
            code = entry.code
            is_sota = True
        else:
            assert os.path.exists(entry.code)
            with open(entry.code, 'r') as file:
                code = file.read()
            is_sota = False
        entry.feedback_embedding = self.feedback_embedder.get_feedback_embedding(code, is_sota)
        return entry.feedback_embedding

    def _format_obs_combo(self, obs_combo: List[str]):
        if len(obs_combo) == 0:
            return "No hints."
        return "\n".join([f"- {d}" for d in obs_combo])

    def _create_entry(self):
        '''
        Create an entry from a given keyword set `answer`, or a give combined design `answer`.
        Update:
        - self.entry_counter: +1
        - self.llm_call_counter: +1
        '''
        logging.info(f"Creating Entry {self.entry_counter}...")
        entry = Entry(id=self.entry_counter, entry_type=EntryType.PLANSEARCH)
        obs_combo = self.obs_combo_list[self.obs_combo_counter]
        self.obs_combo_counter += 1
        if obs_combo != None:
            hint_description = self._format_obs_combo(obs_combo)
            design_prompt = self.DESIGN_PROMPT_TEMPL.replace("[[hints]]", hint_description)
            design_dict = self.design_agent.answer(design_prompt)
            if design_dict != None:
                entry.design = Design(design_dict)
                # set code, miss_ratio_info
                if entry.design != None:
                    code_prompt = (self.CODE_PROMPT_TEMPL
                                    .replace("[[design]]", entry.design.to_str())
                                    .replace("[[metadata]]", entry.design._format(entry.design.metadata, False))
                                    .replace("[[evict]]", entry.design._format(entry.design.evict, False))
                                    .replace("[[update_after_hit]]", entry.design._format(entry.design.update_after_hit, False))
                                    .replace("[[update_after_insert]]", entry.design._format(entry.design.update_after_insert, False))
                                    .replace("[[update_after_evict]]", entry.design._format(entry.design.update_after_evict, False)))
                    assert isinstance(self.code_agent.answer_parser, CacheCodeParser)
                    self.code_agent.answer_parser.set_code_id(self.entry_counter)
                    miss_ratio_info_tuple = self.code_agent.answer(code_prompt)
                    if miss_ratio_info_tuple != None:
                        entry.code = self.code_agent.answer_parser.simulator.code_path
                        assert os.path.exists(entry.code)
                        self._set_entry_feedback_embedding(entry)
        
        # whether succeed or not, add this to the records, and update the entry counter
        self.entry_counter += 1
        self.llm_call_counter += 1
        entry_dict = entry.to_dict()
        entry_dict['obs_combo'] = obs_combo
        write_to_file(
            dest_path=self.record_jsonl_path,
            contents=json.dumps(entry_dict) + "\n",
            is_append=True,
            is_json=False
        )

    def optimize(self):
        while self.llm_call_counter < self.tot_llm_call_num:
            self._create_entry()
        write_to_file(
            dest_path=self.statistics_json_path,
            contents=self.to_dict(),
            is_append=False,
            is_json=True
        )

    def to_dict(self) -> dict:
        return {
            # user-defined parameters
            "tot_llm_call_num": self.tot_llm_call_num,
            # obs_combo_list
            "nl_solution_json_path": self.nl_solution_json_path,
            "obs_combo_list": self.obs_combo_list,
            "obs_combo_counter": self.obs_combo_counter,
            # agents
            "design_agent": self.design_agent.to_dict(),
            "code_agent": self.code_agent.to_dict(),
            # feedback_embedder
            "feedback_embedder": self.feedback_embedder.to_dict(),
            # structure
            "entry_counter": self.entry_counter,
            "llm_call_counter": self.llm_call_counter,
            # static
            "record_jsonl_path": self.record_jsonl_path,
            "statistics_json_path": self.statistics_json_path,
        }