import numpy as np
import re

from automatic_prompt_engineer import data, llm, evaluate
from evaluation.instruction_induction import utility

from tqdm import trange, tqdm
import concurrent.futures
import time


#### Can be improved
EVALUATION_LLM_MODEL = None


def get_query(prompt, eval_template, input_, demo_data, demos_template):
    # Demo_temp: [INPUT], [OUTPUT]
    """
    Eval Template format is as follows:
    [PROMPT] is where the prompt will be inserted.
    [full_DEMO] is where the full demo will be inserted.
    [INPUT] is where the input to the first demo will be inserted.
    [OUTPUT] is where the output from the first demo will be inserted.
    """
    demos = demos_template.fill(demo_data)
    query = eval_template.fill(prompt=prompt,
                               input=input_,
                               output='',
                               full_demo=demos)
    return query


def get_query_for_test(prompt, eval_template, input_, output_):
    query = eval_template.fill(prompt=prompt,
                               input=input_,
                               output='',
                               full_demo='')
    return query


######################################################
def extract_output_based_on_template(result_str, re_pattern=r' *[\n\.\?!\n][\'"\)\]]* *'):
    output_pattern = 'Output:'
    output_str = result_str
    #
    sentence_list = re.split(re_pattern, result_str)
    print("[Sentence list]: ", sentence_list)
    
    #
    search_string = output_pattern
    for sentence in sentence_list:
        # if sentence.startswith(search_string):
        if sentence.replace(' ', '').replace('\n', '').startswith(search_string):
            output_str = sentence
            break
    #
    output_str = output_str.replace(output_pattern, '')
    return output_str

def extract_output_based_on_template_output_enclosure(result_str, re_pattern=r'<output>(.*?)</output>'):
    matches = re.findall(re_pattern, result_str, re.DOTALL)
    output_str = ' '.join(matches)

    return output_str
    
######################################################


def exec_accuracy_evaluator_vicuna(prompts, eval_template, eval_data, demos_template, few_shot_data, config):
    queries = []
    answers = []
    for prompt in prompts:
        subsampled_data = data.subsample_data(
            eval_data, config['num_samples'])
        for d in zip(*subsampled_data):
            input_, output_ = d
            demo_data = data.subsample_data(few_shot_data, config['num_few_shot'])
            # demo_data = None
            query = get_query(prompt, eval_template, input_, demo_data, demos_template)
            #
            queries.append(query)
            answers.append(output_)

    # Instantiate the LLM
    model = config['model']['gpt_config']['model']
    model_outputs = model.generate_text(queries, 1)

    task = config['task']
    metric = utility.TASK_TO_METRIC.get(task, utility.default_metric)

    print(f'Using metric "{metric}" for task "{task}"...')

    if metric == 'f1':
        score_fn = utility.get_multi_answer_f1
    elif metric == 'es':
        score_fn = utility.get_multi_answer_exact_set
    elif metric == 'contains':
        score_fn = utility.get_multi_answer_contains
    elif metric == 'em':
        score_fn = utility.get_multi_answer_em

    scores = []
    for prediction, ans_ in zip(model_outputs, answers):
        score = score_fn(prediction, ans_)
        scores.append(score)

    # Reshape the scores so that it is num_prompts x num_samples
    scores = np.array(scores).reshape(len(prompts), config['num_samples'])

    res = ExecAccuracyEvaluationResult(prompts, scores)
    return res, scores



################################################################################################################
################################################################################################################
################################################################################################################


def mp_helper_wrap(query):
    # query, index = arg_tuple
    return EVALUATION_LLM_MODEL.generate_text(query, 1)


