import argparse
from utils import get_rating_from_azure_chat_call, int_to_word, get_response_from_azure_chat_call, parse_response, parse_method_response, reparse, separate_zero_shot_cot, fix_key_names, fix_key_names_zero_shot
import json, os, random, string
from tqdm import tqdm

def load_statements():
    path = f"apartments/stimuli/positive_traits.txt"
    with open(path, "r") as f:
        positive_traits = [line.strip() for line in f]
    
    path = f"apartments/stimuli/negative_traits.txt"
    with open(path, "r") as f:
        negative_traits = [line.strip() for line in f]
    
    path = f"apartments/stimuli/neutral_positive_traits.txt"
    with open(path, "r") as f:
        neutral_positive_traits = [line.strip() for line in f]

    path = f"apartments/stimuli/neutral_negative_traits.txt"
    with open(path, "r") as f:
        neutral_negative_traits = [line.strip() for line in f]

    return {
        "positive_traits": positive_traits,
        "negative_traits": negative_traits,
        "neutral_positive_traits": neutral_positive_traits,
        "neutral_negative_traits": neutral_negative_traits
    }


def rate_statements(args, statements):
    all_ratings = {}
    for trait_type, trait_list in statements.items():
        if os.path.exists(f"apartments/ratings/{trait_type}.json"):
            print(f"Skipping rating for {trait_type} because it already exists")
            with open(f"apartments/ratings/{trait_type}.json", "r") as f:
                ratings = json.load(f)
            all_ratings[trait_type] = ratings
            continue

        ratings = {}
        if args.verbose: print(f"Rating {trait_type}")
        for index in tqdm(range(len(trait_list))):
            statement = trait_list[index]
            prompt = f"The following statement describes a particular apartment. Please rate how this statement affects the desirability of the apartment for the average tenant, from -5 to 5, with 5 being most desirable. Do not include anything other than the rating.\nStatement: {statement}"

            response = get_rating_from_azure_chat_call(args.model, prompt, max_tokens=10)
            if args.verbose: print(f"Rating for {trait_type} statement '{statement}': {response}")
            ratings[index] = response
        
        with open(f"apartments/ratings/{trait_type}.json", "w") as f:
            json.dump(ratings, f)
        
        all_ratings[trait_type] = ratings
    
    return all_ratings
    

def generate_apartments(args, statements, ratings):
    apartments = []
    for i in range(args.num_apartments):
        apartment = {}
        trait_ratings = []
        traits = []
        for trait_index in range(80):
            trait_outcome = random.choice(list(statements.keys()))
            trait_statement = statements[trait_outcome][trait_index]
            trait_rating = ratings[trait_outcome][str(trait_index)]
            traits.append(trait_statement)
            trait_ratings.append(trait_rating)
        
        apartment["traits"] = traits
        apartment["trait_ratings"] = trait_ratings
        apartment["overall_rating"] = sum(trait_ratings) / len(trait_ratings)
        apartment["index"] = i
        
        apartments.append(apartment)
    
    return apartments


def generate_dataset(args, apartments):
    if os.path.exists(f"{args.dataset_dir}.json"):
        print(f"Dataset already exists, loading from {args.dataset_dir}.json")
        with open(f"{args.dataset_dir}.json", "r") as f:
            return json.load(f)
    
    dataset = []
    if args.verbose: print("Generating dataset")
    for i in tqdm(range(args.n)):

        valid_sample = False
        while not valid_sample:
            sample = random.sample(apartments, args.num_apartments_per_problem)
            overall_ratings = [apartment["overall_rating"] for apartment in sample]
            sorted_ratings = sorted(overall_ratings)
            highest_rating = sorted_ratings[-1]
            second_highest_rating = sorted_ratings[-2]
            if (highest_rating - second_highest_rating > args.min_rating_difference) and (highest_rating - second_highest_rating < args.max_rating_difference):
                # there is a clear best apartment
                valid_sample = True
            best_apartment_index = overall_ratings.index(highest_rating)

        problem = {
            "index": i,
            "apartments": sample,
            "best_apartment_index": best_apartment_index, 
            "apartment_scores": overall_ratings
        }
        dataset.append(problem)

        if args.verbose:
            print(f"Generated problem {i}")
            print(f"Apartment scores: {overall_ratings}")

    with open(f"{args.dataset_dir}.json", "w") as f:
        json.dump(dataset, f)
    
    return dataset


def construct_problem_prompts(args, problem):
    sentences = []
    for apartment_index in range(len(problem["apartments"])):
        apartment = problem["apartments"][apartment_index]
        for trait_index in range(len(apartment["traits"])):
            trait = apartment["traits"][trait_index]
            trait_rating = apartment["trait_ratings"][trait_index]
            rating_sentence = f"On a scale of -5 to 5, with 5 being most desirable, the tenant thinks that an apartment with this description would have a desirability of {trait_rating}."
            new_sentence = f"Apartment {apartment_index + 1}: {trait} {rating_sentence}"
            sentences.append(new_sentence)

    base_prompt = f"""You are an AI assistant designed to evaluate the desirability of {int_to_word(args.num_apartments_per_problem)} apartments for a potential tenant. You will be given a list of statements about the apartment candidates and how much the tenant likes or dislikes an apartment with the quality described by the statement. Your task is to determine which apartment is the most desirable based on the given criteria.

    The statements are as follows:
    {sentences}

    Which apartment is most desirable to the tenant?"""
    zero_shot_prompt = base_prompt + " Respond with only the number of the apartment, do not include anything else."
    cot_prompt = base_prompt + " Let's think step by step."

    return zero_shot_prompt, cot_prompt


