import ast, logging
import os, re
from typing import List
from genai.prompt_pattern import PromptPattern

from .config import Config
from .constants import *

logger = logging.getLogger(LOGGER_NAME)

class Prompt:

    INSTRUCTIONS_KW, EXAMPLES_KW, INPUT_KW = 'instructions', 'examples', 'input'

    def __init__(self, config : Config):
        
        self.config = config

        prompt_contents = _load_prompt_contents(config)

        self.instructions = prompt_contents.get(Prompt.INSTRUCTIONS_KW, '').strip()
        self.examples = prompt_contents.get(Prompt.EXAMPLES_KW, [])
        self.initial = prompt_contents.get(Prompt.INPUT_KW, None)

        assert self.initial is not None, 'Must specify \'initial\' text!'  
        input_var = '{{' + Prompt.INPUT_KW + '}}'
        assert input_var in self.initial, f'Must specify {input_var} in prompt!'

        self.update_prompt()

        self.bop_text = [x for x in self.initial.split(input_var) if x][0].strip()

        assert self._prompt_str, f'Empty prompt for agent-type {config.agent_type} and dataset {config.dataset}!'

        self._prompt_obj = PromptPattern.from_str(self._prompt_str)

    def update_prompt(self):
        self._prompt_str = ''
        if self.instructions: self._prompt_str += self.instructions + '\n\n'
        if self.config.few_shot_k: self._prompt_str += '\n\n'.join([_clean_text_content(ex) for ex in self.examples[:self.config.few_shot_k]]) + '\n\n'
        self._prompt_str += self.initial.strip()

    def update_examples(self, new_examples : List):
        self.examples = new_examples
        self.update_prompt()
        
    def fill_prompt(self, **kw_args):
        filled_prompt = PromptPattern.from_str(self._prompt_str)
        for k, v in kw_args.items(): 
            filled_prompt = filled_prompt.sub(k, v)
        return str(filled_prompt)

def _clean_text_content(text : str):
    tab_sep = '[TAB_SEP]'
    text = text.strip().replace('\t', tab_sep)
    return '\n'.join(line.strip().replace(tab_sep, '\t') for line in text.split('\n'))

def _load_prompt_contents(config : Config):
    def _proc_content(f_content):
        if type(f_content) == dict:
            new_content = dict()
            for k, v in f_content.items():
                new_content[k] = _proc_content(v)
        elif type(f_content) == list:
            new_content = [_proc_content(v) for v in f_content]
        else:
            new_content = _clean_text_content(f_content)
        return new_content

    agent_type = config.agent_type
    if 'ablation' in agent_type:
        new_agent_type = '_'.join(agent_type.split('_')[:-1])
        logger.warning(f'Agent prompt for {agent_type} is being adjusted to {new_agent_type}!')
        agent_type = new_agent_type

    with open(config.prompt_file, 'r') as f:
        return _proc_content(ast.literal_eval(f.read())[agent_type])
