import os
import re
import math
from collections import defaultdict
from typing import Dict, Iterable, List
import json

import numpy as np
import matplotlib.pyplot as plt

try:
    from scipy.stats import gaussian_kde
    _HAS_SCIPY = True
except Exception:
    _HAS_SCIPY = False
    
def extract_paper_fp_from_review_fp(review_filepath):
    ## extract the paper contents 
    pattern = r".*subset-3743-latest/(.*)/(train|test|dev)/.*(level[1-4]|reviews)/(.*)_([1-9]*).txt"
    match = re.search(pattern, review_filepath)
    
    if match is not None:
        conference = match.group(1)
        split = match.group(2)
        level = match.group(3)
        paper_number = match.group(4)
        reviewer_number = match.group(5)

        # return conference, split, level, paper_number, reviewer_number
        generating_model = "OLD PARSER FUNCTION: GENRATING MODEL NOT PARSED"
        prompt = f"{level}@NAV" if level != "reviews" else "HUMAN"

    else:
        pattern = r".*subset-3743-latest/(.*)/(train|test|dev)/(.*)/(level[1-4]|reviews)/(.*).txt"
        match = re.search(pattern, review_filepath)

        conference = match.group(1)
        split = match.group(2)
        paper_number = match.group(3)
        level = match.group(4)

        if '/' in match.group(5):
            generating_model = match.group(5).split('/')[0]
            fileid = match.group(5).split('/')[1]
        else:
            generating_model = "human_review"
            fileid = match.group(5)

        if ":::" in fileid:
            reviewer_number = fileid.split(":::")[-1]
            prompt = fileid.split(":::")[0]
        else:
            reviewer_number = fileid
            if level != "reviews":
                prompt = f"{level}@NAV"
            else:
                prompt = "HUMAN"

    return conference, split, level, paper_number, reviewer_number, generating_model, prompt


