import os
import logging
import re
import openai
import json
import random
from pdfminer.high_level import extract_text
import numpy as np
from utils import generate_powerset, query_problem, turn_into_input_format, add_numerical_list_string_to_list, parse_answer
logging.basicConfig(level=logging.ERROR)


def generate_advice_prompts(args):
    # iterates through the types of advice to generate: {irrelevant, neutral, normal, unorthodox}
    # generates (advice prompts, num_advice) for each type and saves them in advice_prompts dict
    # keys are:      {baseline_irrelevant, baseline_neutral, normal, unorthodox}
    # w/ optional:   {irrelevant_cf, neutral_cf, normal_cf, unorthodox_cf}
    # advice_prompts dict is saved at scenario/advice_prompts.json
    
    if os.path.exists(f"{args.scenario}/advice_prompts.json"):
        print("advice prompts already exist, skipping generation.")
        return
    
    advice_prompts = {}
    advice_prompts = generate_advice_prompt(args, "baseline_irrelevant", args.num_advice, advice_prompts)
    advice_prompts = generate_advice_prompt(args, "baseline_neutral", args.num_advice, advice_prompts)
    advice_prompts = generate_advice_prompt(args, "normal", args.num_advice, advice_prompts)
    advice_prompts = generate_advice_prompt(args, "unorthodox", args.num_unorthodox_advice, advice_prompts)

    with open(f"{args.scenario}/advice_prompts.json", "w") as f:
        json.dump(advice_prompts, f)
    

