import os, glob, random, json, wandb
import numpy as np
import matplotlib.pyplot as plt
from scipy import interpolate
from sklearn.metrics import auc
from matplotlib.colors import hsv_to_rgb
import seaborn as sns
import pandas as pd
from pathlib import Path
from plotnine import (
    ggplot, aes, geom_line, geom_point, facet_wrap, geom_bar, geom_abline, xlim, scale_fill_manual,
    geom_text, position_dodge, ylim, labs, theme_bw, theme, element_text, scale_color_manual, coord_flip, geom_boxplot, geom_jitter
)


# Define a color palette at the top of the file
COLORS = {
    'PreferenceVector': '#d62728',  # Red
    'PromptSteering': '#1f77b4',    # Blue
    'PromptSteering_prepend': '#2ca02c',      # Green
    'PromptSteering_append': '#ff7f0e',       # Orange
    'PromptSteering_prepend_rewrite': '#9467bd',  # Purple
    'PromptSteering_append_rewrite': '#8c564b',   # Brown
    'PromptSteering_original': '#e377c2',      # Pink
    'PromptSteering_rewrite': '#7f7f7f',      # Gray
}

MARKERS = ['o', 's', '^', 'D', 'v', '<', '>', 'p', '*', 'h']


def plot_aggregated_roc(jsonl_data, write_to_path=None, report_to=[], wandb_name=None):
    # Collect ROC data for each model
    metrics_list = [aggregated_result["results"]["AUCROCEvaluator"] 
                    for aggregated_result in jsonl_data]
    
    # Define common FPR thresholds for interpolation
    common_fpr = np.linspace(0, 1, 100)
    
    tprs = {}
    aucs = {}
    for metrics in metrics_list:
        for model_name, value in metrics.items():
            fpr = value["roc_curve"]["fpr"]
            tpr = value["roc_curve"]["tpr"]
            auc = value["roc_auc"]
            
            interp_tpr = np.interp(common_fpr, fpr, tpr)
            interp_tpr[0] = 0.0  # Ensure TPR starts at 0
            if model_name not in tprs:
                tprs[model_name] = []
                aucs[model_name] = []
            tprs[model_name].append(interp_tpr)
            aucs[model_name].append(auc)
    
    # Prepare data for plotting
    plot_data = []
    for model_name in tprs.keys():
        mean_tpr = np.mean(tprs[model_name], axis=0)
        mean_auc = np.mean(aucs[model_name])
        for fpr, tpr in zip(common_fpr, mean_tpr):
            plot_data.append({
                'FPR': fpr,
                'TPR': tpr,
                'Model': f"{model_name} (AUC = {mean_auc:.2f})"
            })
    
    df = pd.DataFrame(plot_data)
    
    # Create the plot
    p = (
        ggplot(df, aes(x='FPR', y='TPR', color='Model')) +
        geom_line(size=1) +
        geom_abline(slope=1, intercept=0, linetype='dashed', color='gray') +
        theme_bw() +
        labs(x='False Positive Rate (FPR)', y='True Positive Rate (TPR)') +
        theme(
            figure_size=(4, 4),
            legend_title=element_text(size=8),
            legend_text=element_text(size=6),
            axis_title=element_text(size=10),
            axis_text=element_text(size=8),
            plot_title=element_text(size=12),
            legend_position='right'
        )
    )
    
    # Optional: Customize colors if needed
    # COLORS = ['#1f77b4', '#ff7f0e', '#2ca02c', ...]  # Define your color palette
    # p += scale_color_manual(values=COLORS)
    
    # Save or show the plot
    if write_to_path:
        p.save(filename=str(write_to_path / "aggregated_roc.png"), dpi=300, bbox_inches='tight')
    else:
        print(p)

    # Report to wandb if wandb_name is provided
    if report_to is not None and "wandb" in report_to:
        # Prepare data for wandb.plot.line_series
        xs = common_fpr.tolist()
        ys = [np.mean(tprs[model], axis=0).tolist() for model in tprs]
        keys = [f"{model} (AUC = {np.mean(aucs[model]):.2f})" for model in tprs]
        wandb.log({"latent/roc_curve" : wandb.plot.line_series(
            xs=xs,
            ys=ys,
            keys=keys,
            title='Aggregated ROC Curve',
            xname='False Positive Rate (FPR)',
        )})


