import os
import numpy as np
# import matplotlib as mpl
# mpl.rcParams['text.usetex'] = True
# mpl.rcParams['text.latex.preamble'] = [r'\usepackage{amsmath}']
import matplotlib.pyplot as plt
from matplotlib import colors
from PIL import Image
from utils.util import (
    get_num_samples_per_test, get_cot_steps, load_config,
    get_multiple_choice_question, construct_save_dir, get_num_cot_steps_matrix
)
from utils.faithfulness import get_faith_x, get_answers_probs, get_faithfulness_matrix
from utils.openaiapi import construct_token_probabilities
from utils.data import titles as dataset_titles
from utils.results import get_ft_results_dict, get_results_dict, get_icl_results_dict

llm_titles = {
    'gpt-3.5-turbo-0125': 'GPT-3.5-Turbo-0125',
    'llama-3-8b-instruct': 'Llama-3-8B-Instruct',
}

axes_names = {
    'accuracy': 'Accuracy',
    'faithfulness': 'Faithfulness',
    'num_cot_steps': 'Number of CoT Steps (Mean)',
}

axes_scales = {
    'accuracy': 1,
    'faithfulness': 1,
    'num_cot_steps': 1,
}

run_name_titles = {
    'baseline_0': 'ZS',
    'baseline_1': 'ZS-CoT',
    'baseline_2': 'GTA',
    'baseline_3': 'DU',
    'approach_1': 'DF',
    'approach_2': 'SU',
    'approach_3': 'SF',
}

pos = {
    'tr': (8, 8),
    'tl': (-8, 8),
    'br': (8, -23),
    'bl': (-8, -23),
    't': (0, 10),
    'b': (0, -25),
    'r': (25, -5),
    'l': (-25, -5),
}

