from GPRSimulator import GPRSimulator
from Entry import Entry, EntryType
from nips2025.Design import Design
from nips2025.cache import CacheConfig
from nips2025.FeedbackEmbedder import FeedbackEmbedder
from Agent import AgentConfig, Agent, AgentSupplierType
from nips2025.Parser import DesignJsonParser, CacheCodeParser, MapParser
from nips2025.Simulator import SimulatorCache, SimulatorConfig
from utils import write_to_file
from KeywordList import KeywordList

import os
import logging
import logging_config
from typing import List
import random
random.seed(42)

class RSDictSF:
    # output paths
    record_jsonl_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log", "record.jsonl")
    statistics_json_path = record_jsonl_path.replace("record.jsonl", "statistics.json")
    def __init__(
        self,
        # cache simulator params
        cache_capacity: int,
        cache_trace_path: str,
        # agent params
        agent_supplier: AgentSupplierType,
        # feedback embedder params
        test_folder: str,
        trace_filter,
        # optimize params
        tot_llm_call_num: int,
    ):
        print("RSDict")
        random.seed(42)
        # user-defined parameters: check and store
        assert cache_capacity >= 1
        assert os.path.exists(cache_trace_path)
        assert os.path.exists(test_folder)
        assert isinstance(tot_llm_call_num, int) and tot_llm_call_num >= 0
        self.tot_llm_call_num = tot_llm_call_num
        # structure
        self.entry_counter = 0
        self.llm_call_counter = 0
        self.keyword_map = dict()
        self.entry_list: List[Entry] = list()
        with open(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "en_dict", "en_stopwords.txt"), 'r') as file:
            stop_word_text = file.read()
            stop_word_set = set([w.strip().lower() for w in stop_word_text.split(",")])
        self.word_set = set()
        with open(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "en_dict", "en_3000.txt"), 'r') as file:
            for l in file:
                word = str(l).lower().strip()
                if word not in stop_word_set and word not in self.word_set:
                    self.word_set.add(word)
        # feedback_embedder
        self.feedback_embedder: FeedbackEmbedder = FeedbackEmbedder(test_folder=test_folder, trace_filter=trace_filter)
        # simulator
        cache_simulator = SimulatorCache(
            SimulatorConfig(
                name="Cache",
                config=CacheConfig(
                    capacity=cache_capacity,
                    consider_obj_size=False,
                    trace_path=cache_trace_path,
                    key_col_id=1,
                    size_col_id=2,
                    has_header=False,
                    delimiter=","
                ),
                system_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache"),
                tune_runs=20,
                code_folder=os.path.join(os.path.dirname(os.path.abspath(__file__)), "log", "code"),
                tune_int_upper=None
            )
        )
        # agents
        self.map_agent = Agent(AgentConfig(agent_name="Map", temperature=1.0, answer_parser=MapParser(), agent_supplier=agent_supplier))
        self.design_agent = Agent(AgentConfig(agent_name="Design", temperature=1.0, answer_parser=DesignJsonParser(), agent_supplier=agent_supplier))
        self.code_agent = Agent(AgentConfig(agent_name="Code", temperature=0.5, trial_num=3, answer_parser=CacheCodeParser(unique_simulator=cache_simulator), agent_supplier=AgentSupplierType.OPENAI))
        # prompts
        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "prompt", "waypoint_reasoning.txt"), 'r') as file:
            self.MAP_PROMPT_TEMPL = file.read().strip()
        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "prompt", "solution_formulation_with_hints.txt"), 'r') as file:
            self.DESIGN_PROMPT_TEMPL = file.read().strip()
        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "prompt", "code_generation.txt"), 'r') as file:
            self.CODE_PROMPT_TEMPL = file.read().strip()

        self.gpr = GPRSimulator(
            embedding_type="sum_obs",
            reduce_feature_dim=None,
            warmup=100,
            window_size=2, # power-of-two
            kernel="dotproduct"
        )

    def _set_entry_feedback_embedding(self, entry: Entry):
        assert entry.code != None
        if not entry.code.endswith(".py"):
            # sota
            code = entry.code
            is_sota = True
        else:
            assert os.path.exists(entry.code)
            with open(entry.code, 'r') as file:
                code = file.read()
            is_sota = False
        entry.feedback_embedding = self.feedback_embedder.get_feedback_embedding(code, is_sota)
        return entry.feedback_embedding

    def _create_map_from_word(self, word: str):
        '''
        Return:
        - str: the description from the word
        - None: if the word cannot generate description (parser failure)
        '''
        word_key = word.strip().lower()
        if word_key in self.keyword_map:
            return self.keyword_map[word_key]
        map_prompt = self.MAP_PROMPT_TEMPL.replace("[[word]]", word_key)
        descrip = self.map_agent.answer(map_prompt)
        self.keyword_map[word_key] = descrip
        return descrip

    def _create_map_from_keyword_list(self, keywords_list: KeywordList):
        '''
        Return: a formatted string of descriptions
        '''
        descrip_list = []
        for keyword in keywords_list.keyword_list:
            descrip = self._create_map_from_word(keyword)
            if descrip != None:
                descrip_list.append(descrip)
        if len(descrip_list) == 0:
            return "No hints."
        return "\n".join([f"- {d}" for d in descrip_list])

    def _create_entry(self, hint_list):
        '''
        Create an entry from a given keyword set `answer`, or a give combined design `answer`.
        Update:
        - self.entry_counter: +1
        - self.llm_call_counter: +1
        '''
        logging.info(f"Creating Entry {self.entry_counter}...")

        entry = Entry(id=self.entry_counter, entry_type=EntryType.RSDICT_SF)
        hints = ",".join(hint_list) if (hint_list != None and all([h != None for h in hint_list])) else None
        if hints != None:
            entry.hints = KeywordList(hints)
            hint_description = self._create_map_from_keyword_list(entry.hints)
            design_prompt = self.DESIGN_PROMPT_TEMPL.replace("[[hints]]", hint_description)
            design_dict = self.design_agent.answer(design_prompt)
            if design_dict != None:
                entry.design = Design(design_dict)
                # set code, miss_ratio_info
                if entry.design != None:
                    code_prompt = (self.CODE_PROMPT_TEMPL
                                    .replace("[[design]]", entry.design.to_str())
                                    .replace("[[metadata]]", entry.design._format(entry.design.metadata, False))
                                    .replace("[[evict]]", entry.design._format(entry.design.evict, False))
                                    .replace("[[update_after_hit]]", entry.design._format(entry.design.update_after_hit, False))
                                    .replace("[[update_after_insert]]", entry.design._format(entry.design.update_after_insert, False))
                                    .replace("[[update_after_evict]]", entry.design._format(entry.design.update_after_evict, False)))
                    assert isinstance(self.code_agent.answer_parser, CacheCodeParser)
                    self.code_agent.answer_parser.set_code_id(self.entry_counter)
                    miss_ratio_info_tuple = self.code_agent.answer(code_prompt)
                    if miss_ratio_info_tuple != None:
                        entry.code = self.code_agent.answer_parser.simulator.code_path
                        assert os.path.exists(entry.code)
                        self._set_entry_feedback_embedding(entry)
        
        # whether succeed or not, add this to the records, and update the entry counter
        self.entry_counter += 1
        self.llm_call_counter += 1
        self.entry_list.append(entry)
        write_to_file(
            dest_path=self.record_jsonl_path,
            contents=entry.to_jsonl() + "\n",
            is_append=True,
            is_json=False
        )


    def gpr_hint(self):
        if self.llm_call_counter < self.gpr.warmup:
            logging.info(f"mode = warmup")
            return random.sample(list(self.word_set), 4)
        hint_list_list = [
            random.sample(list(self.word_set), 4)
            for _ in range(self.gpr.window_size)
        ]
        obs_list_list = [
            self._create_map_from_keyword_list(KeywordList(",".join(hint_list)))
            for hint_list in hint_list_list
        ]
        temp_entry_list: List[Entry] = []
        for i, obs in enumerate(obs_list_list):
            t_entry = Entry(
                id=self.entry_counter + i,
                entry_type=EntryType.RSDICT_SF,
            )
            t_entry.hints = KeywordList(",".join(hint_list_list[i]))
            temp_entry_list.append(t_entry)

        data_X = self.gpr.load_data(
            entries=list([e.to_dict() for e in self.entry_list] + [e.to_dict() for e in temp_entry_list]),
            m_hint_obs=self.keyword_map,
        )
        if (self.llm_call_counter % self.gpr.warmup) > (self.gpr.warmup * 0.5):
            mode = "exploit"
        else:
            mode = "explore"

        logging.info(f"mode = {mode}")

        training_X, training_y, predicting_X = self.gpr._convert_to_vectors(
            training_entries=[e.to_dict() for e in self.entry_list],
            predicting_entries=[e.to_dict() for e in temp_entry_list],
            data_X=data_X
        )

        predicting_y, _ = self.gpr._predict(
            training_X=training_X,
            training_y=training_y,
            predicting_X=predicting_X
        )

        assert len(predicting_y) == len(temp_entry_list)

        # select the top1 entry
        top1_idx = self.gpr._top1_idx(
            training_y=training_y,
            predicting_y=predicting_y,
            explore_exploit=mode,
        )
        assert 0 <= top1_idx < len(temp_entry_list)
        top1_hint_list= hint_list_list[top1_idx]
        return top1_hint_list    
        
    def optimize(self):
        while self.llm_call_counter < self.tot_llm_call_num:
            hint_list = self.gpr_hint()
            self._create_entry(hint_list=hint_list)
        write_to_file(
            dest_path=self.statistics_json_path,
            contents=self.to_dict(),
            is_append=False,
            is_json=True
        )

    def to_dict(self) -> dict:
        return {
            # user-defined parameters
            "tot_llm_call_num": self.tot_llm_call_num,
            # keyword map
            "keyword_map": self.keyword_map,
            # agents
            "map_agent": self.map_agent.to_dict(),
            "design_agent": self.design_agent.to_dict(),
            "code_agent": self.code_agent.to_dict(),
            # feedback_embedder
            "feedback_embedder": self.feedback_embedder.to_dict(),
            # structure
            "entry_counter": self.entry_counter,
            "llm_call_counter": self.llm_call_counter,
            # static
            "record_jsonl_path": self.record_jsonl_path,
            "statistics_json_path": self.statistics_json_path,
        }