from generation import generate_advice, generate_advice_prompts, \
    generate_candidates_prompts, generate_candidates, \
    compare, generate_stakeholders
import argparse
import json
import numpy as np
import os
from baselines import baseline_ask, baseline_ask_with_reasoning
from eval import best_response_baseline_compare
from utils import save_as_csv, load_responses
from get_best import get_best_advice, get_best_response
import time
import random
from move_reasoning_to_pairwise import ablation_baseline_reasoning


def parse_args():
    parser = argparse.ArgumentParser()

    # construct the scenario
    parser.add_argument("--scenario", type=str) # name of the scenario -- barista, dating_app, etc. 
    parser.add_argument("--setting_path", type=str) # setting of the scenario

    # paths to the baseline advice generation prompts
    parser.add_argument("--baseline_ask_path", type=str)
    parser.add_argument("--baseline_ask_with_reasoning_path", type=str)
    parser.add_argument("--reasoning_prefix_path", type=str) # needed for reasoning prompt parsing
    parser.add_argument("--baseline_bad_advice_path", type=str)
    parser.add_argument("--baseline_irrelevant_advice_path", type=str)
    parser.add_argument("--baseline_neutral_advice_path", type=str)

    # paths to the framework advice generation prompts
    parser.add_argument("--unorthodox_advice_path", type=str)
    parser.add_argument("--normal_advice_path", type=str)

    # hyperparameters for advice generation
    parser.add_argument("--num_advice", type=int, default=3) # number of advice options to generate
    parser.add_argument("--num_unorthodox_advice", type=int, default=1) # number of unorthodox advice options to generate

    # paths to the response generation prompts
    parser.add_argument("--remember_advice_path", type=str) # text fill-in for remembering a piece of advice
    parser.add_argument("--second_person_scenario_path", type=str)
    parser.add_argument("--second_person_system_path", type=str)

    # hyperparameters for response generation
    parser.add_argument("--max_advice_per_generation", type=int, default=2) # max number of advice used
    parser.add_argument("--num_generations", type=int, default=5) # number of responses generated per advice set

    # paths to the stakeholder generation prompts
    parser.add_argument("--stakeholders_path", type=str)
    parser.add_argument("--stakeholder_systempromptgen_path", type=str)
    parser.add_argument("--stakeholder_userprompttemplate_path", type=str)
    parser.add_argument("--stakeholder_userpromptgen_path", type=str)

    # hyperparameters for stakeholder generation
    parser.add_argument("--num_stakeholders", type=int)

    # paths to the evaluation prompts
    parser.add_argument("--eval_path", type=str) # this is the name of the folder that contains the evaluators

    # hyperparameters for evaluation
    parser.add_argument("--num_evals", type=int, default=3) # number of evals for each pair of responses

    # paths to the counterfactual prompts
    parser.add_argument("--counterfactual_path", type=str, default="")
    parser.add_argument("--counterfactual_existing_scenario_path", type=str, default="")
    parser.add_argument("--counterfactual_existing_outcome_path", type=str, default="")

    # other general arguments
    parser.add_argument("--credential_path", type=str) # path to the OpenAI API key
    parser.add_argument("--verbose", type=str, default="False")

    args = parser.parse_args()
    if args.verbose == "False": args.verbose = False
    else: args.verbose = True

    return args


def get_comparison_matrices(args):
    if not os.path.exists(f"{args.scenario}/{args.eval_path}"):
        raise Exception(f"Eval folder {args.scenario}/{args.eval_path} does not exist.")
    
    for evaluator in (os.listdir(f"{args.scenario}/{args.eval_path}")):
        if evaluator[0] == "." or evaluator == "weights.json": # .DS_Store, weights
            continue
        elif f"{args.scenario}/{args.eval_path}/{evaluator}/MWU_matrix.npy" in os.listdir(f"{args.scenario}/{args.eval_path}/{evaluator}"):
            print(f"MWU matrix for {evaluator} already exists, skipping generation.")
            continue
        else:
            print(f"Generating MWU matrix for evaluator {evaluator}...")
            get_comparison_matrix(args, evaluator, ["unorthodox_advice"])
        