def plot_metrics(jsonl_data, configs, write_to_path=None, report_to=[], wandb_name=None, mode=None):
    # Collect data into a list
    data = []
    for config in configs:
        evaluator_name = config['evaluator_name']
        metric_name = config['metric_name']
        y_label = config['y_label']
        use_log_scale = config['use_log_scale']
        
        for entry in jsonl_data:
            results = entry.get('results', {}).get(evaluator_name, {})
            for method, res in results.items():
                factors = res.get('factor', [])
                metrics = res.get(metric_name, [])
                # Ensure factors and metrics are lists
                if not isinstance(factors, list):
                    factors = [factors]
                if not isinstance(metrics, list):
                    metrics = [metrics]
                for f, m in zip(factors, metrics):
                    data.append({
                        'Factor': f,
                        'Value': m,
                        'Method': method,
                        'Metric': y_label,
                        'UseLogScale': use_log_scale
                    })

    # Create DataFrame and average metrics
    df = pd.DataFrame(data)
    df = df.groupby(['Method', 'Factor', 'Metric', 'UseLogScale'], as_index=False).mean()

    # Apply log transformation if needed
    df['TransformedValue'] = df.apply(
        lambda row: np.log10(row['Value']) if row['UseLogScale'] else row['Value'],
        axis=1
    )

    # Create the plot
    p = (
        ggplot(df, aes(x='Factor', y='TransformedValue', color='Method', group='Method')) +
        geom_line() +
        geom_point() +
        theme_bw() +
        labs(x='Factor', y='Value') +
        facet_wrap('~ Metric', scales='free_y', nrow=1) +  # Plots in a row
        theme(
            subplots_adjust={'wspace': 0.1},
            figure_size=(1.5 * len(configs), 3),  # Wider for more plots, taller height
            legend_position='right',
            legend_title=element_text(size=4),
            legend_text=element_text(size=6),
            axis_title=element_text(size=6),
            axis_text=element_text(size=6),
            axis_text_x=element_text(rotation=90, hjust=1),  # Rotate x-axis labels
            strip_text=element_text(size=6)
        )
    )

    # Save or show the plot
    if write_to_path:
        p.save(filename=str(write_to_path / f"{mode}_plot.png"), dpi=300, bbox_inches='tight')
    else:
        print(p)

    # Report to wandb if wandb_name is provided
    if report_to is not None and "wandb" in report_to:
        # Separate data by metrics to prepare for wandb line series plotting
        line_series_plots = {}
        for metric in df['Metric'].unique():
            metric_data = df[df['Metric'] == metric]
            
            xs = metric_data['Factor'].unique().tolist()
            ys = [metric_data[metric_data['Method'] == method]['TransformedValue'].tolist() for method in metric_data['Method'].unique()]
            keys = [f"{method}" for method in metric_data['Method'].unique()]
            
            line_series_plots[f"{mode}/{metric}"] = wandb.plot.line_series(
                xs=xs,
                ys=ys,
                keys=keys,
                title=f"{metric}",
                xname='Factor'
            )
        wandb.log(line_series_plots)