icl_xytexts = {
    'gpt-3.5-turbo-0125': {
        'aqua': { # put SU^c at (-20, 5)
            'baseline_1': (-45, 0),  # ZS-CoT
            'baseline_2': (-30, 5),  # GTA
            'baseline_3': (-15, 5),  # DU
            'baseline_3_c': (25, 5), # DU^c
            'approach_1': (-10, 10),  # DF
            'approach_1_c': (25, 5), # DF^c
            'approach_2': (-8, 8),  # SU
            'approach_2_c': (-20, 5),  # SU^c
            'approach_3': pos['t'],  # SF
            'approach_3_c': pos['t'],  # SF^c
        },
        'logiqa': {
            'baseline_1': (-20, -25),  # ZS-CoT
            'baseline_2': pos['t'],  # GTA
            'baseline_3': pos['t'],  # DU
            'baseline_3_c': (5, 10),  # DU^c
            'approach_1': pos['t'],  # DF
            'approach_1_c': (-10, -25),  # DF^c
            'approach_2': pos['t'],  # SU
            'approach_2_c': (15, 10),  # SU^c
            'approach_3': pos['t'],  # SF
            'approach_3_c': (0, -25),  # SF^c
        },
        'truthfulqa': {
            'baseline_1': pos['t'],  # ZS-CoT
            'baseline_2': pos['t'],  # GTA
            'baseline_3': pos['t'],  # DU
            'baseline_3_c': pos['tr'],  # DU^c
            'approach_1': pos['bl'],  # DF
            'approach_1_c': pos['t'],  # DF^c
            'approach_2': pos['r'],  # SU
            'approach_2_c': pos['r'],  # SU^c
            'approach_3': pos['t'],  # SF
            'approach_3_c': pos['l'],  # SF^c
        }
    },
    'gpt-4-0613': {
        'aqua': {
            'baseline_1': pos['t'],  # ZS-CoT
            'baseline_2': pos['r'],  # GTA
            'baseline_3': pos['l'],  # DU
            'baseline_3_c': pos['t'],  # DU^c
            'approach_1': pos['t'],  # DF
            'approach_1_c': pos['tl'],  # DF^c
            'approach_2': pos['t'],  # SU
            'approach_2_c': pos['r'],  # SU^c
            'approach_3': pos['l'],  # SF
            'approach_3_c': pos['tl'],  # SF^c
        },
        'logiqa': {
            'baseline_1': pos['t'],  # ZS-CoT
            'baseline_2': pos['r'],  # GTA
            'baseline_3': pos['l'],  # DU
            'baseline_3_c': pos['l'],  # DU^c
            'approach_1': pos['t'],  # DF
            'approach_1_c': pos['l'],  # DF^c
            'approach_2': pos['t'],  # SU
            'approach_2_c': pos['r'],  # SU^c
            'approach_3': pos['l'],  # SF
            'approach_3_c': pos['tl'],  # SF^c
        },
        'truthfulqa': {
            'baseline_1': (-10, 10),  # ZS-CoT
            'baseline_2': pos['t'],  # GTA
            'baseline_3': pos['t'],  # DU
            'baseline_3_c': pos['t'],  # DU^c
            'approach_1': pos['t'],  # DF
            'approach_1_c': pos['tl'],  # DF^c
            'approach_2': pos['t'],  # SU
            'approach_2_c': pos['r'],  # SU^c
            'approach_3': pos['l'],  # SF
            'approach_3_c': pos['tl'],  # SF^c
        }
    },
    'llama-3-8b-instruct': {
        'aqua': {
            'baseline_1': pos['t'],  # ZS-CoT
            'baseline_2': pos['t'],  # GTA
            'baseline_3': pos['t'],  # DU
            'baseline_3_c': pos['t'],  # DU^c
            'approach_1': pos['l'],  # DF
            'approach_1_c': pos['t'],  # DF^c
            'approach_2': pos['t'],  # SU
            'approach_2_c': pos['tr'],  # SU^c
            'approach_3': pos['t'],  # SF
            'approach_3_c': pos['tl'],  # SF^c
        },
        'logiqa': {
            'baseline_1': (45, -5),  # ZS-CoT
            'baseline_2': pos['t'],  # GTA
            'baseline_3': pos['t'],  # DU
            'baseline_3_c': pos['t'],  # DU^c
            'approach_1': pos['t'],  # DF
            'approach_1_c': pos['t'],  # DF^c
            'approach_2': pos['t'],  # SU
            'approach_2_c': pos['t'],  # SU^c
            'approach_3': pos['t'],  # SF
            'approach_3_c': pos['t'],  # SF^c
        },
        'truthfulqa': {
            'baseline_1': pos['b'],  # ZS-CoT
            'baseline_2': pos['t'],  # GTA
            'baseline_3': pos['r'],  # DU
            'baseline_3_c': pos['t'],  # DU^c
            'approach_1': pos['tl'],  # DF
            'approach_1_c': pos['t'],  # DF^c
            'approach_2': pos['tr'],  # SU
            'approach_2_c': pos['t'],  # SU^c
            'approach_3': pos['t'],  # SF
            'approach_3_c': pos['b'],  # SF^c
        }
    }
}

ft_xytexts = {
    'gpt-3.5-turbo-0125': {
        'aqua': {
            'baseline_1': (40, 5),  # ZS-CoT
            'baseline_2': (-20, 10),  # GTA
            'baseline_3': pos['l'],  # DU
            'baseline_3_c': pos['l'],  # DU^c
            'approach_1': pos['l'],  # DF
            'approach_1_c': (25, -10),  # DF^c
            'approach_2': pos['tr'],  # SU
            'approach_2_c': pos['l'],  # SU^c
            'approach_3': pos['l'],  # SF
            'approach_3_c': (20, -20),  # SF^c
        },
        'logiqa': {
            'baseline_1': pos['t'],  # ZS-CoT
            'baseline_2': pos['t'],  # GTA
            'baseline_3': pos['tr'],  # DU
            'baseline_3_c': pos['t'],  # DU^c
            'approach_1': pos['t'],  # DF
            'approach_1_c': pos['t'],  # DF^c
            'approach_2': pos['tl'],  # SU
            'approach_2_c': pos['l'],  # SU^c
            'approach_3': pos['t'],  # SF
            'approach_3_c': pos['t'],  # SF^c
        },
        'truthfulqa': {
            'baseline_1': pos['t'],  # ZS-CoT
            'baseline_2': pos['t'],  # GTA
            'baseline_3': pos['t'],  # DU
            'baseline_3_c': (25, -15),  # DU^c
            'approach_1': pos['l'],  # DF
            'approach_1_c': pos['l'],  # DF^c
            'approach_2': pos['l'],  # SU
            'approach_2_c': pos['r'],  # SU^c
            'approach_3': pos['r'],  # SF
            'approach_3_c': pos['l'],  # SF^c
        }
    },
    'llama-3-8b-instruct': {
        'aqua': {
            'baseline_1': pos['b'],  # ZS-CoT
            'baseline_2': pos['t'],  # GTA
            'baseline_3': pos['t'],  # DU
            'baseline_3_c': pos['t'],  # DU^c
            'approach_1': pos['tl'],  # DF
            'approach_1_c': pos['t'],  # DF^c
            'approach_2': pos['t'],  # SU
            'approach_2_c': (25, -10),  # SU^c
            'approach_3': pos['t'],  # SF
            'approach_3_c': pos['t'],  # SF^c
        },
        'logiqa': {
            'baseline_1': pos['b'],  # ZS-CoT
            'baseline_2': pos['t'],  # GTA
            'baseline_3': pos['t'],  # DU
            'baseline_3_c': pos['br'],  # DU^c
            'approach_1': pos['t'],  # DF
            'approach_1_c': pos['b'],  # DF^c
            'approach_2': pos['tr'],  # SU
            'approach_2_c': pos['b'],  # SU^c
            'approach_3': pos['t'],  # SF
            'approach_3_c': pos['t'],  # SF^c
        },
        'truthfulqa': {
            'baseline_1': pos['b'],  # ZS-CoT
            'baseline_2': pos['t'],  # GTA
            'baseline_3': pos['tl'],  # DU
            'baseline_3_c': pos['tl'],  # DU^c
            'approach_1': pos['t'],  # DF
            'approach_1_c': pos['tr'],  # DF^c
            'approach_2': pos['r'],  # SU
            'approach_2_c': pos['r'],  # SU^c
            'approach_3': pos['tl'],  # SF
            'approach_3_c': pos['t'],  # SF^c
        }
    }
}

