from Entry import Entry, EntryType
from nips2025.Design import Design
from nips2025.cache import CacheConfig
from nips2025.FeedbackEmbedder import FeedbackEmbedder
from Agent import AgentConfig, Agent, AgentSupplierType
from nips2025.Parser import DesignJsonParser, CacheCodeParser
from nips2025.Simulator import SimulatorCache, SimulatorConfig
from utils import write_to_file

import os
import logging
import logging_config
import random
random.seed(42)

class RSDesign:
    # output paths
    record_jsonl_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log", "record.jsonl")
    statistics_json_path = record_jsonl_path.replace("record.jsonl", "statistics.json")
    def __init__(
        self,
        # cache simulator params
        cache_capacity: int,
        cache_trace_path: str,
        # agent params
        agent_supplier: AgentSupplierType,
        # feedback embedder params
        test_folder: str,
        trace_filter,
        # optimize params
        tot_llm_call_num: int,
    ):
        print("RSDesign")
        # user-defined parameters: check and store
        assert cache_capacity >= 1
        assert os.path.exists(cache_trace_path)
        assert os.path.exists(test_folder)
        assert isinstance(tot_llm_call_num, int) and tot_llm_call_num >= 0
        self.tot_llm_call_num = tot_llm_call_num
        # structure
        self.entry_counter = 0
        self.llm_call_counter = 0
        # feedback_embedder
        self.feedback_embedder: FeedbackEmbedder = FeedbackEmbedder(test_folder=test_folder, trace_filter=trace_filter)
        # simulator
        cache_simulator = SimulatorCache(
            SimulatorConfig(
                name="Cache",
                config=CacheConfig(
                    capacity=cache_capacity,
                    consider_obj_size=False,
                    trace_path=cache_trace_path,
                    key_col_id=1,
                    size_col_id=2,
                    has_header=False,
                    delimiter=","
                ),
                system_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache"),
                tune_runs=20,
                code_folder=os.path.join(os.path.dirname(os.path.abspath(__file__)), "log", "code"),
                tune_int_upper=None
            )
        )
        # agents
        self.design_agent = Agent(AgentConfig(agent_name="Design", temperature=1.0, answer_parser=DesignJsonParser(), agent_supplier=agent_supplier))
        self.code_agent = Agent(AgentConfig(agent_name="Code", temperature=0.5, trial_num=3, answer_parser=CacheCodeParser(unique_simulator=cache_simulator), agent_supplier=AgentSupplierType.OPENAI))
        # prompts
        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "prompt", "solution_formulation_without_hints.txt"), 'r') as file:
            self.DESIGN_PROMPT_TEMPL = file.read().strip()
        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "prompt", "code_generation.txt"), 'r') as file:
            self.CODE_PROMPT_TEMPL = file.read().strip()

    def _set_entry_feedback_embedding(self, entry: Entry):
        assert entry.code != None
        if not entry.code.endswith(".py"):
            # sota
            code = entry.code
            is_sota = True
        else:
            assert os.path.exists(entry.code)
            with open(entry.code, 'r') as file:
                code = file.read()
            is_sota = False
        entry.feedback_embedding = self.feedback_embedder.get_feedback_embedding(code, is_sota)
        return entry.feedback_embedding


    def _create_entry(self):
        '''
        Create an entry from a given keyword set `answer`, or a give combined design `answer`.
        Update:
        - self.entry_counter: +1
        - self.llm_call_counter: +1
        '''
        logging.info(f"Creating Entry {self.entry_counter}...")
        entry = Entry(id=self.entry_counter, entry_type=EntryType.RS)
        
        design_prompt = self.DESIGN_PROMPT_TEMPL
        design_dict = self.design_agent.answer(design_prompt)
        if design_dict != None:
            entry.design = Design(design_dict)
        # set code, miss_ratio_info
        if entry.design != None:
            code_prompt = (self.CODE_PROMPT_TEMPL
                            .replace("[[design]]", entry.design.to_str())
                            .replace("[[metadata]]", entry.design._format(entry.design.metadata, False))
                            .replace("[[evict]]", entry.design._format(entry.design.evict, False))
                            .replace("[[update_after_hit]]", entry.design._format(entry.design.update_after_hit, False))
                            .replace("[[update_after_insert]]", entry.design._format(entry.design.update_after_insert, False))
                            .replace("[[update_after_evict]]", entry.design._format(entry.design.update_after_evict, False)))
            assert isinstance(self.code_agent.answer_parser, CacheCodeParser)
            self.code_agent.answer_parser.set_code_id(self.entry_counter)
            miss_ratio_info_tuple = self.code_agent.answer(code_prompt)
            if miss_ratio_info_tuple != None:
                entry.code = self.code_agent.answer_parser.simulator.code_path
                assert os.path.exists(entry.code)
                self._set_entry_feedback_embedding(entry)
        
        # whether succeed or not, add this to the records, and update the entry counter
        self.entry_counter += 1
        self.llm_call_counter += 1
        write_to_file(
            dest_path=self.record_jsonl_path,
            contents=entry.to_jsonl() + "\n",
            is_append=True,
            is_json=False
        )

    def optimize(self):
        while self.llm_call_counter < self.tot_llm_call_num:
            self._create_entry()
        write_to_file(
            dest_path=self.statistics_json_path,
            contents=self.to_dict(),
            is_append=False,
            is_json=True
        )

    def to_dict(self) -> dict:
        return {
            # user-defined parameters
            "tot_llm_call_num": self.tot_llm_call_num,
            # agents
            "design_agent": self.design_agent.to_dict(),
            "code_agent": self.code_agent.to_dict(),
            # signatary
            "signatary": self.feedback_embedder.to_dict(),
            # structure
            "entry_counter": self.entry_counter,
            "llm_call_counter": self.llm_call_counter,
            # static
            "record_jsonl_path": self.record_jsonl_path,
            "statistics_json_path": self.statistics_json_path,
        }