def plot_metrics_multiple_datasets(data_path, write_to_path=None, report_to=[], wandb_name=None, mode=None, rule=True):
    # Collect data into a list
    df = pd.read_parquet(data_path)
    
    # Find all unique datasets
    datasets = df['dataset_name'].unique()
    
    def harmonic_mean(scores):
        # Return 0 if any score is 0 to maintain strict evaluation
        if 0 in scores:
            return 0
        return len(scores) / sum(1/s for s in scores)
    
    methods = [a[:len(a)-len("_LMJudgeEvaluator_relevance_concept_ratings")] for a in list(df.columns) if "LMJudgeEvaluator_relevance_concept_ratings" in a]

    # Metrics to plot
    if not rule:
        metrics = ['_LMJudgeEvaluator',
                '_LMJudgeEvaluator_relevance_concept_ratings', 
                '_LMJudgeEvaluator_relevance_instruction_ratings', 
                '_LMJudgeEvaluator_fluency_ratings']
        metrics_names = ['Overall', 'Relevance Concept', 'Relevance Instruction', 'Fluency']
    
    else:
        # Calculate harmonic mean for each method and add as a new column
        for method in methods:
            rule_metric = method + '_RuleEvaluator_rule_following'
            instruction_metric = method + '_LMJudgeEvaluator_relevance_instruction_ratings'
            fluency_metric = method + '_LMJudgeEvaluator_fluency_ratings'
            print(df[rule_metric])
            # Create a new column for the harmonic mean
            df[method + '_RuleEvaluator'] = df.apply(
                lambda row: harmonic_mean([
                    2-row[rule_metric] if rule_metric in df.columns and not pd.isna(row[rule_metric]) else 0,
                    row[instruction_metric] if instruction_metric in df.columns and not pd.isna(row[instruction_metric]) else 0,
                    row[fluency_metric] if fluency_metric in df.columns and not pd.isna(row[fluency_metric]) else 0
                ]),
                axis=1
            )
            print(df[method + '_RuleEvaluator'])
       
        metrics = [
                '_RuleEvaluator',
                '_RuleEvaluator_rule_following', 
                '_LMJudgeEvaluator_relevance_instruction_ratings', 
                '_LMJudgeEvaluator_fluency_ratings']
        metrics_names = ['Overall', 'Rule Following', 'Relevance Instruction', 'Fluency']
    
    
    
    # Prepare data for plotting
    plot_data = []
    
    # For each dataset, calculate average metrics for each factor
    for dataset in datasets:
        dataset_data = df[df['dataset_name'] == dataset]     
        # Group by method and factor, then calculate mean for each metric
        for method in methods:
            if 'PromptSteering' in method:
                # For PromptSteering, average across all factors
                # Calculate average for each metric
                for idx, metric in enumerate(metrics):
                    if method + metric in dataset_data.columns:
                        avg_value = dataset_data[method + metric].mean()
                        # Add same average value for each factor to create a straight line
                        for factor in dataset_data['factor'].unique():
                            plot_data.append({
                                'Dataset': dataset,
                                'Method': method,
                                'Factor': factor,
                                'Metric': metrics_names[idx],
                                'Value': avg_value
                            })
            else:
                # For other methods, keep factor-wise values
                for factor in dataset_data['factor'].unique():
                    factor_data = dataset_data[dataset_data['factor'] == factor]
                    # Calculate average for each metric
                    for idx, metric in enumerate(metrics):
                        if method + metric in factor_data.columns:
                            avg_value = factor_data[method + metric].mean()
                            plot_data.append({
                                'Dataset': dataset,
                                'Method': method,
                                'Factor': factor,
                                'Metric': metrics_names[idx],
                                'Value': avg_value
                            })
    
    # Create DataFrame for plotting
    plot_df = pd.DataFrame(plot_data)
    
    # Ensure metrics are displayed in the correct order by creating a categorical variable
    plot_df['Metric'] = pd.Categorical(plot_df['Metric'], categories=metrics_names, ordered=True)
    
    # Create a plot for each dataset
    num_datasets = len(datasets)
    
    # Create a combined plot with all datasets, one per row
    combined_p = (
        ggplot(plot_df, aes(x='Factor', y='Value', color='Method', group='Method')) +
        geom_line() +
        geom_point() +
        facet_wrap('~ Dataset + Metric', scales='free_y', ncol=4, nrow=num_datasets) +
        theme_bw() +
        labs(x='Factor', y='Score') +
        scale_color_manual(values=COLORS) +  # Use the expanded color palette
        theme(
            figure_size=(16, 4 * num_datasets),
            legend_position='right',
            legend_title=element_text(size=8),
            legend_text=element_text(size=6),
            axis_title=element_text(size=10),
            axis_text=element_text(size=8),
            axis_text_x=element_text(rotation=45, hjust=1),
            strip_text=element_text(size=8),
            plot_title=element_text(size=12, hjust=0.5)
        )
    )
    
    # Save or show the combined plot
    if write_to_path:
        combined_p.save(filename=str(write_to_path / f"{mode}_per_dataset_plot.png"), dpi=300, bbox_inches='tight')
    else:
        print(combined_p)
    
    # Also create the original averaged plot across all datasets
    avg_df = plot_df.groupby(['Method', 'Factor', 'Metric'], as_index=False)['Value'].mean()
    avg_df['Metric'] = pd.Categorical(avg_df['Metric'], categories=metrics_names, ordered=True)
    
    avg_p = (
        ggplot(avg_df, aes(x='Factor', y='Value', color='Method', group='Method')) +
        geom_line() +
        geom_point() +
        facet_wrap('~ Metric', scales='free_y', ncol=4) +
        theme_bw() +
        labs(x='Factor', y='Score', title='Average Across All Datasets') +
        scale_color_manual(values=COLORS) +  # Use the expanded color palette
        theme(
            figure_size=(16, 4),
            legend_position='right',
            legend_title=element_text(size=8),
            legend_text=element_text(size=6),
            axis_title=element_text(size=10),
            axis_text=element_text(size=8),
            axis_text_x=element_text(rotation=45, hjust=1),
            strip_text=element_text(size=8),
            plot_title=element_text(size=12, hjust=0.5)
        )
    )
    
    # Save or show the averaged plot
    if write_to_path:
        avg_p.save(filename=str(write_to_path / f"{mode}_averaged_metrics.png"), dpi=300, bbox_inches='tight')
    
    # Report plots to wandb if requested
    if report_to is not None and "wandb" in report_to:
        if write_to_path:
            wandb.log({
                f"{mode}/averaged_metrics": wandb.Image(str(write_to_path / f"{mode}_averaged_metrics.png")),
                f"{mode}/per_dataset_metrics": wandb.Image(str(write_to_path / f"{mode}_per_dataset_plot.png"))
            })


