from GPRSimulator import GPRSimulator
from Entry import Entry, EntryType
from nips2025.Design import Design
from nips2025.cache import CacheConfig
from nips2025.FeedbackEmbedder import FeedbackEmbedder
from Agent import AgentConfig, Agent, AgentSupplierType
from nips2025.Parser import DesignJsonParser, CacheCodeParser
from nips2025.Simulator import SimulatorCache, SimulatorConfig
from utils import write_to_file
from KeywordList import KeywordList

import os
import logging
import logging_config
from typing import List
import random
random.seed(42)

class RSDictSFNoWR:
    # output paths
    record_jsonl_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log", "record.jsonl")
    statistics_json_path = record_jsonl_path.replace("record.jsonl", "statistics.json")
    def __init__(
        self,
        # cache simulator params
        cache_capacity: int,
        cache_trace_path: str,
        # agent params
        agent_supplier: AgentSupplierType,
        # feedback embedder params
        test_folder: str,
        trace_filter,
        # optimize params
        tot_llm_call_num: int,
    ):
        print("RSDict")
        random.seed(42)
        # user-defined parameters: check and store
        assert cache_capacity >= 1
        assert os.path.exists(cache_trace_path)
        assert os.path.exists(test_folder)
        assert isinstance(tot_llm_call_num, int) and tot_llm_call_num >= 0
        self.tot_llm_call_num = tot_llm_call_num
        # structure
        self.entry_counter = 0
        self.llm_call_counter = 0
        self.entry_list: List[Entry] = list()
        with open(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "en_dict", "en_stopwords.txt"), 'r') as file:
            stop_word_text = file.read()
            stop_word_set = set([w.strip().lower() for w in stop_word_text.split(",")])
        self.word_set = set()
        with open(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "en_dict", "en_3000.txt"), 'r') as file:
            for l in file:
                word = str(l).lower().strip()
                if word not in stop_word_set and word not in self.word_set:
                    self.word_set.add(word)
        # feedback_embedder
        self.feedback_embedder: FeedbackEmbedder = FeedbackEmbedder(test_folder=test_folder, trace_filter=trace_filter)
        # simulator
        cache_simulator = SimulatorCache(
            SimulatorConfig(
                name="Cache",
                config=CacheConfig(
                    capacity=cache_capacity,
                    consider_obj_size=False,
                    trace_path=cache_trace_path,
                    key_col_id=1,
                    size_col_id=2,
                    has_header=False,
                    delimiter=","
                ),
                system_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache"),
                tune_runs=20,
                code_folder=os.path.join(os.path.dirname(os.path.abspath(__file__)), "log", "code"),
                tune_int_upper=None
            )
        )
        # agents
        self.design_agent = Agent(AgentConfig(agent_name="Design", temperature=1.0, answer_parser=DesignJsonParser(), agent_supplier=agent_supplier))
        self.code_agent = Agent(AgentConfig(agent_name="Code", temperature=0.5, trial_num=3, answer_parser=CacheCodeParser(unique_simulator=cache_simulator), agent_supplier=AgentSupplierType.OPENAI))
        # prompts
        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "prompt", "solution_formulation_with_hints.txt"), 'r') as file:
            self.DESIGN_PROMPT_TEMPL = file.read().strip()
        with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), "prompt", "code_generation.txt"), 'r') as file:
            self.CODE_PROMPT_TEMPL = file.read().strip()

        self.gpr = GPRSimulator(
            embedding_type="sum_hint",
            reduce_feature_dim=None,
            warmup=100,
            window_size=2, # power-of-two
            kernel="dotproduct"
        )

    def _set_entry_feedback_embedding(self, entry: Entry):
        assert entry.code != None
        if not entry.code.endswith(".py"):
            # sota
            code = entry.code
            is_sota = True
        else:
            assert os.path.exists(entry.code)
            with open(entry.code, 'r') as file:
                code = file.read()
            is_sota = False
        entry.feedback_embedding = self.feedback_embedder.get_feedback_embedding(code, is_sota)
        return entry.feedback_embedding


    def _create_entry(self, hint_list):
        '''
        Create an entry from a given keyword set `answer`, or a give combined design `answer`.
        Update:
        - self.entry_counter: +1
        - self.llm_call_counter: +1
        '''
        logging.info(f"Creating Entry {self.entry_counter}...")

        entry = Entry(id=self.entry_counter, entry_type=EntryType.RSDICT_SF_NOWR)
        hints = ",".join(hint_list) if (hint_list != None and all([h != None for h in hint_list])) else None
        if hints != None:
            entry.hints = KeywordList(hints)
            hint_description = entry.hints.to_str()
            design_prompt = self.DESIGN_PROMPT_TEMPL.replace("[[hints]]", hint_description)
            design_dict = self.design_agent.answer(design_prompt)
            if design_dict != None:
                entry.design = Design(design_dict)
                # set code, miss_ratio_info
                if entry.design != None:
                    code_prompt = (self.CODE_PROMPT_TEMPL
                                    .replace("[[design]]", entry.design.to_str())
                                    .replace("[[metadata]]", entry.design._format(entry.design.metadata, False))
                                    .replace("[[evict]]", entry.design._format(entry.design.evict, False))
                                    .replace("[[update_after_hit]]", entry.design._format(entry.design.update_after_hit, False))
                                    .replace("[[update_after_insert]]", entry.design._format(entry.design.update_after_insert, False))
                                    .replace("[[update_after_evict]]", entry.design._format(entry.design.update_after_evict, False)))
                    assert isinstance(self.code_agent.answer_parser, CacheCodeParser)
                    self.code_agent.answer_parser.set_code_id(self.entry_counter)
                    miss_ratio_info_tuple = self.code_agent.answer(code_prompt)
                    if miss_ratio_info_tuple != None:
                        entry.code = self.code_agent.answer_parser.simulator.code_path
                        assert os.path.exists(entry.code)
                        self._set_entry_feedback_embedding(entry)
        
        # whether succeed or not, add this to the records, and update the entry counter
        self.entry_counter += 1
        self.llm_call_counter += 1
        self.entry_list.append(entry)
        write_to_file(
            dest_path=self.record_jsonl_path,
            contents=entry.to_jsonl() + "\n",
            is_append=True,
            is_json=False
        )

    def gpr_hint(self):
        if self.llm_call_counter < self.gpr.warmup:
            logging.info(f"mode = warmup")
            return random.sample(list(self.word_set), 4)
        hint_list_list = [
            random.sample(list(self.word_set), 4)
            for _ in range(self.gpr.window_size)
        ]
        
        temp_entry_list: List[Entry] = []
        for i, hint_list in enumerate(hint_list_list):
            t_entry = Entry(
                id=self.entry_counter + i,
                entry_type=EntryType.RSDICT_SF_NOWR,
            )
            t_entry.hints = KeywordList(",".join(hint_list))
            temp_entry_list.append(t_entry)

        data_X = self.gpr.load_data(
            entries=list([e.to_dict() for e in self.entry_list] + [e.to_dict() for e in temp_entry_list]),
            m_hint_obs={h: None for e in self.entry_list + temp_entry_list for h in e.hints.keyword_list},
        )
        if (self.llm_call_counter % self.gpr.warmup) > (self.gpr.warmup * 0.5):
            mode = "exploit"
        else:
            mode = "explore"

        logging.info(f"mode = {mode}")

        training_X, training_y, predicting_X = self.gpr._convert_to_vectors(
            training_entries=[e.to_dict() for e in self.entry_list],
            predicting_entries=[e.to_dict() for e in temp_entry_list],
            data_X=data_X
        )

        predicting_y, _ = self.gpr._predict(
            training_X=training_X,
            training_y=training_y,
            predicting_X=predicting_X
        )

        assert len(predicting_y) == len(temp_entry_list)

        # select the top1 entry
        top1_idx = self.gpr._top1_idx(
            training_y=training_y,
            predicting_y=predicting_y,
            explore_exploit=mode,
        )
        assert 0 <= top1_idx < len(temp_entry_list)
        top1_hint_list= hint_list_list[top1_idx]
        return top1_hint_list    
        
    def optimize(self):
        while self.llm_call_counter < self.tot_llm_call_num:
            hint_list = self.gpr_hint()
            self._create_entry(hint_list=hint_list)
        write_to_file(
            dest_path=self.statistics_json_path,
            contents=self.to_dict(),
            is_append=False,
            is_json=True
        )

    def to_dict(self) -> dict:
        return {
            # user-defined parameters
            "tot_llm_call_num": self.tot_llm_call_num,
            # agents
            "design_agent": self.design_agent.to_dict(),
            "code_agent": self.code_agent.to_dict(),
            # feedback_embedder
            "feedback_embedder": self.feedback_embedder.to_dict(),
            # structure
            "entry_counter": self.entry_counter,
            "llm_call_counter": self.llm_call_counter,
            # static
            "record_jsonl_path": self.record_jsonl_path,
            "statistics_json_path": self.statistics_json_path,
        }