import json
import numpy as np
from scipy.stats import bootstrap
import matplotlib.pyplot as plt
from pathlib import Path
import logging
import re

logger = logging.getLogger(__name__)

def calculate_mean_stderr(data_list: list, n_resamples: int = 1000):
    if not data_list or all(x is None for x in data_list):
        return np.nan, np.nan
    
    processed_list = []
    if not data_list:
        return np.nan, np.nan
    
    is_numeric_list = True
    first_valid_item = next((item for item in data_list if item is not None), None)
    if isinstance(first_valid_item, bool):
        is_numeric_list = False

    for x in data_list:
        if x is not None:
            if not is_numeric_list and isinstance(x, bool):
                processed_list.append(int(x))
            elif isinstance(x, (int, float)) and not np.isnan(float(x)):
                processed_list.append(float(x))
            elif isinstance(x, str):
                try:
                    val = float(x)
                    if not np.isnan(val):
                        processed_list.append(val)
                except ValueError:
                    logger.debug(f"Could not convert string to float in calculate_mean_stderr: {x}")
                    continue
    
    if not processed_list:
        return np.nan, np.nan
        
    mean_val = np.mean(processed_list)
    
    if len(processed_list) < 2:
        logger.warning(f"Not enough data points ({len(processed_list)}) to calculate bootstrap standard error for mean {mean_val}. Returning NaN for stderr.")
        stderr_val = np.nan
    else:
        try:
            data_array = np.array(processed_list)
            res = bootstrap((data_array,), np.mean, n_resamples=n_resamples, method='basic', confidence_level=0.95)
            stderr_val = res.standard_error
        except Exception as e:
            logger.error(f"Error during bootstrap calculation: {e}. Returning NaN for stderr.")
            stderr_val = np.nan
            
    return mean_val, stderr_val

def plot_metric_on_axes(ax: plt.Axes, metric_data: dict, metric_name: str, model_name_simple: str):
    logger.info(f"Plotting {metric_name} differences for model {model_name_simple} on provided axes...")
    
    original_values = metric_data.get("original")

    if not original_values:
        logger.warning(f"No 'original' data found for {metric_name}. Cannot compute differences.")
        ax.text(0.5, 0.5, f"No 'original' data for {metric_name}", ha='center', va='center')
        ax.set_title(f"Mean {metric_name.capitalize()} vs. k ({model_name_simple}) - Error: No Original Data")
        return

    plot_series_data = {}

    for method_name, method_values_list in metric_data.items():
        if method_name == "original":
            continue

        if len(method_values_list) != len(original_values):
            logger.warning(f"Length mismatch between '{method_name}' ({len(method_values_list)}) and 'original' ({len(original_values)}) for {metric_name}. Skipping method.")
            continue

        differences_list = []
        for i in range(len(original_values)):
            orig_val = original_values[i]
            method_val = method_values_list[i]

            if isinstance(orig_val, bool):
                orig_val = int(orig_val)
            if isinstance(method_val, bool):
                method_val = int(method_val)

            if orig_val is not None and method_val is not None and not (isinstance(orig_val, float) and np.isnan(orig_val)) and not (isinstance(method_val, float) and np.isnan(method_val)):
                if metric_name == "perplexity":
                    differences_list.append(orig_val - method_val)
                elif metric_name == "repetition":
                    differences_list.append(method_val - orig_val)
                else:
                    differences_list.append(np.nan)
            else:
                differences_list.append(np.nan)
        
        mean_of_diff, stderr_of_diff_mean = calculate_mean_stderr(differences_list)
        
        if np.isnan(mean_of_diff):
            logger.warning(f"Skipping method '{method_name}' for {metric_name} differences due to insufficient data or NaN mean of differences.")
            continue

        top_k_match = re.fullmatch(r"top_k_(\d+)", method_name)
        if top_k_match:
            k_val = int(top_k_match.group(1))
            series_label = "Top-K"
            if series_label not in plot_series_data:
                plot_series_data[series_label] = {'k_values': [], 'means': [], 'stderrs': []}
            plot_series_data[series_label]['k_values'].append(k_val)
            plot_series_data[series_label]['means'].append(mean_of_diff)
            plot_series_data[series_label]['stderrs'].append(stderr_of_diff_mean)
            continue

        primal_match = re.fullmatch(r"primal_alpha_([\d.]+)_k_(\d+)", method_name)
        if primal_match:
            alpha_val = float(primal_match.group(1))
            k_val = int(primal_match.group(2))
            series_label = f"Primal (alpha={alpha_val})"
            
            if series_label not in plot_series_data:
                plot_series_data[series_label] = {'k_values': [], 'means': [], 'stderrs': []}
            
            plot_series_data[series_label]['k_values'].append(k_val)
            plot_series_data[series_label]['means'].append(mean_of_diff)
            plot_series_data[series_label]['stderrs'].append(stderr_of_diff_mean)

    if not plot_series_data:
        logger.info(f"No data found to plot {metric_name} differences.")
        ax.text(0.5, 0.5, f"No data for {metric_name} differences", ha='center', va='center')
        ax.set_title(f"{metric_name.capitalize()} Difference vs. k") 
        return

    colors = plt.cm.get_cmap('tab10', len(plot_series_data) if len(plot_series_data) > 0 else 1)
    markers = ['o', 's', '^', 'D', 'v', '<', '>']
    plot_idx = 0

    for series_label, data in plot_series_data.items():
        if not data['k_values']:
            logger.warning(f"No k_values for series {series_label} in {metric_name} differences plot. Skipping.")
            continue
            
        sorted_indices = np.argsort(data['k_values'])
        k_sorted = np.array(data['k_values'])[sorted_indices]
        means_sorted = np.array(data['means'])[sorted_indices]
        stderrs_sorted = np.array(data['stderrs'])[sorted_indices]
        
        valid_stderr_indices = ~np.isnan(stderrs_sorted)
        
        current_color = colors(plot_idx % colors.N)
        current_marker = markers[plot_idx % len(markers)]

        ax.plot(k_sorted, means_sorted, marker=current_marker, linestyle='-', label=series_label, color=current_color)
        if np.any(valid_stderr_indices):
            means_1d = means_sorted[valid_stderr_indices].flatten()
            stderrs_1d = stderrs_sorted[valid_stderr_indices].flatten()
            ax.fill_between(k_sorted[valid_stderr_indices].flatten(), 
                             (means_1d - stderrs_1d), 
                             (means_1d + stderrs_1d), 
                             alpha=0.2, color=current_color)
        plot_idx += 1

    ax.set_xlabel("k")
    if metric_name == "perplexity":
        ax.set_ylabel("Perplexity Diff")
    elif metric_name == "repetition":
        ax.set_ylabel(" Repetition Diff (%)")
    else:
        ax.set_ylabel(f"Mean {metric_name.capitalize()} Difference")
    
    if plot_idx > 0:
        ax.legend()
    ax.grid(True)

