import json
import logging
import os
import pickle
import random
import re
import time

from prompt_optimization_lcp import utils
from prompt_optimization_lcp.prompt_template import (
    generate_prompt_for_hint_generation,
    generate_prompt_for_hint_summarization,
    generation_high_level_prompt,
)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


PROMPT_LIMIT = 100000


def load_data(base_dir, task_name):
    """Load data from json files.
    :return: List of data samples.
    """
    task_loc = f"{base_dir}/{task_name}/{task_name}.json"
    with open(task_loc, "r") as f:
        data = json.loads(f.readlines()[0])["examples"]
    return data


def data_split(splition_folder, task_name, split="train"):
    """Directly loading data split files to make it consistent to other works.
    :param splition_folder: path to save split files.
    :return: index of data split.
    """
    with open(os.path.join(splition_folder, task_name, split + "_index.pkl"), "rb") as fin:
        split_index = pickle.load(fin)
    return split_index


def test_train_split(data_folder, task_name):
    """Load the dataset.
    :param data_folder: dataset path
    :return: dataset, the index of train split and test split
    """
    dataset = load_data(data_folder, task_name)
    train_index = data_split(data_folder, task_name, "train")
    test_index = data_split(data_folder, task_name, "test")
    logger.info(
        f"Split the dataset to {len(train_index)} training samples, and {len(test_index)} testing samples."
    )
    return dataset, train_index, test_index


def generate_output_folder(project_folder, exp_folder):
    """Create an output folder if it doesn't exist.
    :return: output folder path.
    """
    output_folder = os.path.join(project_folder, "outputs", exp_folder)
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)
    return output_folder


def generate_output_filepath(output_folder, task_name, prefix):
    """Create an output file if it doesn't exist.
    :return: output file path.
    """
    task_folder = os.path.join(output_folder, task_name)
    if not os.path.exists(task_folder):
        os.mkdir(task_folder)
    output_file = os.path.join(task_folder, f"{prefix}_output.json")
    return output_file


def save_output_file(output_file, dict_list):
    """Save results to json file.
    :return: None.
    """
    with open(output_file, "w") as fou:
        json.dump(dict_list, fou)
    logger.info("Write Generated Prompts Successfully!")


def run_inference_step(
    task_name,
    is_boolean,
    dataset,
    data_index,
    meta_prompt,
    model_id,
    temperature=1.0,
    batch_size=8,
):
    """Run the inference step on the dataset.
    :param is_boolean: if task is Boolean Task, the instruction would be slightly different in group_batch function.
    :meta_prompt: current prompt we use.
    :return: incorrectly predicted sample index, correcly predicted sample index and average accuracy.
    """
    prompts_grouped_by_batch_size, true_answers = utils.group_batch(
        dataset,
        data_index,
        meta_prompt,
        batch_size,
        is_boolean,
    )
    raw_outputs = utils.parallel_call_llm(
        prompts_grouped_by_batch_size=prompts_grouped_by_batch_size,
        model_id=model_id,
        temperature=temperature,
    )
    outputs = [utils.extract_output(raw_output, "Ans") for raw_output in raw_outputs]
    accuracy_list = list(
        map(
            lambda x, y: _get_accuracy(
                true_answer=x,
                pred_answer=y,
            ),
            true_answers,
            outputs,
        )
    )
    wrong_samples, true_samples = [], []
    true_samples = [data_index[i] for i in range(len(accuracy_list)) if accuracy_list[i] == 1.0]
    wrong_samples = [data_index[i] for i in range(len(accuracy_list)) if accuracy_list[i] != 1.0]
    num_examples = len(data_index)
    avg_acc = len(true_samples) / num_examples
    logger.info(f"Current Accuracy is: {avg_acc}")
    return wrong_samples, true_samples, avg_acc


def _get_accuracy(true_answer, pred_answer):
    """Compare true_answer to pred_answer.
    This function can extract the option in upper/lower case, w/w.o parenthesis.
    :return: 1.0 if correct, else 0.0.
    """
    clean_answer = None
    pattern = r"^([a-zA-Z]|\([a-zA-Z]\))[\s:)]*"
    match = re.search(pattern, pred_answer)
    if match:
        option = match.group(1)
        if option.startswith("("):
            option = option[1:-1]  # Remove parentheses
        clean_answer = "(" + option.upper() + ")"
    return 1.0 if clean_answer == true_answer else 0.0


def random_sampling(wrong_samples, true_samples, k=20):
    """Randomly select some incorrectly predicted samples.
    :return: selected samples.
    """
    if not wrong_samples:
        k = min(k, len(true_samples))
        selected = random.sample(true_samples, k)
    else:
        k = min(k, len(wrong_samples))
        selected = random.sample(wrong_samples, k)
    return selected


def select_samples(flipping_strategy, curr_wrong_samples, curr_true_samples, prev_wrong_samples):
    """Data sampling strategy for prompt adaptation.
    :param flipping_strategy: specify which samples to select.
            CorrectToWrong: correctly predicted previously, incorrectly predicted currently.
            WrongToWrong: both incorrecly predicted.
            WrongToCorrect: incorrectly predicted previously, correctly predicted currently.
    :param curr_wrong_samples: incorrectly predicted samples in the current model/language.
    :param curr_true_samples: correctly predicted samples in the current model/language.
    :param prev_wrong_samples: incorrectly predicted samples in the previous model/language.
    :return: selected samples.
    """
    if flipping_strategy == "CorrectToWrong":
        new_wrong_samples = list(set(curr_wrong_samples) - set(prev_wrong_samples))
    elif flipping_strategy == "WrongToWrong":
        new_wrong_samples = list(set(curr_wrong_samples) & set(prev_wrong_samples))
    elif flipping_strategy == "WrongToCorrect":
        new_wrong_samples = list(set(curr_true_samples) & set(prev_wrong_samples))
    else:
        raise NotImplementedError("Please check your flipping strategy!")
    return new_wrong_samples