icl_zs_height_scale = {
    'llama-3-8b-instruct': {
        'logiqa': 0.25,
    }
}

ft_zs_height_scale = {
    'llama-3-8b-instruct': {
        'truthfulqa': 0.4,
    }
}

def format_ax(ax, axes, pad_factor = 0.08):
    return ax

def get_color_marker_annotation(run_name, ft=True):
    color = 'red' if run_name in ['baseline_0', 'baseline_1', 'baseline_2'] else 'blue'
    marker = 'o' if run_name in ['baseline_0', 'baseline_1', 'baseline_2'] else '^'
    annotation = run_name_titles[run_name.removesuffix('_c')]
    annotation = annotation + r'$^\mathrm{c}$' if run_name.endswith('_c') else annotation
    return color, marker, annotation

def scatter_plot(ax, dataset_name, model_name, run_names, use_parsed, hard_faith, best_checkpoints=None,
                 axes=['accuracy', 'faithfulness'], method='ft'):
    all_run_names = ['baseline_0', 'baseline_1'] + run_names
    zs_results = get_results_dict([dataset_name], model_name, hard_faith=hard_faith,
                                  use_parsed=False, cot=False, axes=axes)
    zs_cot_results = get_results_dict([dataset_name], model_name, hard_faith=hard_faith,
                                      use_parsed=use_parsed, cot=True, axes=axes)
    if method == 'ft':
        results_dict = get_ft_results_dict(dataset_name, model_name, run_names,
                                        use_parsed=use_parsed, hard_faith=hard_faith,
                                        best_checks=best_checkpoints, best_checks_dict=True,
                                        axes=axes)
    elif method == 'icl':
        results_dict = get_icl_results_dict(dataset_name, model_name, run_names,
                                            use_parsed=use_parsed, hard_faith=hard_faith,
                                            axes=axes)
    else:
        raise ValueError(f"Invalid method: {method}. Please use 'ft' or 'icl'.")
    zs_val1, zs_val2 = zs_results[dataset_name]
    zs_cot_val1, zs_cot_val2 = zs_cot_results[dataset_name]
    results_dict = {'baseline_0': (zs_val1, zs_val2), 'baseline_1': (zs_cot_val1, zs_cot_val2)} | results_dict
    print(dataset_name, model_name)
    print(results_dict)

    # Get max and min val2
    max_val2 = max([results_dict[run_name][1] for run_name in all_run_names])
    min_val2 = min([results_dict[run_name][1] for run_name in all_run_names])

    for run_name in all_run_names:
        val1 = results_dict[run_name][0] * axes_scales[axes[0]]
        val2 = results_dict[run_name][1] * axes_scales[axes[1]]
        if run_name == 'baseline_0':
            ax.axvline(val1, color='red', linestyle='--', label='ZS')
            annotation = 'ZS'
            # set val2 to halfway between min and max
            xytext = (20, 0)
            height_scales = ft_zs_height_scale if method == 'ft' else icl_zs_height_scale
            scale = height_scales.get(model_name, {}).get(dataset_name, 0.5)
            val2 = (max_val2 - min_val2) * scale + min_val2
        else:
            
            ft = True if method == 'ft' else False
            color, marker, annotation = get_color_marker_annotation(run_name, ft=ft)
            ax.scatter(val1, val2, color=color, marker=marker, s=100)

            print(color, marker, annotation)
            method_xytexts = ft_xytexts if method == 'ft' else icl_xytexts
            xytext = method_xytexts.get(model_name, {}).get(dataset_name, {}).get(run_name, (0, 10))
            if xytext != (0, 10):
                print(xytext)
        ax.annotate(annotation, (val1, val2), textcoords="offset points", ha='center', fontsize=19,
                    xytext=xytext, fontname='sans-serif')
    # ax = format_ax(ax, axes)

    # Set axes labels
    ax.set_xlabel(axes_names[axes[0]])
    ax.set_ylabel(axes_names[axes[1]])

    # Scale axes lims
    pad_factor = 0.08
    xlims, ylims = ax.get_xlim(), ax.get_ylim()
    x_range, y_range = xlims[1] - xlims[0], ylims[1] - ylims[0]
    print(xlims, ylims, x_range, y_range)
    ax.set_xlim(xlims[0] - pad_factor*x_range, xlims[1] + pad_factor*x_range)
    ax.set_ylim(ylims[0] - pad_factor*y_range, ylims[1] + pad_factor*y_range)
    # ax.legend(loc='upper right')
    print()
    return ax