def plot_nshot_metrics(data_path, write_to_path=None, report_to=[], wandb_name=None, mode=None, rule = False):
    # Collect data into a list
    df = pd.read_parquet(data_path)
    
    # Extract shot numbers from dataset names (assuming format AttackMultiShotn)
    df['n_shots'] = df['dataset_name'].str.extract(r'AttackMultiShotIgnoreFollow_(\d+)').astype(int)
    
    def harmonic_mean(scores):
        # Return 0 if any score is 0 to maintain strict evaluation
        if 0 in scores:
            return 0
        return len(scores) / sum(1/s for s in scores)
    
    methods = [a.split('_')[0] for a in list(df.columns) if "steered_generation" in a]
    print(methods)
    # Metrics to plot
    if not rule:
        metrics = ['_LMJudgeEvaluator',
                '_LMJudgeEvaluator_relevance_concept_ratings', 
                '_LMJudgeEvaluator_relevance_instruction_ratings', 
                '_LMJudgeEvaluator_fluency_ratings']
        metrics_names = ['Overall', 'Relevance Concept', 'Relevance Instruction', 'Fluency']
    
    else:
        # Calculate harmonic mean for each method and add as a new column
        for method in methods:
            rule_metric = method + '_RuleEvaluator_rule_following'
            instruction_metric = method + '_LMJudgeEvaluator_relevance_instruction_ratings'
            fluency_metric = method + '_LMJudgeEvaluator_fluency_ratings'
            # Create a new column for the harmonic mean
            df[method + '_RuleEvaluator'] = df.apply(
                lambda row: harmonic_mean([
                    2-row[rule_metric] if rule_metric in df.columns and not pd.isna(row[rule_metric]) else 0,
                    row[instruction_metric] if instruction_metric in df.columns and not pd.isna(row[instruction_metric]) else 0,
                    row[fluency_metric] if fluency_metric in df.columns and not pd.isna(row[fluency_metric]) else 0
                ]),
                axis=1
            )

        metrics = [
                '_RuleEvaluator',
                '_RuleEvaluator_rule_following', 
                '_LMJudgeEvaluator_relevance_instruction_ratings', 
                '_LMJudgeEvaluator_fluency_ratings']
        metrics_names = ['Overall', 'Rule Following', 'Relevance Instruction', 'Fluency']    
    # Prepare data for plotting
    plot_data = []
    
    # Calculate average metrics for each method and number of shots
    for method in methods:
        for n_shots in df['n_shots'].unique():
            shots_data = df[df['n_shots'] == n_shots]
            # Calculate average for each metric
            for idx, metric in enumerate(metrics):
                if method + metric in shots_data.columns:
                    avg_value = shots_data[method + metric].mean()
                    plot_data.append({
                        'n_shots': n_shots,
                        'Method': method,
                        'Metric': metrics_names[idx],
                        'Value': avg_value
                    })
    
    # Create DataFrame for plotting
    plot_df = pd.DataFrame(plot_data)
    
    # Ensure metrics are displayed in the correct order by creating a categorical variable
    plot_df['Metric'] = pd.Categorical(plot_df['Metric'], categories=metrics_names, ordered=True)
    
    # Sort by number of shots to ensure proper ordering on x-axis
    plot_df = plot_df.sort_values('n_shots')
    
    # Create the plot with n_shots on x-axis
    p = (
        ggplot(plot_df, aes(x='n_shots', y='Value', color='Method', group='Method')) +
        geom_line() +
        geom_point() +
        facet_wrap('~ Metric', scales='free_y', ncol=4) +  # 4 metrics in one row
        theme_bw() +
        labs(x='Number of Shots', y='Score', title='Performance by Number of Shots') +
        theme(
            figure_size=(16, 4),  # Adjust figure size
            legend_position='right',
            legend_title=element_text(size=8),
            legend_text=element_text(size=6),
            axis_title=element_text(size=10),
            axis_text=element_text(size=8),
            axis_text_x=element_text(rotation=0),  # No need to rotate numbers
            strip_text=element_text(size=8),
            plot_title=element_text(size=12, hjust=0.5)
        )
    )
    
    # Save or show the plot
    if write_to_path:
        p.save(filename=str(write_to_path / f"{mode}_nshot_metrics.png"), dpi=300, bbox_inches='tight')
    else:
        print(p)
    
    # Report plots to wandb if requested
    if report_to is not None and "wandb" in report_to:
        if write_to_path:
            wandb.log({
                f"{mode}/nshot_metrics": wandb.Image(str(write_to_path / f"{mode}_nshot_metrics.png"))
            })