def exec_accuracy_evaluator(prompts, eval_template, eval_data, demos_template, few_shot_data, config,
                                        eval_LLM=None, eval_LLM_tokenizer=None, sample_seed=None, print_details_flag=True,
                                        eval_data_as_demo_flag=False, eval_in_parallel_flag=False):
    model = eval_LLM
    
    ###################
    def white_box_LLM_generate_text(input_query, model, tokenizer):
        # input_query = "This is an example script ."
        inputs = tokenizer(input_query, return_tensors="pt").to('cuda')
        # print(inputs)

        # Generate
        generate_ids = model.generate(inputs.input_ids, max_length=2000, 
                                      pad_token_id=eval_LLM_tokenizer.eos_token_id)
        output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

        return output
    
    ###################

    def parallel_process_with_process_pool(queries, num_workers=50):
        start_time = time.time()
        mp_length = len(queries)
        mp_results = {i: None for i in range(mp_length)}
        #
        total_input_token_num, total_output_token_num = 0, 0
        
        with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor:
            future_to_item = {executor.submit(mp_helper_wrap, query): (i, query) for i, query in enumerate(queries)}
            for future in tqdm(concurrent.futures.as_completed(future_to_item), total=len(future_to_item), desc="Evaluating queries"):
            # for future in concurrent.futures.as_completed(future_to_item):
                (index, query) = future_to_item[future]
                
                try:
                    result, input_token_num, output_token_num = future.result()
                    if result is not None:
                        mp_results[index] = result
                        total_input_token_num += input_token_num
                        total_output_token_num += output_token_num
                except Exception as e:
                    print(f'Item {query} generated an exception: {e}')
                    quit(1)
        
        return_str_list = []
        for i in range(mp_length):
            return_str_list += mp_results[i]

        print(f"----- Claude input token num: {total_input_token_num}, output token num: {total_output_token_num} -----")
        print(f"----- This call elapsed time: {time.time() - start_time} -----")
        
        return return_str_list

    ###################

    queries = []
    answers = []
    for prompt in prompts:
        subsampled_data = data.subsample_data(eval_data, config['num_samples'], sample_seed=sample_seed)
        for d in zip(*subsampled_data):
            input_, output_ = d
            demo_data = data.subsample_data(few_shot_data, config['num_few_shot'], sample_seed=sample_seed)
            #
            # if eval_data_as_demo_flag:
            #     for eval_q, eval_
            # demo_data = None
            query = get_query(prompt, eval_template, input_, demo_data, demos_template)
            #
            queries.append(query)
            answers.append(output_)

    # Instantiate the LLM
    if eval_LLM_tokenizer is None:
        #####################################################################
        if not eval_in_parallel_flag:
            model_outputs = model.generate_text(queries, 1)
        else:
            ### CAN BE IMPROVED
            model.print_each_token_usage_flag = False
            #
            global EVALUATION_LLM_MODEL
            EVALUATION_LLM_MODEL = model
            model_outputs = parallel_process_with_process_pool(queries, num_workers=25)
        #####################################################################
    else:
        model_outputs = []
        for input_query in queries: 
            this_output = white_box_LLM_generate_text(input_query, model, eval_LLM_tokenizer)
            model_outputs.append(this_output)
    
    task = config['task']
    metric = utility.TASK_TO_METRIC.get(task, utility.default_metric)

    if print_details_flag:
        print(f'Using metric "{metric}" for task "{task}"...')

    if metric == 'f1':
        score_fn = utility.get_multi_answer_f1
    elif metric == 'es':
        score_fn = utility.get_multi_answer_exact_set
    elif metric == 'contains':
        score_fn = utility.get_multi_answer_contains
    elif metric == 'em':
        score_fn = utility.get_multi_answer_em

    scores = []
    for query, prediction, ans_ in zip(queries, model_outputs, answers):
        #
        transformed_pred = extract_output_based_on_template_output_enclosure(prediction)
        ###
        score = score_fn(transformed_pred, ans_)
        if print_details_flag:
            #
            print(f"[Query]: {query}, [Prediction]: {transformed_pred}, [Answer]: {ans_}. [Score]: {score}".replace("\n", ", "))
        #
        scores.append(score)

    # Reshape the scores so that it is num_prompts x num_samples
    scores = np.array(scores).reshape(len(prompts), config['num_samples'])

    res = ExecAccuracyEvaluationResult(prompts, scores)
    return res, scores



################################################################################################################
################################################################################################################
################################################################################################################

