import re
import sys
import os.path

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))
import numpy as np
from tqdm import tqdm
import random
from abc import ABC, abstractmethod
import json
import api_utils as utils
import generator
import prompt_optimization.prompts as prompts


class PromptOptimizer(ABC):
    def __init__(self, args, evaluator_fn, scorer, max_threads=1, bf_eval=None):
        self.opt = args
        self.evaluator_fn = evaluator_fn
        self.scorer = scorer
        self.max_threads = max_threads
        self.bf_eval = bf_eval

    @abstractmethod
    def expand_candidates(self, prompts, task, gpt4, train_exs, attribute_cache):
        pass


class ProTeGi(PromptOptimizer):
    """ ProTeGi: Prompt Optimization with Textual Gradients
    """

    def _sample_error_str(self, true_exs, false_exs, preds, task, true_attrs, false_attrs, n=4):
        """ Sample n error strings from the given texts, labels, and preds"""
        # samples a few examples where the model's predictions do not match the true labels
        # then generates a formatted error string that summarizes these mismatches for easier analysis
        error_idxs = []
        for i in range(len(preds)):
            if preds[i] == 0:
                error_idxs.append(i)

        sample_idxs = random.sample(error_idxs, min(len(error_idxs), n))

        sample_trues = [true_exs[i] for i in sample_idxs]
        sample_falses = [false_exs[i] for i in sample_idxs]
        sample_true_attrs = [true_attrs[i] for i in sample_idxs]
        sample_false_attrs = [false_attrs[i] for i in sample_idxs]
        sample_true_paths = [true_exs[i]['img_path'] for i in sample_idxs]
        sample_false_paths = [false_exs[i]['img_path'] for i in sample_idxs]

        error_string = ''
        error_idx = 0
        for i, (t, f, ta, fa) in enumerate(zip(sample_trues, sample_falses, sample_true_attrs, sample_false_attrs)):
            information = f'## Image {error_idx + 1}\nCorrect Description:\n\"{ta.strip()}\"\n\nWrong Description:\n\"{fa.strip()}\"\n\n'
            error_string += information
            error_idx += 1
        return error_string.strip(), sample_true_paths

    def parse_tagged_text(self, text, start_tag, end_tag):
        """ Parse text that is tagged with start and end tags."""
        texts = []
        while True:
            start_index = text.find(start_tag)
            if start_index == -1:
                break
            end_index = text.find(end_tag, start_index)
            if end_index == -1:
                break
            start_index += len(start_tag)
            select = text[start_index:end_index].strip().strip('"').strip('`')
            if select != "` and `" and select != "and":
                texts.append(text[start_index:end_index].strip())
            text = text[end_index + len(end_tag):]
        texts = list(set(texts))
        return texts

    def _get_gradients(self, prompt, error_string, img_paths, num_feedbacks=5, n=1):
        """ Get "gradients" for a prompt based on the error string."""
        if self.opt['task_name'] == 'iNat_butterfly':
            gradient_prompt = prompts.iNat_butterfly
        elif self.opt['task_name'] == 'iNat_grass':
            gradient_prompt = prompts.iNat_grass
        elif self.opt['task_name'] == 'CUB_cuckoo':
            gradient_prompt = prompts.CUB_cuckoo
        elif self.opt['task_name'] == 'CUB_vireo':
            gradient_prompt = prompts.CUB_vireo
        elif self.opt['task_name'] == 'CUB_oriole':
            gradient_prompt = prompts.CUB_oriole
        elif self.opt['task_name'] == 'Stanford_terrier':
            gradient_prompt = prompts.Stanford_terrier
        elif self.opt['task_name'] == 'vegfru_1':
            gradient_prompt = prompts.vegfru_1
        elif self.opt['task_name'] == 'vegfru_2':
            gradient_prompt = prompts.vegfru_2
        else:
            raise Exception(f"Unsupported task: {self.opt['task_name']}")

        gradient_prompt += f"""
        My current prompt is:
        "{prompt}"

        But this prompt generates image descriptions that are too simple, similar and vague, making it difficult to distinguish which description correctly matches the image and leading to the wrong description being chosen for the following examples:
        {error_string}

        Give a reasons why the prompt could have gotten these examples wrong.
        """
        gradient_prompt = '\n'.join([line.lstrip() for line in gradient_prompt.split('\n')])
        if len(img_paths) > 0:
            if 'gpt4o' in self.opt['gradient_model']:
                res = utils.gpt4o(gradient_prompt, img_paths, n=n)
            elif 'sglang' in self.opt['gradient_model']:
                res = utils.sglang_model(gradient_prompt, img_paths=img_paths, model_name=self.opt['gradient_model'])
            else:
                res = utils.google_gemini(gradient_prompt, img_paths)
        else:
            print('No image!')
            if 'gpt4o' in self.opt['gradient_model']:
                res = utils.gpt4o(gradient_prompt, n=n)
            else:
                res = utils.google_gemini(gradient_prompt)

        with open(self.opt['out'], 'a') as outf:
            outf.write('feedbacks: ' + json.dumps(res) + '\n')
        return res

    def apply_gradient(self, prompt, feedback_str, error_str, img_paths, steps_per_gradient, n=1):
        """ Incorporate feedback gradient into a prompt."""
        if self.opt['task_name'] == 'iNat_butterfly':
            transformation_prompt = prompts.iNat_butterfly
        elif self.opt['task_name'] == 'iNat_grass':
            transformation_prompt = prompts.iNat_grass
        elif self.opt['task_name'] == 'CUB_cuckoo':
            transformation_prompt = prompts.CUB_cuckoo
        elif self.opt['task_name'] == 'CUB_vireo':
            transformation_prompt = prompts.CUB_vireo
        elif self.opt['task_name'] == 'CUB_oriole':
            transformation_prompt = prompts.CUB_oriole
        elif self.opt['task_name'] == 'Stanford_terrier':
            transformation_prompt = prompts.Stanford_terrier
        elif self.opt['task_name'] == 'vegfru_1':
            transformation_prompt = prompts.vegfru_1
        elif self.opt['task_name'] == 'vegfru_2':
            transformation_prompt = prompts.vegfru_2
        else:
            raise Exception(f"Unsupported task: {self.opt['task_name']}")

        transformation_prompt += f"""
        My current prompt is:
        "{prompt}"

        But this prompt generates image descriptions that make it difficult to distinguish which description correctly matches the image, leading to the wrong description being chosen for the following examples:
        {error_str}

        Based on these examples the problem with this prompt is that {feedback_str}

        Based on the above information, I wrote {steps_per_gradient} different improved prompts.
        Each prompt is wrapped with <START> and <END>.

        The {steps_per_gradient} new prompts are:
        """
        transformation_prompt = '\n'.join([line.lstrip() for line in transformation_prompt.split('\n')])
        if len(img_paths) > 0:
            if 'gpt4o' in self.opt['gradient_model']:
                res = utils.gpt4o(transformation_prompt, img_paths, n=n)
            elif 'sglang' in self.opt['gradient_model']:
                res = utils.sglang_model(transformation_prompt, img_paths=img_paths, model_name=self.opt['gradient_model'])
            else:
                res = utils.google_gemini(transformation_prompt, img_paths)
        else:
            print('No image!')
            if 'gpt4o' in self.opt['gradient_model']:
                res = utils.gpt4o(transformation_prompt, n=n)
            else:
                res = utils.google_gemini(transformation_prompt)

        new_prompts = []
        for r in res:
            new_prompts += self.parse_tagged_text(r, "<START>", "<END>")
        with open(self.opt['out'], 'a') as outf:
            for p in new_prompts:
                outf.write('new prompts: ' + json.dumps(p) + '\n')
            if len(new_prompts) == 0:
                outf.write('new prompts: None. Dumping res:' + json.dumps(res) + '\n')
        return new_prompts

    def generate_synonyms(self, prompt_section, n=3):
        """ Generate synonyms for a prompt section."""
        rewriter_prompt = f"Generate one variation of the following instruction while keeping the semantic meaning.\n\nInput: {prompt_section}\n\nOutput: ## Variation:\n\n"
        if 'gpt4o' in self.opt['gradient_model']:
            new_instructions = utils.gpt4o(rewriter_prompt, n=n)
        elif 'sglang' in self.opt['gradient_model']:
            new_instructions = []
            for i in range(n):
                new_instructions.append(utils.sglang_model(rewriter_prompt, img_paths=None, model_name=self.opt['gradient_model'])[0])
        else:
            new_instructions = []
            for i in range(n):
                new_instructions.append(utils.google_gemini(rewriter_prompt)[0])

        new_instructions = [re.sub(r"^## Variation:\s*", "", x).strip().strip('"').strip('`')
                            for x in new_instructions if x]
        for x in new_instructions:
            with open(self.opt['out'], 'a') as outf:
                outf.write('synonyms: ' + json.dumps(x) + '\n')
        return new_instructions

    def get_gradients(self, task_section, task, true_exs, false_exs, preds, true_attrs, false_attrs):
        """ Get "gradients" for a prompt based on sampled error strings."""
        prompt_feedbacks = []
        for _ in tqdm(range(self.opt['n_gradients']), total=self.opt['n_gradients'], desc='gradients..'):
            error_string, img_paths = self._sample_error_str(true_exs, false_exs, preds, task, true_attrs, false_attrs,
                                                             n=self.opt['errors_per_gradient'])
            gradients = self._get_gradients(task_section, error_string, img_paths, self.opt['gradients_per_error'], n=1)
            prompt_feedbacks += [(t, error_string, img_paths) for t in gradients]
        return prompt_feedbacks

    def expand_candidates(self, prompts, task, gpt4, train_exs, pred_prompts=None, attribute_cache=None,
                          pred_prompt=None):
        """ Expand a list of prompts by generating gradient-based successors and
            synonyms for each section.
        """
        minibatch = random.sample(train_exs, k=self.opt['minibatch_size'])

        new_prompts = []
        for prompt in tqdm(prompts, desc=f'expanding {len(prompts)} prompts'):
            sections = utils.parse_sectioned_prompt(prompt)
            task_section = sections['task'].strip()

            true_exs, false_exs, preds, true_attrs, false_attrs = task.compare_evaluate(prompt, minibatch,
                                                                                        attribute_cache,
                                                                                        model_name=self.opt['model'])

            new_task_sections = []
            if self.opt['n_gradients'] > 0:
                # gradients = self.get_gradients(task_section, task, texts, labels, preds, attributes)
                gradients = self.get_gradients(task_section, task, true_exs, false_exs, preds, true_attrs, false_attrs)
                new_task_sections = []
                for feedback, error_string, img_paths in tqdm(gradients, desc='applying gradients'):
                    tmp = self.apply_gradient(task_section, feedback, error_string, img_paths,
                                              self.opt['steps_per_gradient'])
                    new_task_sections += tmp

            # generate synonyms
            mc_sampled_task_sections = []
            if self.opt['mc_samples_per_step'] > 0:
                for sect in tqdm(new_task_sections + [task_section], desc='mc samples'):
                    mc_sects = self.generate_synonyms(sect, n=self.opt['mc_samples_per_step'])
                    mc_sampled_task_sections += mc_sects

            # combine
            new_sections = new_task_sections + mc_sampled_task_sections
            new_sections = list(set(new_sections))  # dedup
            tmp_new_prompts = [prompt.replace(task_section, tmp.strip()) for tmp in new_sections
                               if tmp != task_section and prompt.replace(task_section, tmp) != prompt]
            if len(tmp_new_prompts) != len(new_sections):
                print(f'Invalid replacement with {len(tmp_new_prompts)} != {len(new_sections)}')
                with open(self.opt['out'], 'a') as outf:
                    outf.write(f'{len(tmp_new_prompts)}, {len(new_sections)}: ' + json.dumps(task_section) + '\n')

            # filter a little
            if len(new_sections) > self.opt['max_expansion_factor']:
                if self.opt['reject_on_errors']:
                    error_exs = []
                    for i in range(len(preds)):
                        if preds[i] == 0:
                            error_exs.append({'id': i, 'img_path': true_exs[i]['img_path']})
                    error_exs = random.sample(error_exs, min(len(error_exs), 16))
                    # speed up a little
                    tmp_new_prompts = random.sample(tmp_new_prompts,
                                                    min(len(tmp_new_prompts), self.opt['max_expansion_factor'] * 2))
                    if attribute_cache != None:
                        print('Generating temp attr cache:', len(tmp_new_prompts))
                        temp_attr_cache = {}
                        for temp_prompt in tmp_new_prompts:
                            temp_attr_cache[f'{temp_prompt}'] = {}
                            temp_attr_cache = generator.parallel_generate(generator.AttrGredictor(self.opt),
                                                                          temp_prompt, error_exs, temp_attr_cache, 8)
                        error_scores = self.bf_eval(tmp_new_prompts, error_exs, gpt4, self.scorer,
                                                    temp_attr_cache, max_threads=self.max_threads)
                    else:
                        error_scores = self.bf_eval(tmp_new_prompts, error_exs, gpt4, self.scorer,
                                                    max_threads=self.max_threads)
                    tmp_new_prompts = [tmp_new_prompts[i]
                                       for i in np.argsort(error_scores)[-self.opt['max_expansion_factor']:]]
                else:
                    tmp_new_prompts = random.sample(tmp_new_prompts, k=self.opt['max_expansion_factor'])

            new_prompts += tmp_new_prompts

        new_prompts += prompts  # add originals
        new_prompts = list(set(new_prompts))  # dedup

        return new_prompts

    def score_candidates(self, prompts, gpt4=None, train_exs=None, pred_prompts=None, attribute_cache=None):
        """ Score a list of prompts."""
        if len(prompts) == 1:
            return [1.0]

        evals = self.evaluator_fn(prompts, train_exs, gpt4,
                                  scorer=self.scorer,
                                  pred_prompts=pred_prompts,
                                  attribute_cache=attribute_cache,
                                  rounds=self.opt['eval_rounds'],
                                  num_prompts_per_round=self.opt['eval_prompts_per_round'],
                                  samples_per_eval=self.opt['samples_per_eval'],
                                  max_threads=self.max_threads)
        return evals