def plot_accuracy_bars(jsonl_data, evaluator_name, write_to_path=None, report_to=[], wandb_name=None):
    # Get unique methods and sort them
    methods = set()
    for entry in jsonl_data:
        methods.update(entry['results'][evaluator_name].keys())
    methods = sorted(list(methods))
    
    # Initialize data structure for 'Seen' accuracy
    seen_accuracies = {method: [] for method in methods}
    
    # Collect data from all concepts
    for entry in jsonl_data:
        results = entry['results'][evaluator_name]
        for method in methods:
            if method in results:
                if 'macro_avg_accuracy' in results[method]:
                    seen_accuracies[method].append(
                        results[method]['macro_avg_accuracy'])
    
    # Calculate means
    seen_means = {method: np.mean(vals) if len(vals) > 0 else 0 for method, vals in seen_accuracies.items()}
    
    # Prepare data for plotting
    data = []
    for method in methods:
        data.append({'Method': method, 'Accuracy': seen_means[method]})
    
    df = pd.DataFrame(data)
    
    # Create the plot
    p = (
        ggplot(df, aes(x='Method', y='Accuracy', fill='Method')) +
        geom_bar(stat='identity', width=0.7) +
        geom_text(
            aes(label='round(Accuracy, 2)'),
            va='bottom',
            size=8,
            format_string='{:.2f}'
        ) +
        ylim(0, 1) +  # Set y-axis limits from 0 to 1
        theme_bw() +
        labs(x='Method', y='Accuracy') +
        theme(
            figure_size=(5, 2),
            legend_position='none',  # Remove legend since 'fill' corresponds to 'Method'
            axis_title=element_text(size=5),
            axis_text=element_text(size=5),
            plot_title=element_text(size=5)
        )
    )

    # Save or show the plot
    if write_to_path:
        p.save(filename=str(write_to_path / "macro_avg_accuracy_incl_hard_neg.png"), dpi=300, bbox_inches='tight')
    else:
        print(p)

    if report_to is not None and "wandb" in report_to:
        wandb.log({"latent/macro_avg_accuracy_incl_hard_neg": wandb.Image(str(write_to_path / "macro_avg_accuracy_incl_hard_neg.png"))})