class exec_evaluator(object):
    def __init__(self, api_model, config):
        # instantiate the LLM here
        if api_model=='llama':
            self.model = llm.Llama_Forward(config)
        elif api_model=='flan-t5':
            self.model = llm.Flan_T5(config)
        
    def evaluate(self, prompts, eval_template, eval_data, demos_template, few_shot_data, config):
        queries = []
        answers = []
        prompts = [prompts[0] for i in range(20)]
        for prompt in prompts:
            subsampled_data = data.subsample_data(
                eval_data, config['num_samples'])
            for d in zip(*subsampled_data):
                input_, output_ = d
                demo_data = data.subsample_data(
                    few_shot_data, config['num_few_shot'])
                # demo_data = None
                query = get_query(prompt, eval_template, input_, demo_data, demos_template)
                #
                queries.append(query)
                answers.append(output_)

        model_outputs = self.model.generate_text(queries, 1)
        # import pdb; pdb.set_trace()
        task = config['task']
        metric = utility.TASK_TO_METRIC.get(task, utility.default_metric)

        print(f'Using metric "{metric}" for task "{task}"...')

        if metric == 'f1':
            score_fn = utility.get_multi_answer_f1
        elif metric == 'es':
            score_fn = utility.get_multi_answer_exact_set
        elif metric == 'contains':
            score_fn = utility.get_multi_answer_contains
        elif metric == 'em':
            score_fn = utility.get_multi_answer_em

        scores = []
        for prediction, ans_ in zip(model_outputs, answers):
            #
            score = score_fn(prediction, ans_)
            scores.append(score)

        # Reshape the scores so that it is num_prompts x num_samples
        scores = np.array(scores).reshape(len(prompts), config['num_samples'])

        res = ExecAccuracyEvaluationResult(prompts, scores)
        return res

    def test(self, prompts, eval_template, eval_data, config):
        queries = []
        answers = []
        num_samples = config['evaluation']['num_samples']
        prompts = [prompts[0] for i in range(num_samples)]
        # import pdb; pdb.set_trace()
        for prompt in prompts:
            subsampled_data = data.subsample_data(
                eval_data, num_samples)
            for d in zip(*subsampled_data):
                input_, output_ = d
                query = get_query_for_test(
                    prompt, eval_template, input_, output_)
                queries.append(query)
                answers.append(output_)

        model_outputs = self.model.generate_text(queries, 1)
        # import pdb; pdb.set_trace()
        task = config['evaluation']['task']
        metric = utility.TASK_TO_METRIC.get(task, utility.default_metric)

        print(f'Using metric "{metric}" for task "{task}"...')

        if metric == 'f1':
            score_fn = utility.get_multi_answer_f1
        elif metric == 'es':
            score_fn = utility.get_multi_answer_exact_set
        elif metric == 'contains':
            score_fn = utility.get_multi_answer_contains
        elif metric == 'em':
            score_fn = utility.get_multi_answer_em

        scores = []
        for prediction, ans_ in zip(model_outputs, answers):
            score = score_fn(prediction, ans_)
            scores.append(score)

        # Reshape the scores so that it is num_prompts x num_samples
        scores = np.array(scores).reshape(len(prompts), num_samples)
        res = ExecAccuracyEvaluationResult(prompts, scores)
        return res

class ExecAccuracyEvaluationResult(evaluate.EvaluationResult):

    def __init__(self, prompts, scores):
        self.prompts = prompts
        self.scores = scores

    def _agg_scores(self, method):
        """For each prompt, compute a statistic of the scores (e.g., mean, median)"""
        if method == 'mean':
            return [np.mean(s) for s in self.scores]
        elif method == 'median':
            return [np.median(s) for s in self.scores]
        elif method == 'std':
            return [np.std(s) for s in self.scores]
        elif method == 'max':
            return [np.max(s) for s in self.scores]
        elif method == 'min':
            return [np.min(s) for s in self.scores]
        elif method == 'iqm':
            return [np.mean(np.percentile(lps, [25, 75])) for lps in self.scores]
        else:
            raise ValueError('Invalid method: {}'.format(method))

    def sorted(self, method='default'):
        if method == 'default':
            scores = self._agg_scores('mean')
        else:
            scores = self._agg_scores(method)
        # Sort prompts by score
        sorted_prompts = [p for _, p in sorted(zip(scores, self.prompts))]
        sorted_scores = sorted(scores)
        # Reverse both and convert to lists
        sorted_prompts = list(reversed(sorted_prompts))
        sorted_scores = list(reversed(sorted_scores))
        return sorted_prompts, sorted_scores

    def in_place(self, method='default'):
        if method == 'default':
            scores = self._agg_scores('mean')
        else:
            scores = self._agg_scores(method)
        return self.prompts, scores