def construct_method_prompt(args, problem, prompt_method):

    if prompt_method.startswith("no_ratings_"):
        return construct_no_ratings_prompt(args, problem, prompt_method[11:])

    sentences = []
    for apartment_index in range(len(problem["apartments"])):
        apartment = problem["apartments"][apartment_index]
        for trait_index in range(len(apartment["traits"])):
            trait = apartment["traits"][trait_index]
            trait_rating = apartment["trait_ratings"][trait_index]
            rating_sentence = f"On a scale of -5 to 5, with 5 being most desirable, the tenant thinks that an apartment with this description would have a desirability of {trait_rating}."
            new_sentence = f"Apartment {apartment_index + 1}: {trait} {rating_sentence}"
            sentences.append(new_sentence)

    base_prompt = f"""You are an AI assistant designed to evaluate the desirability of {int_to_word(args.num_apartments_per_problem)} apartments for a potential tenant. You will be given a list of statements about the apartment candidates and how much the tenant likes or dislikes an apartment with the quality described by the statement. Your task is to determine which apartment is the most desirable based on the given criteria.

    The statements are as follows:
    {sentences}

    Which apartment is most desirable to the tenant?"""

    if prompt_method == "distractor_nback":

        random_letters = random.sample(string.ascii_lowercase, 10)
        list_of_letters = random.choices(random_letters, k=160)
        prompt = base_prompt + f"\n\nNext, consider the following list of letters: {list_of_letters}. Before answering the question about the apartment, please return in a list format indicating whether each letter (starting from the third letter) was the same as the letter two places before it in the list. Respond with a list of true/false values. Then, answer the question about the apartment using an number between 1 and {args.num_apartments_per_problem}. Respond with 'Apartment ' followed by the number, do not include anything else."
    
    elif prompt_method == "very_carefully_think": 
        # this is the phrase from the paper, three minutes were given to the participants. 
        prompt = base_prompt + "For three minutes, very carefully think about what you think of each of the {args.num_apartments_per_problem} apartments. Then, answer the question about the apartment using an number between 1 and {args.num_apartments_per_problem}. Respond with 'Apartment ' followed by the number, do not include anything else."
        
    elif prompt_method == "zero_shot":
        prompt = base_prompt + " Respond with only the number of the apartment, do not include anything else."

    elif prompt_method == "cot":
        prompt = base_prompt + " Let's think step by step."
    
    elif prompt_method == "cot_time_limit":
        prompt = base_prompt + " You have three minutes to think about the problem. Let's think step by step."

    else: 
        raise ValueError(f"Prompt method {prompt_method} not recognized")

    return prompt


def construct_no_ratings_prompt(args, problem, prompt_method):
    sentences = []
    for apartment_index in range(len(problem["apartments"])):
        apartment = problem["apartments"][apartment_index]
        for trait_index in range(len(apartment["traits"])):
            trait = apartment["traits"][trait_index]
            new_sentence = f"Apartment {apartment_index + 1}: {trait}"
            sentences.append(new_sentence)
    
    base_prompt = f"""You are an AI assistant designed to evaluate the desirability of {int_to_word(args.num_apartments_per_problem)} apartments for the average tenant. You will be given a list of statements about the apartment candidates. Your task is to determine which apartment is the most desirable based on the given criteria.

    The statements are as follows:
    {sentences}

    Which apartment is most desirable to the tenant?"""
    
    if prompt_method == "zero_shot":
        prompt = base_prompt + " Respond with only the number of the apartment, do not include anything else."
    elif prompt_method == "cot":
        prompt = base_prompt + " Let's think step by step."
    else:
        raise ValueError(f"Prompt method {prompt_method} not recognized")
    
    return prompt