def plot_win_rates(jsonl_data, write_to_path=None, report_to=[], wandb_name=None):
    # Collect methods and baseline models
    methods = set()
    baseline_models = set()
    for entry in jsonl_data:
        winrate_results = entry.get('results', {}).get('WinRateEvaluator', {})
        for method_name, res in winrate_results.items():
            methods.add(method_name)
            baseline_models.add(res.get('baseline_model', 'Unknown'))
    methods = sorted(list(methods))
    baseline_models = sorted(list(baseline_models))
    
    # Assuming all methods are compared against the same baseline
    if len(baseline_models) == 1:
        baseline_model = baseline_models[0]
    else:
        # Handle multiple baselines if necessary
        baseline_model = baseline_models[0]  # For now, take the first one
    
    # Add the baseline method to methods if not already present
    if baseline_model not in methods:
        methods.append(baseline_model)
    
    # Initialize data structures
    win_rates = {method: [] for method in methods}
    loss_rates = {method: [] for method in methods}
    tie_rates = {method: [] for method in methods}
    
    # Collect data from all concepts
    num_concepts = len(jsonl_data)
    for entry in jsonl_data:
        winrate_results = entry.get('results', {}).get('WinRateEvaluator', {})
        for method in methods:
            if method == baseline_model:
                continue  # Handle baseline separately
            if method in winrate_results:
                res = winrate_results[method]
                win_rates[method].append(res.get('win_rate', 0) * 100)
                loss_rates[method].append(res.get('loss_rate', 0) * 100)
                tie_rates[method].append(res.get('tie_rate', 0) * 100)
            else:
                # If method is not present in this concept, assume zero rates
                win_rates[method].append(0.0)
                loss_rates[method].append(0.0)
                tie_rates[method].append(0.0)
    
    # For the baseline method, set win_rate=50%, loss_rate=50%, tie_rate=0%
    win_rates[baseline_model] = [50.0] * num_concepts
    loss_rates[baseline_model] = [50.0] * num_concepts
    tie_rates[baseline_model] = [0.0] * num_concepts
    
    # Calculate mean percentages
    win_means = {method: np.mean(vals) for method, vals in win_rates.items()}
    loss_means = {method: np.mean(vals) for method, vals in loss_rates.items()}
    tie_means = {method: np.mean(vals) for method, vals in tie_rates.items()}
    
    # Sort methods: baseline at top, then methods by descending win rate
    non_baseline_methods = [m for m in methods if m != baseline_model]
    sorted_methods = sorted(
        non_baseline_methods,
        key=lambda m: win_means[m],
        reverse=True
    )
    
    # Prepare data for plotting
    data = []
    for method in sorted_methods:
        data.append({'Method': method, 'Outcome': 'Loss', 'Percentage': loss_means[method]})
        data.append({'Method': method, 'Outcome': 'Tie', 'Percentage': tie_means[method]})
        data.append({'Method': method, 'Outcome': 'Win', 'Percentage': win_means[method]})
    
    df = pd.DataFrame(data)
    
    # Set the order of Outcome to control stacking order
    df['Outcome'] = pd.Categorical(df['Outcome'], categories=['Loss', 'Tie', 'Win'], ordered=True)
    # Reverse the methods list for coord_flip to display baseline at the top
    df['Method'] = pd.Categorical(df['Method'], categories=sorted_methods[::-1], ordered=True)
    
    # Ensure df is sorted properly
    df = df.sort_values(['Method', 'Outcome'])
    # Convert 'Percentage' to float
    df['Percentage'] = df['Percentage'].astype(float)
    
    # Compute cumulative percentage per method
    df['cum_percentage'] = df.groupby('Method')['Percentage'].cumsum()
    # Shift cumulative percentages per method
    df['cum_percentage_shifted'] = df.groupby('Method')['cum_percentage'].shift(1).fillna(0)
    
    # For the 'Win' outcome, get the cumulative percentage up to before 'Win'
    df_win = df[df['Outcome'] == 'Win'].copy()
    df_win['text_position'] = df_win['cum_percentage_shifted']
    # Convert 'text_position' to float
    df_win['text_position'] = 100.0 - df_win['text_position'].astype(float)
    # Format the win percentage label
    df_win['win_percentage_label'] = df_win['Percentage'].map(lambda x: f"{x:.1f}%")
    
    # Create the plot
    p = (
        ggplot(df, aes(x='Method', y='Percentage', fill='Outcome')) +
        geom_bar(stat='identity', position='stack', width=0.8) +
        # Add the geom_text layer to include win rate numbers
        geom_text(
            data=df_win,
            mapping=aes(
                x='Method',
                y='text_position',
                label='win_percentage_label'
            ),
            ha='right',
            va='center',
            size=6,  # Adjust size as needed
            color='black',
            nudge_y=18  # Adjust this value as needed for proper positioning
        ) +
        coord_flip() +  # Flip coordinates for horizontal bars
        theme_bw() +
        labs(
            y='Percentage (%)',
            x=''
        ) +
        theme(
            axis_text_x=element_text(size=6),
            axis_text_y=element_text(size=6),
            axis_title=element_text(size=6),
            legend_title=element_text(size=6),
            legend_text=element_text(size=6),
            figure_size=(3, len(sorted_methods) * 0.3 + 0.3)
        ) +
        scale_fill_manual(
            values={'Win': '#a6cee3', 'Tie': '#bdbdbd', 'Loss': '#fbb4ae'},
            guide='legend',
            name='Outcome'
        )
    )
    
    # Save or show the plot
    if write_to_path:
        p.save(filename=str(write_to_path / "winrate_plot.png"), dpi=300, bbox_inches='tight')
    else:
        print(p)
    
    if report_to is not None and "wandb" in report_to:
        wandb.log({"steering/winrate_plot": wandb.Image(str(write_to_path / "winrate_plot.png"))})


