import json
import os
import pandas as pd
import matplotlib.pyplot as plt
import argparse

def load_file_jsonl(path):
   with open(path) as f:
        return [json.loads(row) for row in f]
   
def plot_success_rate_with_partial_scores(df, solver, expert):
    df = df.sort_values(by='Average Success Rate')
    filtered_df = df[(df['Solver'] == solver) & (df['Expert'] == expert)]
    
    # Create subplots
    fig, ax1 = plt.subplots(figsize=(10, 6))

    # Plotting Success Rate
    ax1.bar(filtered_df['Puzzle'], filtered_df['Average Success Rate'], color='blue', label='Average Success Rate')
    ax1.set_xlabel("Puzzle")
    ax1.set_ylabel("Success Rate", color='blue')
    ax1.tick_params(axis='y', labelcolor='blue')
    ax1.set_ylim(0, 1.1)  # Keep success rate between 0 and 1
    
    # Creating a second y-axis for partial scores
    ax2 = ax1.twinx()
    ax2.plot(filtered_df['Puzzle'], filtered_df['Highest Partial Score'], color='green', marker='o', label='Highest Partial Score')
    ax2.plot(filtered_df['Puzzle'], filtered_df['Average Partial Score'], color='red', marker='x', label='Average Partial Score')
    ax2.set_ylabel("Partial Scores", color='green')
    ax2.tick_params(axis='y', labelcolor='green')
    
    # Adding legends
    fig.legend(loc="upper left", bbox_to_anchor=(0.1, 1), bbox_transform=ax1.transAxes)

    # Adjusting x-axis labels
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()

    # Show the plot
    plt.title(f"Success Rate and Partial Scores for Solver: {solver} and Expert: {expert}")
    plt.show()