def get_comparison_matrix(args, path, types):
    # load the responses
    advice_type_responses = []
    total_entries = 0
    mappings = []
    for type in types:
        # if args.scenario == "barista": # pairwise evaluation data was queried before sorting combinatorial advice -- not anymore! 
        #     responses, mapping = load_responses(f"{args.scenario}/responses/{type}", sorted=False)
        # else: 
        responses, mapping = load_responses(f"{args.scenario}/candidates/{type}")
        advice_type_responses += responses
        total_entries += len(responses)
        mappings.append(mapping)

    with open(f"{args.scenario}/{args.eval_path}/{path}/mappings.json", "w") as f:
        json.dump(mappings, f)

    MWU_matrix = np.zeros((total_entries, total_entries, args.num_generations, args.num_generations))
    # dim 0 index = first choice in comparison
    # dim 1 index = second choice in comparison
    # dim 2 index = generation number for first choice
    # dim 3 index = generation number for second choice
    # 1: i is better than j, 0: j is better than i, 0.5: tie

    for i in range(total_entries):
        for k in range(args.num_generations):
            first_response = advice_type_responses[i][k]
            for j in range(i, total_entries): # we compare every pair only once
                for l in range(args.num_generations):
                    if i == j and k >= l:
                        continue
                    second_response = advice_type_responses[j][l]
                    comparison = compare(first_response, second_response, args, i, j, k, l, path)
                    
                    MWU_matrix[i][j][k][l] = comparison
                    MWU_matrix[j][i][l][k] = 1 - MWU_matrix[i][j][k][l]

    np.save(f"{args.scenario}/{args.eval_path}/{path}/MWU_matrix.npy", MWU_matrix)


def main(args):
    start_time = time.time()
    random.seed(0)

    # # step 1.1: generate prompts for getting advice, saved at scenario/advice_prompts.json
    # generate_advice_prompts(args)

    # # step 1.2: generate advice, saved at scenario/advice/{advice_type}.json
    # generate_advice(args)

    # # step 2.1: generate baseline non-advice-based messages
    # baseline_ask_generations = baseline_ask(args)
    # baseline_ask_with_reasoning_generations = baseline_ask_with_reasoning(args)
    # save_as_csv(baseline_ask_with_reasoning_generations, f"{args.scenario}/ablation_generations.csv")
    # ablation_baseline_reasoning(args)

    # # step 2.2: generate baseline advice-based messages
    # generate_candidates_prompts(args, ["baseline_irrelevant_advice"])
    # generate_candidates_prompts(args, ["baseline_neutral_advice"])

    # baseline_irrelevant_advice_generations = generate_candidates(args, "baseline_irrelevant_advice")
    # baseline_neutral_advice_generations = generate_candidates(args, "baseline_neutral_advice")
    # print(f"Baselines complete - time elapsed: {time.time() - start_time} seconds\n")

    # # step 2.3: generate the message candidates for helpful advice
    # generate_candidates_prompts(args, ["normal_advice", "unorthodox_advice"])
    # helpful_advice_generations = generate_candidates(args, "unorthodox_advice")
    # print(f"Framework generations complete - time elapsed: {time.time() - start_time} seconds\n")

    # print(f"baseline ask: {len(baseline_ask_generations)} entries\n")
    # print(f"baseline ask with reasoning: {len(baseline_ask_with_reasoning_generations)} entries\n")
    # print(f"baseline irrelevant advice: {len(baseline_irrelevant_advice_generations)} entries\n")
    # print(f"baseline neutral advice: {len(baseline_neutral_advice_generations)} entries\n")
    # print(f"normal and unorthodox advice: {len(helpful_advice_generations)} entries\n")

    # all_generations = baseline_ask_generations + \
    #                   baseline_ask_with_reasoning_generations + \
    #                   baseline_irrelevant_advice_generations + \
    #                   baseline_neutral_advice_generations + \
    #                   helpful_advice_generations
    
    # save_as_csv(all_generations, f"{args.scenario}/generations.csv")
    # print(f"Responses generated and saved - time elapsed: {time.time() - start_time} seconds\n")

    # # step 2.5 (only if necessary for the setting): generate the stakeholders
    # if args.stakeholders_path != "":
    #     generate_stakeholders(args)
    # print(f"Stakeholders generated - time elapsed: {time.time() - start_time} seconds\n")

    # # step 3: conduct pairwise comparisons
    get_comparison_matrices(args)
    # print(f"Pairwise comparisons complete - time elapsed: {time.time() - start_time} seconds\n")

    # step 4: make conclusions based on the comparisons
    best_advice = get_best_advice(args)
    best_response = get_best_response(args)
    print(f"Best advice and response found - time elapsed: {time.time() - start_time} seconds\n")

    # step 5: evaluate how well the concept performs
    # compare the best response to the baseline responses
    # best_response_baseline_compare(args, best_response)
    # print(f"Best response compared to baselines - time elapsed: {time.time() - start_time} seconds\n")

if __name__ == "__main__":
    args = parse_args()
    main(args)