import numpy as np

from automatic_prompt_engineer import data, llm, evaluate
from experiments.evaluation.instruction_induction import utility
import re
import sqlite3
import time

def execute_query(query, db_path, task, max_retry=10, retry_delay=0.5):
    retry_count = 0
    while retry_count < max_retry:
        try:
            with sqlite3.connect(db_path) as conn:
                cursor = conn.cursor()
                if task in ["svamp", "multiarith"]:
                    cursor.execute("SELECT output FROM sm WHERE query=?", (query,))
                else:
                    cursor.execute("SELECT output FROM gsm8k WHERE query=?", (query,))
                data = cursor.fetchall()
                return data
                # conn.commit()
        except sqlite3.OperationalError as e:
            if 'database is locked' in str(e):
                retry_count += 1
                time.sleep(retry_delay)  # Wait for some time before trying again.
            else:
                raise e  # If error is not 'database is locked', raise it.
        except Exception as e:
            raise e
    return []

def insert_many(data, db_path, task, max_retry=5, retry_delay=0.5):
    if task in ["svamp", "multiarith"]:
        query = "INSERT OR REPLACE INTO sm (query, output) VALUES (?,?)"
    else:
        query = "INSERT OR REPLACE INTO gsm8k (query, output) VALUES (?,?)"
    
    for attempt in range(max_retry):
        try:
            with sqlite3.connect(db_path) as conn:
                cursor = conn.cursor()
                cursor.executemany(query, data)
                conn.commit()
                break  # If successful, break the retry loop.
        except (sqlite3.IntegrityError, sqlite3.OperationalError) as e:
            if attempt < max_retry - 1:  # if it's not the last attempt
                print(f"Attempt {attempt+1} failed, retrying in {retry_delay} seconds...")
                time.sleep(retry_delay)  # Wait for some time before trying again.
            else:
                raise e  # If it's the last attempt, raise last exception.


def extract_numbers(string):
    return re.findall(r'[+-]?\d+\.\d+|\d+', string)


def get_query(prompt, eval_template, input_, output_, demo_data, demos_template):
    demos = demos_template.fill(demo_data)
    query = eval_template.fill(prompt=prompt,
                               input=input_,
                               output='',
                               full_demo=demos)
    return query


def get_query_for_test(prompt, eval_template, input_, output_):
    query = eval_template.fill(prompt=prompt,
                               input=input_,
                               output='',
                               full_demo='')
    return query