def evaluate_model(args, dataset):
    results = []
    for problem_index in tqdm(range(args.start_index, args.end_index)):
        problem = dataset[problem_index]
        
        if os.path.exists(f"{args.output_dir}/{args.model}/zero_shot_cot/{problem['index']}.json"):
            print(f"Skipping problem {problem['index']} because it already exists")
            path = f"{args.output_dir}/{args.model}/zero_shot_cot/{problem['index']}.json"
            with open(path, 'r') as f:
                result = json.load(f)
            results.append(result)
            continue

        zero_shot_prompt, cot_prompt = construct_problem_prompts(args, problem)

        zero_shot_response = get_response_from_azure_chat_call(args.model, zero_shot_prompt, max_tokens=100)
        cot_response = get_response_from_azure_chat_call(args.model, cot_prompt, max_tokens=2500)
        parsed_zero_shot_response = parse_response(args, zero_shot_response, "zero_shot")
        parsed_cot_response = parse_response(args, cot_response, "cot")

        result = {
            "index": problem["index"],
            "apartments": problem["apartments"],
            "zero_shot_prompt": zero_shot_prompt,
            "cot_prompt": cot_prompt,
            "zero_shot_response": zero_shot_response,
            "cot_response": cot_response,
            "parsed_zero_shot_response": parsed_zero_shot_response,
            "parsed_cot_response": parsed_cot_response, 
            "best_apartment_index": problem["best_apartment_index"]
        }
        with open(f"{args.output_dir}/{args.model}/zero_shot_cot/{problem['index']}.json", "w") as f:
            json.dump(result, f)
        
        results.append(result)

    return results


def evaluate_new_prompt(args, dataset, prompt_method):
    print(f"Evaluating {prompt_method}")
    results = []
    for problem_index in tqdm(range(args.start_index, args.end_index)):
        problem = dataset[problem_index]
        
        if os.path.exists(f"{args.output_dir}/{args.model}/{prompt_method}/{problem['index']}.json"):
            print(f"Skipping problem {problem['index']} because it already exists")
            path = f"{args.output_dir}/{args.model}/{prompt_method}/{problem['index']}.json"
            with open(path, 'r') as f:
                result = json.load(f)
            results.append(result)
            continue

        prompt = construct_method_prompt(args, problem, prompt_method)
        
        response = get_response_from_azure_chat_call(args.model, prompt, max_tokens=2000)
        parsed_response = parse_method_response(args, response, prompt_method)

        result = {
            "index": problem["index"],
            "apartments": problem["apartments"],
            "prompt": prompt,
            "response": response,
            "parsed_response": parsed_response,
            "best_apartment_index": problem["best_apartment_index"],
            "apartment_scores": [sum(apartment["trait_ratings"]) for apartment in problem["apartments"]] if ("apartment_scores" not in problem) else problem["apartment_scores"]
        }

        if not os.path.exists(f"{args.output_dir}/{args.model}/{prompt_method}"):
            os.makedirs(f"{args.output_dir}/{args.model}/{prompt_method}")
        with open(f"{args.output_dir}/{args.model}/{prompt_method}/{problem['index']}.json", "w") as f:
            json.dump(result, f)
        
        results.append(result)

    counter = 0
    for result in results:
        if result["parsed_response"] == result["best_apartment_index"] and counter < 5:
            counter += 1
            # check more of the correct ones

            apartment_scores = [sum(apartment["trait_ratings"]) for apartment in result["apartments"]]
            print(apartment_scores)
            print(result["parsed_response"])
            print(result["best_apartment_index"])
            print(result["response"])
            print("\n")

    return results


def analyze_method_performance(results, prompt_method):
    correct = 0
    for result in results:
        # print(result)
        if result["parsed_response"] == result["best_apartment_index"]:
            correct += 1
    
    print(f"{prompt_method} accuracy: {correct / len(results)}")
    
    return correct


def main(args):
    # get the individual statements describing an apartment
    statements = load_statements()

    # get the ratings for the statements by the average tenant
    ratings = rate_statements(args, statements)

    # generate the apartments
    apartments = generate_apartments(args, statements, ratings)

    # generate the problems
    dataset = generate_dataset(args, apartments)

    # results = evaluate_model(args, dataset)
    results = evaluate_new_prompt(args, dataset, args.method)

    # calculate the performance
    performance = analyze_method_performance(results, args.method)

# thoughts: 
# can do personas and importances associated with personas generated by GPT
# this way there is subjective weighting, but we can include this in the prompt as a number or verbal "importance weight"
# use these to somehow get the ground truth, or ask the persona to give a weighting of a rough value
# just make sure that the options are sufficiently separated that the noise would not cause overlap



def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--start_index", type=int, default=0)
    parser.add_argument("--end_index", type=int, default=100)
    parser.add_argument("--dataset_dir", type=str, default="apartments/dataset_0.1_0.3")
    parser.add_argument("--output_dir", type=str, default="apartments/results_0.1_0.3")
    parser.add_argument("--min_rating_difference", type=float, default=0.1)
    parser.add_argument("--max_rating_difference", type=float, default=0.3)
    parser.add_argument("--method", type=str, default="cot", choices=["cot_time_limit", "distractor_nback", "very_carefully_think", "zero_shot", "cot", "no_ratings_zero_shot", "no_ratings_cot"])


    parser.add_argument("--model", type=str, default="gpt-4o")
    parser.add_argument("--verbose", action="store_true")
    parser.add_argument("--use_existing_results", action="store_true")
    parser.add_argument("--num_apartments", type=int, default=100)
    parser.add_argument("--num_apartments_per_problem", type=int, default=4)
    parser.add_argument("--n", type=int, default=100)
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()
    main(args)
    # reparse(args)
    # separate_zero_shot_cot()
    # fix_key_names_zero_shot()