def plot_num_cot_steps_boxplot(dataset_names, model_name, config, soft=True):
    fig, ax = plt.subplots(1, 3, figsize=(15, 5))
    for i, dataset_name in enumerate(dataset_names):
        config['dataset'], config['llm'] = dataset_name, model_name
        results_dir = construct_save_dir(config, save_config=False) + 'responses/'
        n_cot_matrix = get_num_cot_steps_matrix(results_dir).flatten()
        idx = 0 if soft else 1
        faith = get_faithfulness_matrix(results_dir)[idx].flatten()
        ax[i].boxplot([faith[n_cot_matrix == n] for n in np.unique(n_cot_matrix)])

        # Get correlation between faithfulness and number of CoT steps
        corr = np.corrcoef(faith, n_cot_matrix)[0, 1]
        ax[i].set_title(f'{dataset_titles[dataset_name]} ({corr:.2f})')
    plt.tight_layout()
    plt.suptitle(model_name.title(), fontsize=18, y=1.03)
    plt.show()

def create_sample_visualization(save_dir, test_idx, sample_idx, block_offset=0.02):
    # Get the response
    response = load_config(save_dir + f'/response_{test_idx}.json')
    question = response['question']
    options = response['options']
    label = response['label']
    sample = response[f'sample_{sample_idx}']
    prompt_text = get_multiple_choice_question(question, options)

    # Create a plot of the question text in a rounded rectangle
    fig = plt.figure(figsize=(10, 25), dpi=100)
    total_fig_height_pixels = fig.get_figheight() * fig.dpi  # Total height in pixels
    plt.gca().axis('off')

    # Process text
    # cot_prompt = cot_prompt.replace('Instructions:', r'$\bf{Instructions:}$').replace('The output', '\nThe output').rstrip().rstrip('\n')
    prompt_text = prompt_text.replace('Question:', r'$\bf{Question:}$').replace('Choices:', r'$\bf{Choices:}$').rstrip().rstrip('\n')

    # DATASET, TEST_IDX, SAMPLE_IDX, LLM, LABEL
    try:
        height = 1
        dataset, llm, split_str = save_dir.lstrip('results/').rstrip('/').split('/')
        split_str = split_str.split('_')[0]
        bold_llm = r"$\bf{{{}}}$ | ".format(llm)
        bold_dataset = r"$\bf{{{}}}$ ".format(dataset)
        index = f'({split_str.title()} Index {test_idx}, Sample Idx {sample_idx})'
        t = plt.text(0.5, height, bold_llm + bold_dataset + index,
                    ha='center', va='top', fontsize=14, wrap=True,
                    bbox=dict(boxstyle="round,pad=0.2", edgecolor='black',
                            facecolor=make_alpha_color('lightcoral', 0.1)))
        t_height = t.get_window_extent(renderer=fig.canvas.get_renderer()).height
        height -= (t_height / total_fig_height_pixels) + block_offset


        # QUESTION
        t = plt.text(0.5, height, prompt_text, ha='center', va='top', fontsize=14, wrap=True,
                bbox=dict(boxstyle="round,pad=0.2", edgecolor='black',
                        facecolor=make_alpha_color('beige', 0.1)))
        t_height = t.get_window_extent(renderer=fig.canvas.get_renderer()).height
        height -= (t_height / total_fig_height_pixels) + block_offset


        if 'error' in sample.keys():
            error_msg, response = sample['error'], 'Step 1: ' + sample['response']
            # Just plot the error message and response
            t = plt.text(0.5, height, error_msg, ha='center', va='top', fontsize=14, wrap=True,
                        bbox=dict(boxstyle="round,pad=0.2", edgecolor='black',
                                facecolor=make_alpha_color('lightcoral', 0.6)))
            t_height = t.get_window_extent(renderer=fig.canvas.get_renderer()).height
            height -= (t_height / total_fig_height_pixels) + block_offset

            t = plt.text(0.5, height, response, ha='center', va='top', fontsize=14, wrap=True,
                        bbox=dict(boxstyle="round,pad=0.2", edgecolor='black',
                                facecolor=make_alpha_color('lightgreen', 0.2)))
        else:
            # FINAL ANSWER
            final_answers = sample['parsed_intermediate_answers']
            t = plt.text(1, height, f'Final Answer: {final_answers[0]}', ha='right', va='top', fontsize=16, wrap=True,
                        bbox=dict(boxstyle="round,pad=0.2", edgecolor='black',
                                facecolor=make_alpha_color('yellowgreen', 0.9)))
            t_height = t.get_window_extent(renderer=fig.canvas.get_renderer()).height
            height -= (t_height / total_fig_height_pixels) + block_offset

            for i, cot_step in enumerate(get_cot_steps(sample, None)):
                # COT STEP
                t = plt.text(0.5, height, cot_step, ha='center', va='top', fontsize=14, wrap=True,
                            bbox=dict(boxstyle="round,pad=0.2", edgecolor='black',
                                        facecolor=make_alpha_color('cyan', 0.1)))
                t_height = t.get_window_extent(renderer=fig.canvas.get_renderer()).height
                height -= t_height / total_fig_height_pixels + block_offset

                # FINAL ANSWER
                t = plt.text(1, height, f'Final Answer: {final_answers[i+1]}', ha='right', va='top', fontsize=16, wrap=True,
                            bbox=dict(boxstyle="round,pad=0.2", edgecolor='black',
                                    facecolor=make_alpha_color('yellowgreen', 0.9)))
                t_height = t.get_window_extent(renderer=fig.canvas.get_renderer()).height
                height -= t_height / total_fig_height_pixels + block_offset

            # FAITHFULNESS
            faithfulness = sample['hard_faithfulness']
            faith_color = interpolate_colors('lightcoral', 'lightgreen', faithfulness)
            t = plt.text(0.5, height, f'Faithfulness: {faithfulness:.2f}', ha='center', va='top', fontsize=16, wrap=True,
                        bbox=dict(boxstyle="round,pad=0.2", edgecolor='black',
                                facecolor=make_alpha_color(faith_color, 0.9)))

        if not os.path.exists(save_dir + 'response_images/'):
            os.makedirs(save_dir + 'response_images/')
        file_path = save_dir + f'response_images/response_{test_idx}_{sample_idx}.png'
        plt.savefig(file_path)
        trim_whitespace(file_path)
        plt.close()
    except Exception as e:
        print(e)
        print(f'Error in creating the visualization at test index {test_idx} and sample index {sample_idx}')
        plt.close()