def exec_accuracy_evaluator(prompts, eval_template, eval_data, demos_template, few_shot_data, config):
    queries = []
    answers = []
    task = config['task']
    if len(prompts) > 1:
        raise ValueError('Only one prompt is supported for exec accuracy evaluation.')
    query_for_db = []
    answers_for_db = []
    pred_ans_from_db = []
    for prompt in prompts:
        subsampled_data = data.subsample_data(
            eval_data, config['num_samples'])
        for d in zip(*subsampled_data):
            input_, output_ = d
            demo_data = data.subsample_data(
                few_shot_data, config['num_few_shot'])
            # demo_data = None
            query = get_query(
                prompt, eval_template, input_, output_, demo_data, demos_template)
            results = execute_query(query, 'gsm8k.db', task)

            if len(results) != 0:
                query_for_db.append(query)
                answers_for_db.append(output_[0])
                pred_ans_from_db.append(results[0][0])
            else:
                queries.append(query)
                answers.append(output_[0])

    # Instantiate the LLM
    model = llm.model_from_config(config['model'])
    model_outputs = model.generate_text(queries, 1)
    # print("First query: ", queries[0])
    # print("Model outputs: ", model_outputs[0])
    answer_extract_queries = []

    if task in ["gsm8k", "svamp", "multiarith"]:
        answe_ind = "\nTherefore, the answer (arabic numerals) is"
    elif task == "aqua":
        answe_ind = "\nTherefore, among A through E, the answer is"
    for query, model_output in zip(queries, model_outputs):
        query_plus_answer = query + model_output + answe_ind
        answer_extract_queries += [query_plus_answer]
    answer_extract_model_outputs = model.generate_text(answer_extract_queries, 1)
    # print("Intermidiate query: ", answer_extract_queries[0])
    # print("Final answer extract model outputs: ", answer_extract_model_outputs[0])

    if task in ["gsm8k", "svamp", "multiarith"]:
        answer_extract_model_outputs = [extract_numbers(output_.rstrip('.')) for output_ in answer_extract_model_outputs]
    elif task == "aqua":
        answer_extract_model_outputs = [re.findall (r'A|B|C|D|E',output_) for output_ in answer_extract_model_outputs]
    pred_answers = [tmp[0] if len(tmp)>0 else "" for tmp in answer_extract_model_outputs]

    result = []
    for pred_ans, gt_ans in zip(pred_answers, answers):
        if task in ["svamp", "multiarith"]:
            try:
                pred_ans = float(pred_ans)
                gt_ans = float(gt_ans)
            except:
                pass
        if pred_ans == gt_ans:
            result += [1]
        else:
            result += [0]
    insert_many(list(zip(queries, pred_answers)), 'gsm8k.db', task)
    for pred_db, gt_db in zip(pred_ans_from_db, answers_for_db):
        if task in ["svamp", "multiarith"]:
            try:
                pred_db = float(pred_db)
                gt_db = float(gt_db)
            except:
                pass
        if pred_db == gt_db:
            result += [1]
        else:
            result += [0]

    return sum(result) / len(result), np.array(result)
        

def exec_accuracy_evaluator_vicuna(prompts, eval_template, eval_data, demos_template, few_shot_data, config):
    queries = []
    answers = []
    for prompt in prompts:
        subsampled_data = data.subsample_data(
            eval_data, config['num_samples'])
        for d in zip(*subsampled_data):
            input_, output_ = d
            demo_data = data.subsample_data(
                few_shot_data, config['num_few_shot'])
            # demo_data = None
            query = get_query(
                prompt, eval_template, input_, output_, demo_data, demos_template)
            queries.append(query)
            answers.append(output_)

    # Instantiate the LLM
    model = config['model']['gpt_config']['model']
    model_outputs = model.generate_text(queries, 1)

    task = config['task']
    metric = utility.TASK_TO_METRIC.get(task, utility.default_metric)

    print(f'Using metric "{metric}" for task "{task}"...')

    if metric == 'f1':
        score_fn = utility.get_multi_answer_f1
    elif metric == 'es':
        score_fn = utility.get_multi_answer_exact_set
    elif metric == 'contains':
        score_fn = utility.get_multi_answer_contains
    elif metric == 'em':
        score_fn = utility.get_multi_answer_em

    scores = []
    for prediction, ans_ in zip(model_outputs, answers):
        score = score_fn(prediction, ans_)
        scores.append(score)

    # Reshape the scores so that it is num_prompts x num_samples
    scores = np.array(scores).reshape(len(prompts), config['num_samples'])

    res = ExecAccuracyEvaluationResult(prompts, scores)
    return res, scores