def run_analysis_and_plot(metrics_file_path_str: str, model_name_simple: str):
    metrics_file_path = Path(metrics_file_path_str)
    output_dir = metrics_file_path.parent
    output_dir.mkdir(parents=True, exist_ok=True)

    if not metrics_file_path.exists():
        logger.error(f"Metrics file not found: {metrics_file_path}")
        return

    try:
        with open(metrics_file_path, 'r') as f:
            all_metrics_data = json.load(f)
    except json.JSONDecodeError as e:
        logger.error(f"Error decoding JSON from {metrics_file_path}: {e}")
        return
    except Exception as e:
        logger.error(f"Error reading metrics file {metrics_file_path}: {e}")
        return

    fig, axes = plt.subplots(1, 2, figsize=(22, 8))

    perplexity_data_available = False
    if "perplexity_data" in all_metrics_data and all_metrics_data["perplexity_data"]:
        plot_metric_on_axes(axes[0], all_metrics_data["perplexity_data"], "perplexity", model_name_simple)
        perplexity_data_available = True
    else:
        logger.warning(f"No perplexity data found or data is empty in {metrics_file_path} for model {model_name_simple}.")
        axes[0].text(0.5, 0.5, "No perplexity data", ha='center', va='center')
        axes[0].set_title("Perplexity Difference vs. k")
        axes[0].set_xlabel("k value / k_max value")
        axes[0].set_ylabel("Mean PPL Diff (Original - Method)")

    repetition_data_available = False
    if "repetition_data" in all_metrics_data and all_metrics_data["repetition_data"]:
        plot_metric_on_axes(axes[1], all_metrics_data["repetition_data"], "repetition", model_name_simple)
        repetition_data_available = True
    else:
        logger.warning(f"No repetition data found or data is empty in {metrics_file_path} for model {model_name_simple}.")
        axes[1].text(0.5, 0.5, "No repetition data", ha='center', va='center')
        axes[1].set_title("Repetition Difference vs. k")
        axes[1].set_xlabel("k value / k_max value")
        axes[1].set_ylabel("Mean Repetition Diff (Method - Original)")

    fig.suptitle(f"Results for {model_name_simple}", fontsize=16)
    fig.tight_layout(rect=[0, 0.03, 1, 0.95])

    if perplexity_data_available or repetition_data_available:
        plot_path = output_dir / f"Perplexity_Repetition_comparison_plot_{model_name_simple}.pdf"
        try:
            plt.savefig(plot_path)
            logger.info(f"Saved combined metric differences plot to {plot_path}")
        except Exception as e:
            logger.error(f"Error saving combined plot {plot_path}: {e}")
        finally:
            plt.close(fig)
    else:
        logger.info(f"No data was available to plot differences for model {model_name_simple}. No PDF generated.")
        plt.close(fig)

    logger.info(f"Analysis and plotting of differences finished for model {model_name_simple}. Plots, if generated, are in {output_dir}")