def interpolate_colors(color1, color2, value):
    cm = colors.LinearSegmentedColormap.from_list('gradient', [color1, color2], N=100)  # N controls the resolution
    return cm(value)

def trim_whitespace(image_path):
    # Load the image
    with Image.open(image_path) as img:
        # Convert image to numpy array
        img_array = np.array(img)

        # Find indices where the rows and columns are not all whitespace
        # Assuming background is white; adjust the condition for different background colors
        non_empty_columns = np.where(img_array.min(axis=0) < 255)[0]
        non_empty_rows = np.where(img_array.min(axis=1) < 255)[0]
        
        # Determine the bounding box of the non-white areas
        crop_box = (min(non_empty_columns) - 1, min(non_empty_rows) - 1,
                    max(non_empty_columns) + 1, max(non_empty_rows) + 1)

        # Crop the image
        img_cropped = img.crop(crop_box)
        if image_path.split('.')[-1] in ['jpg', 'jpeg']:
            img_cropped.save(image_path, 'JPEG', quality=100)
        else:
            img_cropped.save(image_path)
        img_cropped.save(image_path)  # Overwrite the original image

def make_alpha_color(color, alpha):
    fc = colors.to_rgba(color)
    return fc[:-1] + (alpha,)

def plot_logprobs(logprobs_content, ax=None):
    tokens, probs = construct_token_probabilities(logprobs_content)
    if ax is None:
        plt.figure(figsize=(10, 5))
        plt.bar(tokens, probs)
        plt.ylabel('Probability')
        plt.xlabel('Token')
        plt.title('Probability of each token in the answer')
        plt.show()
    else:
        ax.bar(tokens, probs)
        ax.set_ylabel('Probability')
        ax.set_xlabel('Token')
        ax.set_title('Probability of each token in the answer')
        return ax
    