def run_hint_generation(
    dataset, model_id, wrong_samples, true_samples=None, temperature=1.0, hint_nums=3
):
    """Generate hint for each selected sample.
    :param hint_nums: the number of generated hints.
    :return: hints for selected incorrectly predicted samples.
    """
    selected_samples = random_sampling(wrong_samples, true_samples, hint_nums)
    hints_wrong_ans = {}
    hint_prompt = generate_prompt_for_hint_generation()
    for idx in selected_samples:
        true_answer = dataset[idx]["target"]
        prompt = hint_prompt.replace("<INPUT>", dataset[idx]["input"])
        prompt = prompt.replace("<OUTPUT>", true_answer)
        output = utils.call_llm_func(prompt, model_id=model_id, temperature=temperature)
        hint = utils.extract_output(output, "hint")
        if hint != "":
            hints_wrong_ans[idx] = hint
    return hints_wrong_ans


def summarize_hints(hints, dataset, model_id, temperature=1.0, prompt_limit=PROMPT_LIMIT):
    """Summarize generated hints.
    :param hints: dict sample index, the specific hint.
    :param model_id: the LLM used to summarize hints.
    :return: the hints summary.
    """
    sample_hint = "Given input: <INPUT>\nAnd its expected output: <OUTPUT>.\nAnd the reason for the expected output: <HINT>\n"
    summarize_prompt = generate_prompt_for_hint_summarization()
    curr_hint = ""
    for idx in hints.keys():
        true_answer = dataset[idx]["target"]
        temp = (
            sample_hint.replace("<INPUT>", dataset[idx]["input"])
            .replace("<OUTPUT>", true_answer)
            .replace("<HINT>", hints[idx])
        )
        if len(curr_hint.split(" ")) >= prompt_limit - len(summarize_prompt.split(" ")) - len(
            temp.split(" ")
        ):
            logger.error("The length of sample hints is greater than maximum!")
            break
        curr_hint += temp
    prompt = summarize_prompt.replace("<HINTS>", curr_hint)
    raw_output = utils.call_llm_func(prompt, model_id=model_id, temperature=temperature)
    output = utils.extract_output(raw_output, "hint")
    return output


def train(args):
    """Train function of the framework.
    :param args: all related arguments are included in arguments.py.
    :return: None, will write output files.
    """
    start_time = time.time()
    output_file = generate_output_filepath(args.output_folder, args.task_name, args.prefix)
    meta_prompt = args.initial_prompt
    output_dict_list = []
    prompt_history = []
    for step in range(args.steps):
        logger.info(f"Current Step is: {step}")
        # step 1: Prompt Candidates Generation
        wrong_samples, true_samples, prev_perf = run_inference_step(
            task_name=args.task_name,
            is_boolean=args.is_boolean,
            dataset=args.dataset,
            data_index=args.train_index,
            meta_prompt=meta_prompt,
            model_id=args.model_id,
            temperature=args.temperature,
            batch_size=args.batch_size,
        )
        if args.optimization_strategy == "adaptation":
            p_wrong_samples, p_true_samples, p_prev_perf = run_inference_step(
                task_name=args.task_name,
                is_boolean=args.is_boolean,
                dataset=args.dataset,
                data_index=args.train_index,
                meta_prompt=meta_prompt,
                model_id=args.prev_model_id,
                temperature=args.temperature,
                batch_size=args.batch_size,
            )
            new_wrong_samples = select_samples(
                args.flipping_strategy,
                wrong_samples,
                true_samples,
                p_wrong_samples,
            )
        else:
            new_wrong_samples = wrong_samples
        prompt_history.append([meta_prompt, prev_perf])
        prompt_candidates = prompt_history
        for _ in range(args.num_candidates):
            # Hint Generation
            generated_hints = run_hint_generation(
                dataset=args.dataset,
                model_id=args.model_id,
                wrong_samples=new_wrong_samples,
                true_samples=true_samples,
                temperature=args.temperature,
                hint_nums=args.num_hints,
            )
            # Hint Summarization
            prompt_candidate = summarize_hints(
                hints=generated_hints,
                dataset=args.dataset,
                model_id=args.model_id,
                temperature=args.temperature,
                prompt_limit=args.prompt_limit,
            )
            # Calculate corresponding scores
            cur_wrong, cur_true, perf = run_inference_step(
                task_name=args.task_name,
                is_boolean=args.is_boolean,
                dataset=args.dataset,
                data_index=args.train_index,
                meta_prompt=prompt_candidate,
                model_id=args.model_id,
                temperature=args.temperature,
                batch_size=args.batch_size,
            )
            prompt_candidates.append([prompt_candidate, perf])
        # step 2: New Prompt Generation
        prompt_candidates = sorted(prompt_candidates, key=lambda x: x[1], reverse=True)
        generation_prompt = generation_high_level_prompt(
            args.dataset_name, args.task_name, prompt_candidates
        )
        raw_prompt = utils.call_llm_func(
            input_str=generation_prompt,
            model_id=args.model_id,
            temperature=args.temperature,
        )
        new_prompt = utils.extract_output(raw_prompt, "prompt")
        output_dict_list.append(json.dumps({"step_id": step, "prompt": new_prompt}))
        meta_prompt = new_prompt
    save_output_file(output_file, output_dict_list)
    logger.info(f"Time Duration: {time.time()-start_time}")
    return
