from Entry import Entry, EntryType
from BaseDesign import BaseDesign
from BaseFeedbackEmbedder import BaseFeedbackEmbedder
from Agent import AgentConfig, Agent, AgentSupplierType
from BaseParser import DefaultParser, BaseParser
from utils import write_to_file
from KeywordList import KeywordList
from GPRSimulator import GPRSimulator

import os
import logging
import logging_config
import random
from abc import abstractmethod, ABC
from typing import List
random.seed(42)

class BaseMetaMuse(ABC):
    record_jsonl_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "log", "record.jsonl")
    statistics_json_path = record_jsonl_path.replace("record.jsonl", "statistics.json")
    def __init__(
        self,
        agent_supplier: AgentSupplierType,
        tot_llm_call_num: int,
        feedback_embedder: BaseFeedbackEmbedder, # define your own feedback embedder
        map_agent_parser: BaseParser, # answer parser for the agent for property-extractiona & problem mapping
        design_agent_pareser: BaseParser, # answer parser for the agent for solution formulation
        code_agent_parser: BaseParser, # answer parser for the agent for code generation
        map_prompt_path: str, # path to the prompt; the placeholder for the keyword is [[]]
        design_prompt_path: str, # the placeholder for the observations is [[hints]]
        code_prompt_path: str,
        entry_type: EntryType, # the type of the entry
        design_type: BaseDesign, # the derived type of your design
        use_gpr: bool
    ):
        assert isinstance(tot_llm_call_num, int) and tot_llm_call_num >= 0
        self.tot_llm_call_num = tot_llm_call_num
        
        assert isinstance(entry_type, EntryType)
        self.entry_type = entry_type

        self.design_type = design_type

        self.use_gpr = use_gpr
        
        # counters
        self.entry_counter = 0
        self.llm_call_counter = 0

        # keywords
        self.keyword_map = dict() # keyword -> observation
        with open(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "en_dict", "en_stopwords.txt"), 'r') as file:
            stop_word_text = file.read()
            stop_word_set = set([w.strip().lower() for w in stop_word_text.split(",")])
        self.word_set = set() # dictionary
        with open(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "en_dict", "en_3000.txt"), 'r') as file:
            for l in file:
                word = str(l).lower().strip()
                if word not in stop_word_set and word not in self.word_set:
                    self.word_set.add(word)

        # feedback embedder
        self.feedback_embedder: BaseFeedbackEmbedder = feedback_embedder

        # agents
        self.map_agent = Agent(AgentConfig(agent_name="Map", temperature=1.0, answer_parser=map_agent_parser, agent_supplier=agent_supplier))
        self.design_agent = Agent(AgentConfig(agent_name="Design", temperature=1.0, answer_parser=design_agent_pareser, agent_supplier=agent_supplier))
        self.code_agent = Agent(AgentConfig(agent_name="Code", temperature=0.5, trial_num=3, answer_parser=code_agent_parser, agent_supplier=AgentSupplierType.OPENAI))

        # prompts
        with open(map_prompt_path, 'r') as file:
            self.MAP_PROMPT_TEMPL = file.read().strip()
        with open(design_prompt_path, 'r') as file:
            self.DESIGN_PROMPT_TEMPL = file.read().strip()
        with open(code_prompt_path, 'r') as file:
            self.CODE_PROMPT_TEMPL = file.read().strip()

        # gpr
        self.gpr = GPRSimulator(
            embedding_type="sum_obs",
            reduce_feature_dim=None,
            warmup=100,
            window_size=2, # power-of-two
            kernel="dotproduct"
        )
        self.entry_list: List[Entry] = list()

    @abstractmethod
    def _set_entry_feedback_embedding(self, entry: Entry):
        '''
        Set the feedback_embedding of a complete entry.
        '''
        # TODO
        return entry.feedback_embedding

    @abstractmethod
    def _formulate_code_generation_prompt(self, design: BaseDesign):
        '''
        Formulate the code generation prompt using self.CODE_PROMP_TEMPL, given the design.
        '''
        # TODO
        return self.CODE_PROMPT_TEMPL

    @abstractmethod
    def _code_generation(self, code_prompt: str, entry: Entry):
        '''
        Generate and test the code given the code_prompt.
        If the code is valid, store it to some place, and record its path in entry.code;
        Otherwise, entry.code=None
        Return: whether the code is valid or not.
        '''
        ## TODO
        ## Example: 
        # result = self.code_agent.answer(code_prompt)
        # if result != None:  # valid code
        #     entry.code = get_code_path()
        #     assert os.path.exists(entry.code)
        #     return True
        # else:
        #     return False
        return True
        

    def _create_map_from_word(self, word: str):
        '''
        Return:
        - str: the description from the word
        - None: if the word cannot generate description (parser failure)
        '''
        word_key = word.strip().lower()
        if word_key in self.keyword_map:
            return self.keyword_map[word_key]
        map_prompt = self.MAP_PROMPT_TEMPL.replace("[[word]]", word_key)
        descrip = self.map_agent.answer(map_prompt)
        self.keyword_map[word_key] = descrip
        return descrip

    def _create_map_from_keyword_list(self, keywords_list: KeywordList):
        '''
        Return: a formatted string of descriptions
        '''
        descrip_list = []
        for keyword in keywords_list.keyword_list:
            descrip = self._create_map_from_word(keyword)
            if descrip != None:
                descrip_list.append(descrip)
        if len(descrip_list) == 0:
            return "No hints."
        return "\n".join([f"- {d}" for d in descrip_list])

    def _create_entry(self, hint_list=None):
        '''
        Create an entry from a given keyword set (i.e., external stimuli).
        '''
        logging.info(f"Creating Entry {self.entry_counter}...")

        entry = Entry(id=self.entry_counter, entry_type=self.entry_type)
        if hint_list == None:
            hint_list = random.sample(list(self.word_set), 4)
        hints = ",".join(hint_list) if (hint_list != None and all([h != None for h in hint_list])) else None
        if hints != None:
            entry.hints = KeywordList(hints)
            hint_description = self._create_map_from_keyword_list(entry.hints)
            design_prompt = self.DESIGN_PROMPT_TEMPL.replace("[[hints]]", hint_description)
            design_dict = self.design_agent.answer(design_prompt)
            if design_dict != None:
                entry.design = self.design_type(design_dict)
                # set code, miss_ratio_info
                if entry.design != None:
                    code_prompt = self._formulate_code_generation_prompt(design=entry.design)
                    is_valid_code = self._code_generation(code_prompt, entry)
                    if is_valid_code == True:
                        self._set_entry_feedback_embedding(entry)
        
        # whether succeed or not, add this to the records, and update the entry counter
        self.entry_counter += 1
        self.llm_call_counter += 1
        self.entry_list.append(entry)
        write_to_file(
            dest_path=self.record_jsonl_path,
            contents=entry.to_jsonl() + "\n",
            is_append=True,
            is_json=False
        )

    def gpr_hint(self):
        if self.llm_call_counter < self.gpr.warmup:
            logging.info(f"mode = warmup")
            return random.sample(list(self.word_set), 4)
        hint_list_list = [
            random.sample(list(self.word_set), 4)
            for _ in range(self.gpr.window_size)
        ]
        obs_list_list = [
            self._create_map_from_keyword_list(KeywordList(",".join(hint_list)))
            for hint_list in hint_list_list
        ]
        temp_entry_list: List[Entry] = []
        for i, obs in enumerate(obs_list_list):
            t_entry = Entry(
                id=self.entry_counter + i,
                entry_type=EntryType.RSDICT_SF,
            )
            t_entry.hints = KeywordList(",".join(hint_list_list[i]))
            temp_entry_list.append(t_entry)

        data_X = self.gpr.load_data(
            entries=list([e.to_dict() for e in self.entry_list] + [e.to_dict() for e in temp_entry_list]),
            m_hint_obs=self.keyword_map,
        )
        if (self.llm_call_counter % self.gpr.warmup) > (self.gpr.warmup * 0.5):
            mode = "exploit"
        else:
            mode = "explore"

        logging.info(f"mode = {mode}")

        training_X, training_y, predicting_X = self.gpr._convert_to_vectors(
            training_entries=[e.to_dict() for e in self.entry_list],
            predicting_entries=[e.to_dict() for e in temp_entry_list],
            data_X=data_X
        )

        predicting_y, _ = self.gpr._predict(
            training_X=training_X,
            training_y=training_y,
            predicting_X=predicting_X
        )

        assert len(predicting_y) == len(temp_entry_list)

        # select the top1 entry
        top1_idx = self.gpr._top1_idx(
            training_y=training_y,
            predicting_y=predicting_y,
            explore_exploit=mode,
        )
        assert 0 <= top1_idx < len(temp_entry_list)
        top1_hint_list= hint_list_list[top1_idx]
        return top1_hint_list 

    def optimize(self):
        while self.llm_call_counter < self.tot_llm_call_num:
            if self.use_gpr == False:
                self._create_entry()
            else:
                hint_list = self.gpr_hint()
                self._create_entry(hint_list=hint_list)
        write_to_file(
            dest_path=self.statistics_json_path,
            contents=self.to_dict(),
            is_append=False,
            is_json=True
        )

    def to_dict(self) -> dict:
        return {
            # user-defined parameters
            "tot_llm_call_num": self.tot_llm_call_num,
            # keyword map
            "keyword_map": self.keyword_map,
            # agents
            "map_agent": self.map_agent.to_dict(),
            "design_agent": self.design_agent.to_dict(),
            "code_agent": self.code_agent.to_dict(),
            # feedback_embedder
            "feedback_embedder": self.feedback_embedder.to_dict(),
            # structure
            "entry_counter": self.entry_counter,
            "llm_call_counter": self.llm_call_counter,
            # static
            "record_jsonl_path": self.record_jsonl_path,
            "statistics_json_path": self.statistics_json_path,
        }