import os
import json
import numpy as np


def get_advice_from_combination_index(args, index, evaluator):
    # returns the best combination of advice as a list of indices, and the corresponding advice texts
    with open(f"{args.scenario}/advice/normal_advice.json", "r") as f:
        advice_options = json.load(f)
    with open(f"{args.scenario}/advice/unorthodox_advice.json", "r") as f:
        advice_options += json.load(f)
    
    with open(f"{args.scenario}/{args.eval_path}/{evaluator}/mappings.json", "r") as f:
        mapping = json.load(f)[0]
    
    best_combination_value = mapping[str(index)][:-5] # remove .json

    if best_combination_value == "":
        return [], []

    combination = []
    advice_texts = []
    for index in best_combination_value.split("_"):
        combination.append(int(index))
        advice_texts.append(advice_options[int(index)])
    
    return combination, advice_texts


def get_best_advice(args):
    print("\nGetting best advice...")
    advice_scores_across_stakeholders = []
    has_weights = False
    evaluators = []

    for path in (os.listdir(f"{args.scenario}/{args.eval_path}")):
        if path[0] == ".": # .DS_Store
            continue
        elif path == "weights.json": 
            with open (f"{args.scenario}/{args.eval_path}/{path}", "r") as f:
                weights_dict = json.load(f)
            has_weights = True
        else:
            print(f"\nFor evaluator {path}:")
            evaluators.append(path)
            MWU_matrix = np.load(f"{args.scenario}/{args.eval_path}/{path}/MWU_matrix.npy")

            # MWU_matrix[[3, 6, 8, 9], :, :, :] = 0
            # MWU_matrix[:, [3, 6, 8, 9], :, :] = 0

            advice_scores = np.sum(MWU_matrix, axis=(1,2,3))
            # here, another option is to sum over axis (2,3) and then do a simple 3-way sorting
            # the reasoning here is that, if one option is better than both others, it will win in this sum too
            # and if there is a three-way tie (e.g., 1>2, 2>3, 3>1), then we will want the best-performing sum anyway
            print(advice_scores)
            advice_scores_across_stakeholders.append(advice_scores)
            best_advice = np.argmax(advice_scores)

            combination, texts = get_advice_from_combination_index(args, best_advice, path)
            print(f"Best combination: {combination}\n{texts}")

    scores = np.zeros((len(advice_scores_across_stakeholders[0]))) # shape = (num_advice_combinations,)

    for i in range(len(evaluators)):
        evaluator = evaluators[i]
        
        # get the weight
        if not has_weights:
            weight = 1/len(evaluators)
        else:
            found_weight = False
            for d in weights_dict:
                if d["evaluator"] == evaluator:
                    weight = int(d["weight"])/100
                    found_weight = True
                    break
            
            if not found_weight:
                raise Exception(f"Weight for evaluator {evaluator} not found.")
        
        scores += advice_scores_across_stakeholders[i] * weight
    
    best_overall_advice = np.argmax(scores)
    combination, texts = get_advice_from_combination_index(args, best_overall_advice, evaluators[0])
    # path here is just using any evaluator, order is all the same since it is sorted

    print(f"Best overall combination: {combination}\n{texts}")
    return best_overall_advice # index of best combination


def get_best_response(args):
    print("\nGetting best response...")
    response_scores_across_stakeholders = []
    has_weights = False
    MWU_matrices = {}
    evaluators = []

    for path in (os.listdir(f"{args.scenario}/{args.eval_path}")):
        if path[0] == ".": # .DS_Store
            continue
        elif path == "weights.json": 
            with open (f"{args.scenario}/{args.eval_path}/{path}", "r") as f:
                weights_dict = json.load(f)
            has_weights = True
        else:
            print(f"\nFor evaluator {path}:")
            evaluators.append(path)
            MWU_matrix = np.load(f"{args.scenario}/{args.eval_path}/{path}/MWU_matrix.npy")

            # augmenting code to ignore non-3
            # first, zero out dim 0 and dim 1 index 3, 6, 8, 9
            # MWU_matrix[[3, 6, 8, 9], :, :, :] = 0
            # MWU_matrix[:, [3, 6, 8, 9], :, :] = 0

            MWU_matrices[path] = MWU_matrix

            response_scores = np.sum(MWU_matrix, axis=(1,3)) # (num_advice_combinations, num_generations)
            if args.verbose: print(response_scores)
            response_scores = response_scores.flatten() # (num_advice_combinations * num_generations,) -> each generation
            response_scores_across_stakeholders.append(response_scores) # (num_stakeholders, num_advice_combinations * num_generations)

            best_response = np.argmax(response_scores)

            best_advice_index = best_response // args.num_generations
            best_response_index = best_response % args.num_generations

            with open(f"{args.scenario}/{args.eval_path}/{path}/mappings.json", "r") as f:
                mapping = json.load(f)[0]
            filename = mapping[str(best_advice_index)]

            with open(f"{args.scenario}/candidates/unorthodox_advice/{filename}", "r") as f:
                best_response_text = json.load(f)[best_response_index]

            print(f"Best response: {best_response_text}")
    
    scores = np.zeros((len(response_scores_across_stakeholders[0]))) # (num_advice_combinations * num_generations,)
    MWU_matrix_weighted = np.zeros(MWU_matrix.shape)

    for i in range(len(evaluators)):
        evaluator = evaluators[i]
        
        # get the weight
        if not has_weights:
            weight = 1/len(evaluators)
        else:
            found_weight = False
            for d in weights_dict:
                if d["evaluator"] == evaluator:
                    weight = int(d["weight"])/100
                    found_weight = True
                    break

            if not found_weight:
                raise Exception(f"Weight for evaluator {evaluator} not found.")
    
        scores += response_scores_across_stakeholders[i] * weight
        MWU_matrix_weighted += MWU_matrices[evaluator] * weight
        
    print(f"Response scores across stakeholders: {scores}")

    with open(f"{args.scenario}/best_response_scores.json", "w") as f:
        json.dump(scores.tolist(), f)

    best_response = np.argmax(scores)
    best_advice_index = best_response // args.num_generations
    best_response_index = best_response % args.num_generations
    filename = mapping[str(best_advice_index)]

    with open(f"{args.scenario}/candidates/unorthodox_advice/{filename}", "r") as f:
        best_response_text = json.load(f)[best_response_index]

    print(f"Best overall response: advice {filename} {best_response_index}, {best_response_text}")

    # export the weighted MWU matrix
    np.save(f"{args.scenario}/MWU_matrix_weighted.npy", MWU_matrix_weighted)

    return best_response_text