def plot_faithfulness_hist(faithfulness, ax=None, bins=50, title='Histogram of Faithfulness'):
    """
    Plot histogram of faithfulness
    faithfulness: np.ndarray, array of faithfulness values (will flatten if not 1D)
    ax: matplotlib axis, axis to plot on and return, otherwise will create a new figure
    returns: matplotlib axis if ax is not None, otherwise None
    """
    # Flatten if not 1D
    faithfulness = faithfulness.flatten() if faithfulness.ndim > 1 else faithfulness
    n_points = len(faithfulness)

    # Plot histogram
    if ax is None:
        plt.figure(figsize=(10, 5))
        plt.hist(faithfulness, bins=bins if n_points > 100 else 20)
        plt.ylabel('Frequency')
        plt.xlabel('Faithfulness')
        plt.title(title)
        plt.show()
    else:
        ax.hist(faithfulness, bins=bins if n_points > 100 else 20)
        ax.set_ylabel('Frequency')
        ax.set_xlabel('Faithfulness')
        ax.set_title(title)
        return ax
    
def plot_faithfulness_x(faithfulness_x, ax=None, ylabel='Faithfulness', fill_aoc_color_alpha=False):
    """
    Plot faithfulness values at each step
    faithfulness_x: 1D np.ndarray, array of faithfulness values at each step
    ax: matplotlib axis, axis to plot on and return, otherwise will create a new figure
    ylabel: string, label for y-axis
    returns: matplotlib axis if ax is not None, otherwise None
    """
    n_points = len(faithfulness_x)
    color, alpha = fill_aoc_color_alpha if fill_aoc_color_alpha else (None, None)
    if ax is None:
        plt.figure(figsize=(10, 5))
        plt.plot(np.arange(n_points), faithfulness_x)
        if color:
            plt.fill_between(np.arange(n_points), faithfulness_x, np.ones(n_points), color=color, alpha=alpha)
        plt.ylabel(ylabel)
        plt.ylim(-0.05, 1.05)
        plt.xlabel('Step')
        plt.title('Faithfulness at each step')
        plt.show()
    else:
        ax.plot(np.arange(n_points), faithfulness_x)
        if color:
            ax.fill_between(np.arange(n_points), faithfulness_x, np.ones(n_points), color=color, alpha=alpha)
        ax.set_ylim(-0.05, 1.05)
        ax.set_ylabel(ylabel)
        ax.set_xlabel('Step')
        ax.set_title('Faithfulness at each step')
        return ax
    