def generate_advice_prompt(args, advice_type, num_advice, advice_prompts):
    # generates the advice prompt and counterfactual advice prompt (if needed)
    # adds key: advice_type, value: (advice, num_advice) to the advice_prompts dict
    
    with open(f"{args.scenario}/{args.setting_path}", "r") as f:
        setting_text = f.read()

    with open(f"{args.scenario}/{advice_type}_advice.txt", "r") as f:
        advice_text = f.read()
    
    prompt = setting_text + advice_text

    if args.counterfactual_path != "": 
    # add counterfactual scenario to half(round down) of generated advice

        with open(f"{args.scenario}/{args.counterfactual_path}", "r") as f:
            cf_text = f.read()
        
        with open(f"{args.scenario}/{args.counterfactual_existing_scenario_path}", "r") as f:
            existing_scenario_text = f.read()
        
        prompt_cf = setting_text + cf_text + advice_text
        cf_text = cf_text.replace("EXISTING_SCENARIO", existing_scenario_text)
        prompt_cf = prompt.replace("NUM_ADVICE", str(num_advice//2))
        input_prompt_cf = turn_into_input_format("user", prompt_cf)
        advice_prompts[f"{advice_type}_cf"] = (input_prompt_cf, num_advice//2)

        prompt = prompt.replace("NUM_ADVICE", str(num_advice - num_advice//2))
        input_prompt = turn_into_input_format("user", prompt)
        advice_prompts[advice_type] = (input_prompt, num_advice - num_advice//2)

    else:
        prompt = prompt.replace("NUM_ADVICE", str(num_advice))
        input_prompt = turn_into_input_format("user", prompt)
        advice_prompts[advice_type] = (input_prompt, num_advice)

    return advice_prompts


def generate_advice(args):
    # iterates through the keys of scenario/advice_prompts.json: 
    # {irrelevant, neutral, normal, unorthodox} and optionally their counterfactual versions
    # generates the advice for each type and saves them into advice/advice_type.json

    with open(f"{args.scenario}/advice_prompts.json", "r") as f:
        advice_prompts = json.load(f)
    
    if not os.path.exists(f"{args.scenario}/advice"):
        os.mkdir(f"{args.scenario}/advice")
    
    for advice_type, (prompt, num_advice) in advice_prompts.items():

        if os.path.exists(f"{args.scenario}/advice/{advice_type}_advice.json"):
            print(f"{advice_type} advice already exists, skipping advice generation.")
            continue

        answers = query_problem(args, None, prompt, temperature=1, top_p=1, n=1, verbose=args.verbose)[0][0]["message"]["content"]
        # here, we only generate one query for advice, since the variation is ideally captured in the three advice options. 
        
        if num_advice > 1: advice_list = add_numerical_list_string_to_list([], answers)
        else: advice_list = [answers]
    
        if len(advice_list) != num_advice:
            raise ValueError(f"Number of advice generated does not match the expected number: {len(advice_list)} vs {num_advice}\n{advice_list}")
        
        print(f"generated {len(advice_list)} pieces of advice for {advice_type}.")

        # save into a json file
        with open(f"{args.scenario}/advice/{advice_type}_advice.json", "w") as f:
            json.dump(advice_list, f)


def generate_candidates_prompts(args, path_list):
    advice_list = []
    for path in path_list:
        with open(f"{args.scenario}/advice/{path}.json", "r") as f:
            advice_list += json.load(f)
    
    all_generated = True
    powerset = generate_powerset(len(advice_list))
    powerset = powerset[1:]
    generated_prompts_count = 0

    if not os.path.exists(f"{args.scenario}/candidates_prompts"):
        os.mkdir(f"{args.scenario}/candidates_prompts")

    for combination in powerset:
        if len(combination) > args.max_advice_per_generation:
            continue

        name = str(combination).replace(", ", "_").replace("[", "").replace("]", "")
        num_advice = len(combination)
        advice_to_use = []
        for i in range(num_advice):
            advice_to_use.append(advice_list[combination[i]])

        if os.path.exists(f"{args.scenario}/candidates_prompts/{path}/{name}.json"):
            if args.verbose: print(f"candidates prompts for {path} advice set {name} already exists, skipping generation.")
            continue

        generate_candidates_prompt(args, advice_to_use, path, name)
        generated_prompts_count += 1
        all_generated = False
    
    if all_generated:
        print("all candidates prompts already exist, skipping generation.")
    else:
        print(f"generated {generated_prompts_count} candidates prompts for {path_list}.")


def generate_candidates_prompt(args, advice_to_use, path, name):
    if len(advice_to_use) == 0: remember_advice_text = ""

    elif len(advice_to_use) == 1:
        with open(f"{args.scenario}/{args.remember_advice_path}", "r") as f:
            remember_advice_text = f.read()
            advice = advice_to_use[0]
            if not advice.startswith("\""): advice = "\"" + advice + "\""
            remember_advice_text = remember_advice_text.replace("ADVICE", advice)
    
    else: 
        with open(f"{args.scenario}/{args.remember_advice_path[:-4]}s.txt", "r") as f:
            remember_advice_text = f.read()
            for advice_index in range(len(advice_to_use)):
                advice = advice_to_use[advice_index]
                if not advice.startswith("\""): advice = "\"" + advice + "\""
                if advice_index == len(advice_to_use) - 1:
                    remember_advice_text = remember_advice_text.replace("ADVICE", f"{advice_index+1}. {advice}\n")
                else:
                    remember_advice_text = remember_advice_text.replace("ADVICE", f"{advice_index+1}. {advice}\nADVICE", 1)

    with open(f"{args.scenario}/{args.second_person_scenario_path}", "r") as f:
        second_person_scenario_text = f.read()
        second_person_scenario_text = second_person_scenario_text.replace("ADVICESENTENCE", remember_advice_text)

    with open(f"{args.scenario}/{args.second_person_system_path}", "r") as f:
        second_person_system_text = f.read()
    
    user_prompt = turn_into_input_format("user", second_person_scenario_text)
    system_prompt = turn_into_input_format("system", second_person_system_text)

    # save into a json file
    if not os.path.exists(f"{args.scenario}/candidates_prompts/{path}"):
        os.mkdir(f"{args.scenario}/candidates_prompts/{path}")
    
    with open(f"{args.scenario}/candidates_prompts/{path}/{name}.json", "w") as f:
        json.dump({"user_prompt": user_prompt, \
                   "system_prompt": system_prompt, \
                   "advice_indices": name, \
                   "advice_list": advice_to_use}, f)


def generate_candidates(args, path):

    prompts_list = []
    for file in os.listdir(f"{args.scenario}/candidates_prompts/{path}"):
        if file[0] == ".": continue
        with open(f"{args.scenario}/candidates_prompts/{path}/{file}", "r") as f:
            prompts_list.append(json.load(f))

    if not os.path.exists(f"{args.scenario}/candidates/{path}"):
        os.mkdir(f"{args.scenario}/candidates/{path}")

    all_generated = True
    generated_candidates = []

    for d in (prompts_list):
        if os.path.exists(f"{args.scenario}/candidates/{path}/{d['advice_indices']}.json"):
            if args.verbose: print(f"candidates for advice set {d['advice_indices']} already exists, skipping generation.")
            with open(f"{args.scenario}/candidates/{path}/{d['advice_indices']}.json", "r") as f:
                candidates = json.load(f)
                for i in range(len(candidates)):
                    candidate = candidates[i]
                    generated_candidates.append({"advice_type": path, \
                                                 "advice_indices": d["advice_indices"], \
                                                 "advice_list": d["advice_list"], \
                                                 "generation_index": i, \
                                                 "text": candidate})
            continue
            
        answers = query_problem(args, d["system_prompt"], d["user_prompt"], temperature=1, top_p=1, n=args.num_generations, verbose=args.verbose)
        answers = [answer["message"]["content"].strip() for answer in answers[0]]

        # save into a json file
        with open(f"{args.scenario}/candidates/{path}/{d['advice_indices']}.json", "w") as f:
            json.dump(answers, f)

        for i in range(len(answers)):
            generation = answers[i]
            generated_candidates.append({"advice_type": path, \
                                        "advice_indices": d['advice_indices'], \
                                        "advice_list": d['advice_list'], \
                                        "generation_index": i, \
                                        "text": generation})
        all_generated = False

    if all_generated:
        print("all responses already exist, skipping generation.")
    else:
        print(f"generated {len(generated_candidates)} responses for {path}.")

    return generated_candidates


def compare(response1, response2, args, i, j, k, l, evaluator):

    if os.path.exists(f"{args.scenario}/{args.eval_path}/{evaluator}/compare_results/{i}_{j}_{k}_{l}.json"):
        if args.verbose: print(f"\ncompare results for {evaluator} {i}_{j}_{k}_{l} already exists, skipping generation.")
        with open(f"{args.scenario}/{args.eval_path}/{evaluator}/compare_results/{i}_{j}_{k}_{l}.json", "r") as f:
            save_dict = json.load(f)
        answers = save_dict["responses"]
        flipped = save_dict["flipped"]

        if len(answers) != args.num_evals:
            raise Exception("number of answers does not match the number of evaluations.")

        if True:
            pass
            # # below this is augmented code
            # print(f"{i}_{j}_{k}_{l}\n")

            # user_prompt_path = f"{args.scenario}/{args.eval_path}/{evaluator}/eval_user.txt"
            # system_prompt_path = f"{args.scenario}/{args.eval_path}/{evaluator}/eval_system.txt"

            # # read prompts
            # with open(system_prompt_path, "r") as f:
            #     system_prompt = f.read()
            # with open(user_prompt_path, "r") as f:
            #     user_prompt = f.read()

            # if flipped: 
            #     # do the opposite
            #     user_prompt = user_prompt.replace("SCENARIO1", response1)
            #     user_prompt = user_prompt.replace("SCENARIO2", response2)
            # else:
            #     user_prompt = user_prompt.replace("SCENARIO1", response2)
            #     user_prompt = user_prompt.replace("SCENARIO2", response1)
            
            # user_prompt = turn_into_input_format("user", user_prompt)
            # system_prompt = turn_into_input_format("system", system_prompt)

            # opposite_answers = query_problem(args, system_prompt, user_prompt, temperature=1, top_p=1, n=args.num_evals, verbose=args.verbose)
            # opposite_answers = [answer["message"]["content"] for answer in opposite_answers[0]]

            # # save
            # save_dict["responses"] = answers + opposite_answers
            # save_dict["flipped"] = [flipped, not flipped]
            # save_dict["system_prompt"] = [save_dict["system_prompt"], system_prompt]
            # save_dict["user_prompt"] = [save_dict["user_prompt"], user_prompt]
            # parsed_answers = [parse_answer(args, answer) for answer in opposite_answers]
            # save_dict["parsed_answers"] = save_dict["parsed_answers"] + parsed_answers

            # with open(f"{args.scenario}/{args.eval_path}/{evaluator}/compare_results/{i}_{j}_{k}_{l}.json", "w") as f:
            #     json.dump(save_dict, f)
        
        results = save_dict["parsed_answers"]
        if flipped:
            results = [(1 - result) for result in results]
        
        return np.mean(results)

    else:
        # raise Exception(f"this shouldn't happen, {evaluator} {i}_{j}_{k}_{l}")
        print(f"{i}_{j}_{k}_{l}\n")

        user_prompt_path = f"{args.scenario}/{args.eval_path}/{evaluator}/eval_user.txt"
        system_prompt_path = f"{args.scenario}/{args.eval_path}/{evaluator}/eval_system.txt"

        # read prompts
        with open(system_prompt_path, "r") as f:
            system_prompt = f.read()
        with open(user_prompt_path, "r") as f:
            user_prompt = f.read()
        
        if random.random() < 0.5:
            user_prompt = user_prompt.replace("SCENARIO1", response1)
            user_prompt = user_prompt.replace("SCENARIO2", response2)
            flipped = False
        else: 
            user_prompt = user_prompt.replace("SCENARIO1", response2)
            user_prompt = user_prompt.replace("SCENARIO2", response1)
            flipped = True

        user_prompt = turn_into_input_format("user", user_prompt)
        system_prompt = turn_into_input_format("system", system_prompt)

        answers = query_problem(args, system_prompt, user_prompt, temperature=1, top_p=1, n=args.num_evals, verbose=args.verbose)
        answers = [answer["message"]["content"] for answer in answers[0]]

        # save
        save_dict = {}
        save_dict["indices"] = [i, j, k, l]
        save_dict["responses"] = answers
        save_dict["flipped"] = flipped
        save_dict["system_prompt"] = system_prompt
        save_dict["user_prompt"] = user_prompt
        save_dict["parsed_answers"] = [parse_answer(args, answer) for answer in answers]

        if not os.path.exists(f"{args.scenario}/{args.eval_path}/{evaluator}/compare_results"):
            os.mkdir(f"{args.scenario}/{args.eval_path}/{evaluator}/compare_results")
        with open(f"{args.scenario}/{args.eval_path}/{evaluator}/compare_results/{i}_{j}_{k}_{l}.json", "w") as f:
            json.dump(save_dict, f)
    
    results = save_dict["parsed_answers"]
    if flipped:
        results = [(1 - result) for result in results]
    
    return np.mean(results)


def generate_stakeholder_prompts(args, evaluator):
    
    if os.path.exists(f"{args.scenario}/{args.eval_path}/{evaluator}/stakeholder_prompts.json"):
        print(f"stakeholder prompts for {evaluator} already exist, skipping generation.")
        return

    with open(f"{args.scenario}/{args.setting_path}", "r") as f:
        setting_text = f.read()

    with open(f"{args.scenario}/{args.stakeholder_systempromptgen_path}", "r") as f:
        stakeholder_systempromptgen_text = f.read()
        stakeholder_systempromptgen_text = stakeholder_systempromptgen_text.replace("STAKEHOLDER", evaluator)
        stakeholder_systempromptgen_text = stakeholder_systempromptgen_text.replace("SCENARIO", setting_text)
    
    system_prompt = turn_into_input_format("system", stakeholder_systempromptgen_text)

    # user prompt
    with open(f"{args.scenario}/{args.stakeholder_userpromptgen_path}", "r") as f:
        stakeholder_userpromptgen_text = f.read()
        stakeholder_userpromptgen_text = stakeholder_userpromptgen_text.replace("STAKEHOLDER", evaluator)
        stakeholder_userpromptgen_text = stakeholder_userpromptgen_text.replace("SCENARIO", setting_text)

    user_prompt = turn_into_input_format("user", stakeholder_userpromptgen_text)

    with open(f"{args.scenario}/{args.eval_path}/{evaluator}/stakeholder_prompts.json", "w") as f:
        json.dump([system_prompt, user_prompt], f)


def query_stakeholder_prompts(args, evaluator):

    if os.path.exists(f"{args.scenario}/{args.eval_path}/{evaluator}/eval_system.txt") and \
       os.path.exists(f"{args.scenario}/{args.eval_path}/{evaluator}/eval_user.txt"):
        print(f"eval prompts for {evaluator} already exist, skipping generation.")
        return
    
    if not os.path.exists(f"{args.scenario}/{args.eval_path}/{evaluator}/stakeholder_prompts.json"):
        raise ValueError(f"stakeholder prompts for {evaluator} do not exist, please generate them first.")

    with open(f"{args.scenario}/{args.eval_path}/{evaluator}/stakeholder_prompts.json", "r") as f:
        prompts = json.load(f)
        system_prompt = prompts[0]
        user_question_prompt = prompts[1]
    
    generated_system_prompt = query_problem(args, None, system_prompt, temperature=1, top_p=1, n=1, verbose=True)[0][0]["message"]["content"].strip()
    generated_user_question = query_problem(args, None, user_question_prompt, temperature=1, top_p=1, n=1, verbose=True)[0][0]["message"]["content"].strip()

    with open(f"{args.scenario}/{args.eval_path}/{evaluator}/eval_system.txt", "w") as f:
        f.write(generated_system_prompt)

    with open(f"{args.scenario}/{args.stakeholder_userprompttemplate_path}", "r") as f:
        stakeholder_userprompttemplate_text = f.read()
        stakeholder_userprompt = stakeholder_userprompttemplate_text.replace("STAKEHOLDER_QUESTION", generated_user_question)
    
    # currently, cf is not supported here
    # if args.counterfactual_path != "":
    #     with open(f"{args.scenario}/{args.counterfactual_existing_scenario_path}", "r") as f:
    #         stakeholder_userprompttemplate_text = stakeholder_userprompttemplate_text.replace("EXISTING_SCENARIO", f.read())
        
    #     with open(f"{args.scenario}/{args.counterfactual_existing_outcome_path}", "r") as f:
    #         stakeholder_userprompttemplate_text = stakeholder_userprompttemplate_text.replace("EXISTING_RESULT", f.read())
        
    # else:
    #     stakeholder_userprompttemplate_text = stakeholder_userprompttemplate_text.replace("EXISTING_SCENARIO", "")
    #     stakeholder_userprompttemplate_text = stakeholder_userprompttemplate_text.replace("EXISTING_RESULT", "")
    
    with open(f"{args.scenario}/{args.eval_path}/{evaluator}/eval_user.txt", "w") as f:
        f.write(stakeholder_userprompt)


def generate_stakeholders(args):
    generate_stakeholder_query_prompt(args)
    query_stakeholder_query_prompt(args)


def generate_stakeholder_query_prompt(args):
    if not os.path.exists(f"{args.scenario}/stakeholder_prompts"):
        os.mkdir(f"{args.scenario}/stakeholder_prompts")

    if os.path.exists(f"{args.scenario}/stakeholder_prompts/stakeholder_query_prompt.json"):
        print("stakeholder query prompt already exists, skipping generation.")
        return

    with open(f"{args.scenario}/{args.setting_path}", "r") as f:
        setting_text = f.read()

    with open(f"{args.scenario}/{args.stakeholders_path}", "r") as f:
        stakeholders_text = f.read()
        stakeholders_text = stakeholders_text.replace("NUM_STAKEHOLDERS", str(args.num_stakeholders))
    
    prompt = setting_text + stakeholders_text
    prompt = turn_into_input_format("user", prompt)

    with open(f"{args.scenario}/stakeholder_prompts/stakeholder_query_prompt.json", "w") as f:
        json.dump(prompt, f)


def query_stakeholder_query_prompt(args):
    if not os.path.exists(f"{args.scenario}/evaluators"):
        os.mkdir(f"{args.scenario}/evaluators")

    if len(os.listdir(f"{args.scenario}/evaluators")) >= args.num_stakeholders:
        print("evaluators already exist, skipping generation.")
        return
    
    if not os.path.exists(f"{args.scenario}/stakeholder_prompts/stakeholder_query_prompt.json"):
        raise ValueError("stakeholder query prompt does not exist, please generate it first.")
    
    with open(f"{args.scenario}/stakeholder_prompts/stakeholder_query_prompt.json", "r") as f:
        prompt = json.load(f)

    answer = query_problem(args, None, prompt, temperature=1, top_p=1, n=1, verbose=True)[0][0]["message"]["content"]

    weights = []
    for line in answer.splitlines():
        if len(line) > 0 and line[0].isnumeric():
            full_line = line[2:]

            # remove everything after the first "("
            evaluator = re.sub(r"\(.*", "", full_line)
            evaluator = evaluator.lower().strip().replace(" ", "_")
            # keep only alphanumeric characters and "_"
            evaluator = re.sub(r"[^a-zA-Z0-9_]+", "", evaluator)

            # only take the text in parentheses
            weight = re.search(r'\((.*?)\)', full_line).group(1)
            # remove non-numeric
            weight = re.sub(r"[^0-9.]", "", weight)
            evaluator_dict = {"evaluator": evaluator, "weight": weight, "justification": full_line}
            weights.append(evaluator_dict)

            if evaluator in os.listdir(f"{args.scenario}/evaluators"):
                print(f"{evaluator} already exists, skipping.")
                continue
            else:
                print(f"Creating evaluator {evaluator}...")
            
            os.mkdir(f"{args.scenario}/evaluators/{evaluator}")
            generate_stakeholder_prompts(args, evaluator)
            query_stakeholder_prompts(args, evaluator)

    with open(f"{args.scenario}/evaluators/weights.json", "w") as f:
        json.dump(weights, f)