class exec_evaluator(object):
    def __init__(self, api_model, config):
        # instantiate the LLM here
        if api_model=='llama':
            self.model = llm.Llama_Forward(config)
        elif api_model=='flan-t5':
            self.model = llm.Flan_T5(config)
        
    def evaluate(self, prompts, eval_template, eval_data, demos_template, few_shot_data, config):
        queries = []
        answers = []
        prompts = [prompts[0] for i in range(20)]
        for prompt in prompts:
            subsampled_data = data.subsample_data(
                eval_data, config['num_samples'])
            for d in zip(*subsampled_data):
                input_, output_ = d
                demo_data = data.subsample_data(
                    few_shot_data, config['num_few_shot'])
                # demo_data = None
                query = get_query(
                    prompt, eval_template, input_, output_, demo_data, demos_template)
                queries.append(query)
                answers.append(output_)

        model_outputs = self.model.generate_text(queries, 1)
        # import pdb; pdb.set_trace()
        task = config['task']
        metric = utility.TASK_TO_METRIC.get(task, utility.default_metric)

        print(f'Using metric "{metric}" for task "{task}"...')

        if metric == 'f1':
            score_fn = utility.get_multi_answer_f1
        elif metric == 'es':
            score_fn = utility.get_multi_answer_exact_set
        elif metric == 'contains':
            score_fn = utility.get_multi_answer_contains
        elif metric == 'em':
            score_fn = utility.get_multi_answer_em

        scores = []
        for prediction, ans_ in zip(model_outputs, answers):
            score = score_fn(prediction, ans_)
            scores.append(score)

        # Reshape the scores so that it is num_prompts x num_samples
        scores = np.array(scores).reshape(len(prompts), config['num_samples'])

        res = ExecAccuracyEvaluationResult(prompts, scores)
        return res

    def test(self, prompts, eval_template, eval_data, config):
        queries = []
        answers = []
        num_samples = config['evaluation']['num_samples']
        prompts = [prompts[0] for i in range(num_samples)]
        # import pdb; pdb.set_trace()
        for prompt in prompts:
            subsampled_data = data.subsample_data(
                eval_data, num_samples)
            for d in zip(*subsampled_data):
                input_, output_ = d
                query = get_query_for_test(
                    prompt, eval_template, input_, output_)
                queries.append(query)
                answers.append(output_)

        model_outputs = self.model.generate_text(queries, 1)
        # import pdb; pdb.set_trace()
        task = config['evaluation']['task']
        metric = utility.TASK_TO_METRIC.get(task, utility.default_metric)

        print(f'Using metric "{metric}" for task "{task}"...')

        if metric == 'f1':
            score_fn = utility.get_multi_answer_f1
        elif metric == 'es':
            score_fn = utility.get_multi_answer_exact_set
        elif metric == 'contains':
            score_fn = utility.get_multi_answer_contains
        elif metric == 'em':
            score_fn = utility.get_multi_answer_em

        scores = []
        for prediction, ans_ in zip(model_outputs, answers):
            score = score_fn(prediction, ans_)
            scores.append(score)

        # Reshape the scores so that it is num_prompts x num_samples
        scores = np.array(scores).reshape(len(prompts), num_samples)
        res = ExecAccuracyEvaluationResult(prompts, scores)
        return res

class ExecAccuracyEvaluationResult(evaluate.EvaluationResult):

    def __init__(self, prompts, scores):
        self.prompts = prompts
        self.scores = scores

    def _agg_scores(self, method):
        """For each prompt, compute a statistic of the scores (e.g., mean, median)"""
        if method == 'mean':
            return [np.mean(s) for s in self.scores]
        elif method == 'median':
            return [np.median(s) for s in self.scores]
        elif method == 'std':
            return [np.std(s) for s in self.scores]
        elif method == 'max':
            return [np.max(s) for s in self.scores]
        elif method == 'min':
            return [np.min(s) for s in self.scores]
        elif method == 'iqm':
            return [np.mean(np.percentile(lps, [25, 75])) for lps in self.scores]
        else:
            raise ValueError('Invalid method: {}'.format(method))

    def sorted(self, method='default'):
        if method == 'default':
            scores = self._agg_scores('mean')
        else:
            scores = self._agg_scores(method)
        # Sort prompts by score
        sorted_prompts = [p for _, p in sorted(zip(scores, self.prompts))]
        sorted_scores = sorted(scores)
        # Reverse both and convert to lists
        sorted_prompts = list(reversed(sorted_prompts))
        sorted_scores = list(reversed(sorted_scores))
        return sorted_prompts, sorted_scores

    def in_place(self, method='default'):
        if method == 'default':
            scores = self._agg_scores('mean')
        else:
            scores = self._agg_scores(method)
        return self.prompts, scores