def plot_faithfulness_x_all(results_dir, test_idx):
    """
    Plot faithfulness values at each step for all samples at a given test instance
    results_dir: string, path to the results directory
    test_idx: integer, index of the test instance
    returns: None
    """
    n_samples = get_num_samples_per_test(results_dir)
    fig, ax = plt.subplots(1, n_samples, figsize=(50, 4), dpi=200)
    for sample_idx in range(n_samples):
        answers_probs = get_answers_probs(results_dir, test_idx=test_idx, sample_idx=sample_idx)
        soft_faith_x, hard_faith_x = get_faith_x(answers_probs)
        ax[sample_idx] = plot_faithfulness_x(hard_faith_x, ax=ax[sample_idx], ylabel='', fill_aoc_color_alpha=('blue', 0.3))
        ax[sample_idx].set_title(f"Sample {sample_idx}")
        ax[sample_idx] = plot_faithfulness_x(soft_faith_x, ax=ax[sample_idx], ylabel='', fill_aoc_color_alpha=('orange', 0.2))
        ax[sample_idx].set_title(f"Sample {sample_idx}", fontsize=17)
    ax[0].set_ylabel('Probability of Final Answer', fontsize=15)
    plt.suptitle(f'Test instance {test_idx}', fontsize=24, y=1.05)
    plt.show()

def plot_finetuning_examples(dataset_name, model_name, run_tables, top_ps=np.arange(10, 110, 10), points=None, results_dict=None):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    ax1, ax1_2, min_0, max_0 = plot_finetuning_examples_single(run_tables[0], top_ps, points[0], ax1, results_dict)
    ax2, ax2_2, min_1, max_1 = plot_finetuning_examples_single(run_tables[1], top_ps, points[1], ax2, results_dict)
    for ax in [ax1, ax2, ax1_2, ax2_2]:
        min_all, max_all = min(min_0, min_1), max(max_0, max_1)
        ax.set_ylim(min(min_all-3, -3), min(max_all+3, 103))
    fig.suptitle(f'Finetuning Trainset Properties ({dataset_titles[dataset_name]}, {llm_titles[model_name]})', fontsize=20)
    fig.tight_layout()
    plt.show()

def plot_finetuning_examples_single(run_table, top_ps, points, ax1, results_dict=None):
    temperature = run_table['Temperature'][0]
    suffix = ' (Most Faithful Sample of 10)' if temperature != 0.0 else ' (No Sampling)'

    # Plot Accuracy
    color = 'tab:blue'
    ax1.set_xlabel('Size of Trainset (%)')
    ax1.set_ylabel('Trainset Accuracy (%)', color=color)
    accs = [float(acc.strip('%')) for acc in run_table['Examples Acc.'].tolist()]
    ax1.plot(top_ps, accs , color=color, label='Accuracy (%)')
    ax1.tick_params(axis='y', labelcolor=color)

    # Plot Faithfulness
    ax1_2 = ax1.twinx()
    color = 'tab:red'
    ax1_2.set_ylabel('Trainset Faithfulness (%)', color=color)
    faith = [float(faith.strip('%')) for faith in run_table['Examples Faith'].tolist()]
    ax1_2.plot(top_ps, faith, color=color)
    ax1_2.tick_params(axis='y', labelcolor=color)

    # Plot Dataset Accuracy and Faithfulness
    data_acc = float(run_table['Dataset Acc.'][0].strip('%'))
    data_faith = float(run_table['Dataset Faith'][0].strip('%'))
    # ax1.axhline(y=data_acc, color='b', linestyle='--')
    # ax1_2.axhline(y=data_faith, color='r', linestyle='--')
    min_faith, max_faith = ax1_2.get_ylim()
    min_acc, max_acc = ax1.get_ylim()
    min_both, max_both = min(min_faith, min_acc), max(max_faith, max_acc)

    if points is not None:
        for point in points:
            top_p, run_name = point
            top_p_idx = top_ps.tolist().index(top_p)
            point_faith = faith[top_p_idx]
            point_acc = accs[top_p_idx]
            ha = 'left' if top_p < 50 else 'right'
            # va = 'bottom' if point_faith < (min_both + max_both) / 2 else 'top'
            va = 'center'
            ax1.plot(top_p, point_acc, 'bo', alpha=0.4)
            ax1_2.plot(top_p, point_faith, 'ro', alpha=0.4)
            ax1_2.text(top_p, (point_faith+point_acc)/2, run_name,
                       alpha=0.8, fontsize=10, ha=ha, va=va)
            
            # Get results from results_dict for run_name
            if results_dict is not None:
                ft_accs, ft_faiths = results_dict[run_name.replace(' ', '_').lower()]
                ax1.plot(top_p, ft_accs[0]*100, 'bs', alpha=0.4)
                ax1_2.plot(top_p, ft_faiths[0]*100, 'rs', alpha=0.4)

    # Add legend for points
    # ax1_2.axhline(y=-1, color='b', linestyle='--', label='Full Trainset Acc.')
    # ax1_2.axhline(y=-1, color='r', linestyle='--', label='Full Trainset Faith.')
    ax1_2.plot([], [], 'o', color='black', alpha=0.4, label='Examples Selected (Before Fine-Tuning)')
    # ax1_2.plot([], [], 's', color='black', alpha=0.4, label='Testset Performance (After Fine-Tuning)')
    # ax1_2.plot([], [], 'bs', alpha=0.4, label='FT Model Acc.')
    # ax1_2.plot([], [], 'rs', alpha=0.4, label='FT Model Faith.')
    # ax1_2.plot([], [], 'bo', alpha=0.4, label='Method Acc.')
    # ax1_2.plot([], [], 'ro', alpha=0.4, label='Method Faith.')
    ax1_2.legend(loc='lower left')

    ax1.set_title(f'Temperature {temperature}' + suffix)
    return ax1, ax1_2, min_both, max_both