def main(args):
    result_folder = args.result_folder
    all_info = {}
    for file in sorted(os.listdir(result_folder)):
        info = file.split("_")
        solver = info[0]
        if len(info) >= 2:
            expert = solver
        #expert = info[2] if len(info) >= 2 else solver
        all_info[f"{solver} SOLVER <-> {expert} EXPERT"] = {}
        overall_completed = 0
        overall_mistakes = 0
        overall_partial_scores = 0
        overall_average_conversation_length = 0
        overall_average_word_count = 0
        grand_total = 0
        for puzzle in os.listdir(os.path.join(result_folder, file)):
            total_runs = len(os.listdir(os.path.join(result_folder, file, puzzle)))
            success = 0
            mistakes_for_each_run = []
            word_count_each_run = []
            conversation_length_each_run = []
            partial_score = 0
            highest_partial_score = 0
            word_count = 0
            for run in os.listdir(os.path.join(result_folder, file, puzzle)):
                mistakes = 0
                solver_expert_messages = 0
                conversation = load_file_jsonl(os.path.join(result_folder, file, puzzle, run, "conversation.jsonl"))
                if len(conversation) < 2:
                    continue
                if "score" in conversation[-1] and conversation[-1]["score"] != "None":
                    partial_score += float(conversation[-1]["score"])
                    overall_partial_scores += float(conversation[-1]["score"])
                    highest_partial_score = max([highest_partial_score, float(conversation[-1]["score"])])
                else: # Just to deal with cases where there is no partial scoring
                    if conversation[-2]["value"].startswith("Puzzle successfully finished") or "value" in conversation[-1] and conversation[-1]["value"].startswith("Puzzle successfully finished"):
                        partial_score += 1
                        overall_partial_scores += 1
                        highest_partial_score = max([highest_partial_score, 1])

                if conversation[-2]["value"].startswith("Puzzle successfully finished") or "value" in conversation[-1] and conversation[-1]["value"].startswith("Puzzle successfully finished"):
                    success += 1
                    overall_completed += 1
                for message in conversation:
                    if message["from"] == "SOLVER" or message["from"] == "EXPERT":
                        solver_expert_messages += 1
                        if message['value']:
                            word_count += len(message["value"].split(" "))
                    if message["from"] == "ENVIRONMENT" and "value" in message and message["value"].startswith("That action seems to have been a mistake."):
                        mistakes += 1
                        overall_mistakes += 1

                mistakes_for_each_run.append(mistakes)
                conversation_length_each_run.append(solver_expert_messages // 2)
                word_count_each_run.append(word_count)
                overall_average_word_count += word_count
                overall_average_conversation_length += solver_expert_messages // 2
                grand_total += 1
            avg_mistakes = 0 if len(mistakes_for_each_run) == 0 else sum(mistakes_for_each_run) / len(mistakes_for_each_run)
            avg_conversation_length = 0 if len(conversation_length_each_run) == 0 else sum(conversation_length_each_run) / len(conversation_length_each_run)
            all_info[f"{solver} SOLVER <-> {expert} EXPERT"][puzzle] = {}
            all_info[f"{solver} SOLVER <-> {expert} EXPERT"][puzzle]["Average Success Rate"] = success / total_runs
            all_info[f"{solver} SOLVER <-> {expert} EXPERT"][puzzle]["Average Word Count"] = sum(word_count_each_run) / max(1, len(word_count_each_run))
            
            all_info[f"{solver} SOLVER <-> {expert} EXPERT"][puzzle]["Average Mistakes"] = avg_mistakes
            all_info[f"{solver} SOLVER <-> {expert} EXPERT"][puzzle]["Average Conversation Length"] = avg_conversation_length
            all_info[f"{solver} SOLVER <-> {expert} EXPERT"][puzzle]["Highest Partial Score"] = highest_partial_score
            all_info[f"{solver} SOLVER <-> {expert} EXPERT"][puzzle]["Average Partial Score"] = partial_score / total_runs
        all_info[f"{solver} SOLVER <-> {expert} EXPERT"][puzzle]["Overall Average Success Rate"] = overall_completed / max(1, grand_total)
        all_info[f"{solver} SOLVER <-> {expert} EXPERT"][puzzle]["Mistakes Overrall"] = overall_mistakes / max(1, grand_total)
        all_info[f"{solver} SOLVER <-> {expert} EXPERT"][puzzle]["Overall Average Partial Score"] = overall_partial_scores / max(1, grand_total)
        all_info[f"{solver} SOLVER <-> {expert} EXPERT"][puzzle]["Overall Average Conversation Length"] = overall_average_conversation_length / max(1, grand_total)
        all_info[f"{solver} SOLVER <-> {expert} EXPERT"][puzzle]["Overall Average Word Count"] = overall_average_word_count / max(1, grand_total)





    for metric in ["Partial Score", "Success Rate", "Conversation Length", "Word Count"]:
        table_rows = []
        results = pd.DataFrame()
        final_result_table = pd.DataFrame()
        for pair in all_info:
            chunks = pair.split(" ")
            solver, expert = chunks[0], chunks[-2]
            flattened_data = []
            for puzzle, stats in all_info[pair].items():
                row = {'Puzzle': puzzle}
                row.update(stats)
                flattened_data.append(row)

            # Create a DataFrame
            df = pd.DataFrame(flattened_data)
            df['Solver'] = solver
            df['Expert'] = expert
            results = pd.concat([results, df]).reset_index(drop=True)

            # Display the table
            df_row = {}
            solver_expert_pair = f"{solver[:-5]} & {expert[:-5]}"
            df_row['Agents'] = solver_expert_pair

            table_row = f"{solver[:-5]} & {expert[:-5]} & "
            for puzzle in ["Button", "Dog", "SimpleWire", "Who", "Led", "Memory", "KeyPad", "Password", "Colour", "Maze"]:
                
                stat = df[df["Puzzle"] == f"{puzzle}Puzzle"]['Average ' + metric]
                if len(stat) == 0:
                    stat = "?"
                else:
                    p_score = round(df[df["Puzzle"] == f"{puzzle}Puzzle"]['Average ' + metric].item(), 2)
                    stat = f"{p_score:.2f}"

                df_row[puzzle] = stat 
                table_row += f"{stat} & "

            overall_metric = f'{round(df[df["Overall Average " + metric].notnull()]["Overall Average " + metric].item(), 2):.2f}'
            df_row['overall'] = overall_metric
            df_row = pd.DataFrame(df_row, index=[0])
            final_result_table = pd.concat([df_row, final_result_table]).reset_index(drop=True)

            table_row += overall_metric
            table_rows.append(table_row)

        print(f"====={metric} Results=====")
        print(final_result_table.sort_values(by='overall'))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument("--result_folder", type=str, default='./outputs',
                        help="Path to folder which contains results")
    args = parser.parse_args()
    main(args)