def plot_best_factor_scores(data_path, write_to_path=None, report_to=[], wandb_name=None, mode=None, rule=True):
    # Read the data
    df = pd.read_parquet(data_path)
    
    def harmonic_mean(scores):
        # Return 0 if any score is 0 to maintain strict evaluation
        if 0 in scores:
            return 0
        return len(scores) / sum(1/s for s in scores)
    
    methods = [a[:len(a)-len("_LMJudgeEvaluator_relevance_concept_ratings")] for a in list(df.columns) if "LMJudgeEvaluator_relevance_concept_ratings" in a]

    # Metrics to plot
    if not rule:
        metrics = ['_LMJudgeEvaluator',
                '_LMJudgeEvaluator_relevance_concept_ratings', 
                '_LMJudgeEvaluator_relevance_instruction_ratings', 
                '_LMJudgeEvaluator_fluency_ratings']
        metrics_names = ['Overall', 'Relevance Concept', 'Relevance Instruction', 'Fluency']
    else:
        # Calculate harmonic mean for each method and add as a new column
        for method in methods:
            rule_metric = method + '_RuleEvaluator_rule_following'
            instruction_metric = method + '_LMJudgeEvaluator_relevance_instruction_ratings'
            fluency_metric = method + '_LMJudgeEvaluator_fluency_ratings'
            
            # Create a new column for the harmonic mean
            df[method + '_RuleEvaluator'] = df.apply(
                lambda row: harmonic_mean([
                    2-row[rule_metric] if rule_metric in df.columns and not pd.isna(row[rule_metric]) else 0,
                    row[instruction_metric] if instruction_metric in df.columns and not pd.isna(row[instruction_metric]) else 0,
                    row[fluency_metric] if fluency_metric in df.columns and not pd.isna(row[fluency_metric]) else 0
                ]),
                axis=1
            )
       
        metrics = [
                '_RuleEvaluator',
                '_RuleEvaluator_rule_following', 
                '_LMJudgeEvaluator_relevance_instruction_ratings', 
                '_LMJudgeEvaluator_fluency_ratings']
        metrics_names = ['Overall', 'Rule Following', 'Relevance Instruction', 'Fluency']
    
    # Get all concepts
    concepts = df['concept_id'].unique()
    
    # Store best scores for each concept
    best_scores = []
    
    # For each concept, split data and find best factor
    for concept in concepts:
        concept_data = df[df['concept_id'] == concept]
        
        # Get indices for this concept's data
        indices = concept_data.index.values
        
        # Randomly split indices into train and test
        np.random.seed(41)  # for reproducibility
        train_indices = np.random.choice(indices, size=len(indices)//2, replace=False)
        test_indices = np.array([idx for idx in indices if idx not in train_indices])
        
        # Split data
        train_data = concept_data.loc[train_indices]
        test_data = concept_data.loc[test_indices]
        
        # Find the factor that gives max RuleEvaluator score on train data
        train_rule_scores = train_data['PreferenceVector_RuleEvaluator']
        best_factor = train_data.loc[train_rule_scores.idxmax(), 'factor']
        
        # Get scores for the best factor using test data
        test_factor_data = test_data[test_data['factor'] == best_factor]
        
        if len(test_factor_data) > 0:  # Only add if we have test data for this factor
            # Get all metrics for this best factor using mean of test data
            metrics_data = {
                'Concept': f'Concept {concept}',
                'Factor': best_factor,
                'Overall': test_factor_data['PreferenceVector_RuleEvaluator'].mean(),
                'Rule Following': 2 - test_factor_data['PreferenceVector_RuleEvaluator_rule_following'].mean(),
                'Relevance': test_factor_data['PreferenceVector_LMJudgeEvaluator_relevance_instruction_ratings'].mean(),
                'Fluency': test_factor_data['PreferenceVector_LMJudgeEvaluator_fluency_ratings'].mean()
            }
            best_scores.append(metrics_data)
    
    # Create DataFrame with best scores
    best_scores_df = pd.DataFrame(best_scores)
    
    # Print the selected factors and their frequencies
    factor_counts = best_scores_df['Factor'].value_counts()
    print("\nSelected Factors Distribution:")
    print(factor_counts)
    
    # Melt the DataFrame for plotting
    plot_df = pd.melt(best_scores_df, 
                      id_vars=['Concept', 'Factor'],
                      value_vars=['Overall', 'Rule Following', 'Relevance', 'Fluency'],
                      var_name='Metric',
                      value_name='Score')
    
    # Create box plot
    p = (
        ggplot(plot_df, aes(x='Metric', y='Score')) +
        geom_boxplot(fill='lightblue', alpha=0.5) +
        geom_jitter(color='blue', alpha=0.3, width=0.2) +
        theme_bw() +
        labs(
            title='Distribution of Test Set Scores Using Factors Selected on Train Set',
            x='Metric',
            y='Score'
        ) +
        theme(
            figure_size=(10, 6),
            axis_text_x=element_text(rotation=45, hjust=1),
            plot_title=element_text(size=12, hjust=0.5),
            axis_title=element_text(size=10),
            axis_text=element_text(size=8)
        )
    )
    
    # Save or show the plot
    if write_to_path:
        p.save(filename=str(write_to_path / f"{mode}_best_factor_boxplot.png"), dpi=300, bbox_inches='tight')
    else:
        print(p)
    
    # Report to wandb if requested
    if report_to is not None and "wandb" in report_to and write_to_path:
        wandb.log({
            f"{mode}/best_factor_boxplot": wandb.Image(str(write_to_path / f"{mode}_best_factor_boxplot.png"))
        })
    
    # Print summary statistics
    print("\nSummary Statistics for Test Set Scores:")
    summary_stats = plot_df.groupby('Metric')['Score'].agg(['mean', 'std', 'min', 'max'])
    print(summary_stats)
    
    return p