def plot_results_dict(ft_results_dict, title=None, save_dir=None, method_descriptions=None, best_checkpoints=None, hard_faith=False):
    # Put the _c results on the same plot as the corresponding non-_c results
    ft_results_dict_c = {}
    for run_name, (accs, faiths) in ft_results_dict.items():
        if run_name.endswith('_c'):
            ft_results_dict_c[run_name] = (accs, faiths)

    # Remove _c results from the original dictionary
    for run_name in ft_results_dict_c.keys():
        del ft_results_dict[run_name]

    n_results = len(ft_results_dict)
    plt.style.use('default')
    fig, ax = plt.subplots(nrows=1, ncols=n_results, figsize=(5*n_results, 5))
    min_y, max_y = 100, 0
    prefix = 'Hard ' if hard_faith else 'Soft '

    for i, (run_name, (accs, faiths)) in enumerate(ft_results_dict.items()):
        checkpoints = np.arange(len(accs))
        accs, faiths = np.array(accs), np.array(faiths)
        ax[i].plot(checkpoints, accs*100, label='Accuracy', color='b', alpha=0.8)
        ax[i].plot(checkpoints, faiths*100, label=f'{prefix}Faithfulness', color='r', alpha=0.8)
        if run_name + '_c' in ft_results_dict_c:
            accs_c, faiths_c = ft_results_dict_c[run_name + '_c']
            ax[i].plot(checkpoints, np.array(accs_c)*100, label='Accuracy (Correct Only)', color='b', linestyle='--', alpha=0.8)
            ax[i].plot(checkpoints, np.array(faiths_c)*100, label=f'{prefix}Faithfulness (Correct Only)', color='r', linestyle='--', alpha=0.8)
            if best_checkpoints:
                best_check = best_checkpoints[run_name + '_c']
                ax[i].plot(best_check, accs_c[best_check]*100, 'bs', alpha=0.4)
                ax[i].plot(best_check, faiths_c[best_check]*100, 'rs', alpha=0.4)
        if method_descriptions:
            ax[i].set_title(run_name.replace('_', ' ').title() + f' ({method_descriptions[run_name]})')
        else:
            ax[i].set_title(run_name)
        ax[i].set_xlabel('Checkpoint')
        ax[i].set_xticks(checkpoints)
        ax[i].set_ylabel('Percentage')
        if best_checkpoints:
            best_check = best_checkpoints[run_name]
            ax[i].plot(best_check, accs[best_check]*100, 'bs', alpha=0.4)
            ax[i].plot(best_check, faiths[best_check]*100, 'rs', alpha=0.4)
            ax[i].plot([], [], 's', color='black', alpha=0.4, label='Selected Checkpoint')  # for legend
        ax[i].legend()
        min_y = min(min_y, min(accs.min(), faiths.min())*100)
        max_y = max(max_y, max(accs.max(), faiths.max())*100)
    for i in range(n_results):
        ax[i].set_ylim(min_y-3, max_y+3)

    if title:
        plt.suptitle(title, fontsize=20)
    if save_dir:
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        plt.savefig(save_dir + f'{title}.png')
    plt.tight_layout()
    plt.show()



