from source.language_models.prompt_templates import generation_prompt

class Transformation():
    def __init__(self, query, input_thoughts, input_documents, output_size, model=None):
        self.input_thoughts = input_thoughts
        self.input_documents = input_documents
        self.output_size = output_size
        self.query = query
        self.model = model
    
        self.temperature = 0.95
        self.top_p = 0.73
        self.top_k = 0
        self.repetition_penalty = 1.4
        self.max_new_tokens = 100
    
        # print("input_thoughts : ", input_thoughts)
        # print("input_documents : ", input_documents)
    
    def apply(self, LLM, graph):
        thoughts, documents = self.fetch_content(graph)
        prompt = generation_prompt(self.query, thoughts, documents, self.model)
        try:
            prompt_list, output_list = LLM.query(prompt,self.output_size, temperature=self.temperature, top_p=self.top_p, top_k=self.top_k, repetition_penalty=self.repetition_penalty, max_new_tokens=self.max_new_tokens)
        except:
            raise Exception("prompt : ", prompt)
        return prompt_list, output_list
    
    def fetch_content(self, graph):
        thoughts = []
        documents = []
        for t in self.input_thoughts:
            if t !=0:
                thoughts  += [graph.nodes[t]["information"]]
        for d in self.input_documents:
            documents += [graph.nodes[d]["information"]]
        return thoughts, documents
    

    