def plot_feature_distributions(
    data: Dict[str, Dict],
    features: Iterable[str],
    levels: List[str] = None,
    prompts: List[str] = None,
    output_dir: str = "plots",
    bins: int = 30,
    figsize_per_subplot: tuple = (6, 3),
    alpha: float = 0.5,
):
    """
    Plot distribution of given feature(s) grouped by level and prompt id.

    Args:
        data: mapping from file-path (string) -> metrics dict (the structure you showed).
        features: iterable of feature names (strings) to plot. Each feature will produce one PNG.
        levels: list of levels (strings) in desired subplot order. Defaults to
                ['level1','level2','level3','level4','reviews'].
        output_dir: directory to save PNG files.
        bins: number of histogram bins.
        figsize_per_subplot: width,height per subplot (so total fig size = (w, h*len(levels))).
        alpha: histogram fill transparency.

    Behavior details:
        - The function extracts `level` by searching the path for '/level1/', '/level2/', ... or '/reviews/'.
        - The prompt_id is extracted from the filename before ':::' (os.path.basename(path).split(':::')[0]).
        - If a given level has no data for a feature, its subplot will show "no data".
        - If scipy is installed, a KDE curve is drawn for each prompt_id when there are >=2 samples.
    """

    if levels is None:
        levels = ["level1", "level2", "level3", "level4", "reviews"]

    os.makedirs(output_dir, exist_ok=True)

    # Helper regex to find level in path
    level_regex = re.compile(r"/(level[1-4]|reviews)/", flags=re.IGNORECASE)

    # organize values: level -> prompt_id -> list of numbers (for each feature we will remake this)
    # But we'll build per-feature to avoid memory of many different features if data huge.

    # color cycle for many prompt ids
    color_cycle = plt.rcParams['axes.prop_cycle'].by_key().get('color', None)

    for feature in features:
        # Build nested dict for this feature
        grouped = {lv: defaultdict(list) for lv in levels}
        unknowns = defaultdict(list)

        for path, metrics in data.items():
            # skip if feature not present
            if not isinstance(metrics, dict) or feature not in metrics:
                continue
            val = metrics[feature]
            # only numeric values make sense for distribution
            try:
                val = float(val)
            except Exception:
                continue

            _, _, lvl, _, _, _, pid = extract_paper_fp_from_review_fp(path)

            # if none of the strings in prompts list is a substring of pid, skip
            
            if not any(p in pid for p in prompts):
                continue

            if "NAV_2" in pid:
                continue

            if lvl in grouped:
                grouped[lvl][pid].append(val)
            else:
                unknowns[lvl].append((pid, val))

        # Prepare figure
        n_levels = len(levels)
        fig_w = figsize_per_subplot[0]
        fig_h = figsize_per_subplot[1] * n_levels
        fig, axes = plt.subplots(n_levels, 1, figsize=(fig_w, fig_h), squeeze=False)
        axes = axes.flatten()

        # global x range to align bins across subplots (compute from all values)
        all_values = []
        for lv in levels:
            for arr in grouped[lv].values():
                all_values.extend(arr)
        # if no values at all for this feature -> warn and skip plotting (but create empty file)
        if len(all_values) == 0:
            print(f"[warning] feature '{feature}' has no numeric data in provided dict. Creating empty placeholder plot.")
            axes[0].text(0.5, 0.5, f"No data for feature '{feature}'", ha="center", va="center")
            for ax in axes[1:]:
                ax.axis('off')
            fig.suptitle(feature)
            outpath = os.path.join(output_dir, f"{feature}_distribution_by_level.png")
            fig.tight_layout()
            fig.savefig(outpath, dpi=150)
            plt.close(fig)
            continue

        global_min = min(all_values)
        global_max = max(all_values)
        # extend range slightly
        rng = global_max - global_min
        if math.isfinite(rng) and rng == 0:
            # small epsilon
            global_min -= 0.5
            global_max += 0.5
        else:
            global_min -= 0.05 * (rng if rng != 0 else 1.0)
            global_max += 0.05 * (rng if rng != 0 else 1.0)

        bins_edges = np.linspace(global_min, global_max, bins + 1)

        # Plot for each level
        for i, lvl in enumerate(levels):
            ax = axes[i]
            ax.set_title(f"{lvl} — {feature}")
            # if no prompts in this level
            if len(grouped[lvl]) == 0:
                ax.text(0.5, 0.5, "no data", ha="center", va="center")
                ax.set_xlim(global_min, global_max)
                continue

            # plot each prompt id as a semi-transparent histogram
            for j, (pid, values) in enumerate(sorted(grouped[lvl].items(), key=lambda x: -len(x[1]))):
                if len(values) == 0:
                    continue
                # choose color from cycle if available
                color = None
                if color_cycle:
                    color = color_cycle[j % len(color_cycle)]
                # histogram (density normalized)
                # ax.hist(values, bins=bins_edges, density=True, alpha=alpha, label=f"{pid} (n={len(values)})",
                #         histtype="stepfilled", edgecolor='none', color=color)

                # optional KDE/density curve if scipy is available and enough samples
                if _HAS_SCIPY and len(values) >= 2:
                    try:
                        kde = gaussian_kde(values)
                        xs = np.linspace(global_min, global_max, 200)
                        ys = kde(xs)
                        ax.plot(xs, ys, linewidth=1.2, color=color, label=f"{pid} (n={len(values)})")
                    except Exception:
                        pass

            ax.legend(fontsize="small", loc="upper right")
            ax.set_xlim(global_min, global_max)
            ax.set_ylabel("Density")

        fig.suptitle(f"Distribution of '{feature}' by level and prompt-id", fontsize=14)
        fig.tight_layout(rect=[0, 0, 1, 0.97])

        outpath = os.path.join(output_dir, f"{feature}_distribution_by_level.png")
        fig.savefig(outpath, dpi=150)
        plt.close(fig)
        print(f"Saved: {outpath}")

    print("All done.")


# -------------------------
# Example usage:
# -------------------------
# Suppose `mydict` is your dictionary (path->metrics). To plot 'WordCount' and 'AvgWordLength':
#
# plot_feature_distributions(mydict, features=['WordCount', 'AvgWordLength'], output_dir='plots', bins=25)
#
# The function will create `plots/WordCount_distribution_by_level.png` and
# `plots/AvgWordLength_distribution_by_level.png`.
#
# Notes / possible tweaks:
# - If you want one subplot per row as a horizontal layout, change fig, axes creation accordingly.
# - If you want separate files per level instead of subplots, that's easy to change.
# - If you prefer colored outlines instead of filled histograms, set histtype='step' and alpha appropriately.

with open("/ai-involvement-in-peer-reviews/linguistic_features/extracted_features/subset-3743-25-10-25/linguistic_features_all_generations.json", "r") as fin:
    data = json.load(fin)
    
features = []
prompts = []
final_dict = {}

models_of_interest = ["gpt-5"]

for key, val in data.items():
    conference, split, level, paper_number, reviewer_number, generating_model, prompt = extract_paper_fp_from_review_fp(key)
    if generating_model in (models_of_interest + ["human_review"]): 
        final_dict[key] = val
        for feature_name in val.keys():
            if "Count" not in feature_name and feature_name not in features:
                features.append(feature_name)
        prompts.append(prompt)

features = ['StopWordPercentage', 'TotalSyllables', 'HapaxLegomenonRate', 'RTTR', 'BigramUniqueness']
prompts = ['level1@NAV', 'level2@NAV', 'level3@NAV', 'level4@NAV', 'HUMAN']

plot_feature_distributions(
    final_dict,
    features=features,
    prompts=prompts,
    output_dir=f"plots/{':'.join(models_of_interest)}",
    bins=50
)