import argparse
import json
import numpy as np
import pathlib
import math


def get_sensitivity(fin):
    """
    Yields every local sensitivity value found in the evaluation file.
    """
    with open(fin) as f:
        for line in f:
            data = json.loads(line.strip())
            yield data['sensitivity'][0]


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--base-dir', type=pathlib.Path, required=True)
    parser.add_argument('--language', type=pathlib.Path, required=True)
    parser.add_argument('--dataset', default='test', help="Name of the jsonl file to read")
    parser.add_argument('--num-trials', type=int, required=True)
    parser.add_argument('--string-length', type=int, required=True)
    parser.add_argument('--num-languages', type=int, required=True)
    parser.add_argument('--tex-output', type=pathlib.Path, required=True)
    args = parser.parse_args()

    num_bins = 100
    max_sensitivity = args.string_length
    bins = np.linspace(0, max_sensitivity, num_bins + 1)

    args.tex_output.parent.mkdir(parents=True, exist_ok=True)

    all_individual_histograms = []

    print(f"Processing {args.num_languages} languages x {args.num_trials} trials...")
    print(f"Reading dataset file: {args.dataset}.jsonl")

    loop_range = range(args.num_languages) if args.num_languages >= 1 else range(1)

    for language_num in loop_range:
        # Construct path for this language (handling 'language-1' vs 'language')
        if args.num_languages > 1:
            # e.g., random-language-1
            lang_dir = f'{args.language}-{language_num + 1}'
        else:
            lang_dir = f'{args.language}'

        for trial in range(args.num_trials):
            try:
                log_file = args.base_dir / 'models' / lang_dir / 'transformer' / f'{args.string_length}' / f'{trial + 1}' / 'eval' / f'sensitivity-{args.dataset}.jsonl'

                if log_file.exists():
                    # Get raw sensitivities for ONE trial
                    trial_sensitivities = list(get_sensitivity(log_file))

                    if trial_sensitivities:
                        # Compute histogram for this SINGLE trial
                        counts, _ = np.histogram(trial_sensitivities, bins=bins, density=False)

                        # Log Transform (handling zeros)
                        log_counts = [math.log(c) if c > 0 else 0.0 for c in counts]
                        all_individual_histograms.append(log_counts)
            except Exception:
                continue

    if not all_individual_histograms:
        print(
            f"ERROR: No data found. Checked paths like: .../models/{{lang}}/transformer/{{len}}/{{trial}}/eval/{args.dataset}.jsonl")
        return

    print(f"Total collected histograms: {len(all_individual_histograms)}")

    # We want ALL unique histograms, but we cap identical duplicates at 10
    final_histograms_to_plot = []
    histogram_counts = {}

    for hist in all_individual_histograms:
        # Convert list to tuple so it can be used as a dictionary key
        h_tuple = tuple(hist)

        # Get how many times we have seen this EXACT histogram shape
        current_count = histogram_counts.get(h_tuple, 0)

        if current_count < 10:
            # If we haven't seen it 10 times yet, keep it!
            final_histograms_to_plot.append(hist)
            histogram_counts[h_tuple] = current_count + 1

    print(f"Plotting {len(final_histograms_to_plot)} histograms (duplicates capped at 10).")

    # Generate tick labels
    tick_step = 20
    upper_bound = math.ceil(max_sensitivity)
    custom_ticks_list = range(0, upper_bound + tick_step, tick_step)
    xtick_str = ",".join(map(str, custom_ticks_list))

    # CRITICAL FIX: Create the output directory if it doesn't exist
    if not args.tex_output.parent.exists():
        print(f"Creating directory: {args.tex_output.parent}")
        args.tex_output.parent.mkdir(parents=True, exist_ok=True)

    with args.tex_output.open('w') as fout:
        fout.write(f"""
\\begin{{tikzpicture}}
    \\begin{{axis}}[
        ybar,
        axis lines=left,  
        axis line style={{-}},  
        enlargelimits=false,
        bar shift=0pt,
        ymin=0, 
        xmin=-1.5,
        xtick={{{xtick_str}}},
        xlabel={{Sensitivity}},
        ylabel style={{align=center}},
        ylabel={{Log Frequency}},
        xmajorgrids=false,
        bar width=3pt,
        colormap={{muted}}{{color(0cm)=(blue!60!gray); color(1cm)=(red!60!gray)}}
    ]
""")

        # LOOP over the FILTERED list
        for hist_data in final_histograms_to_plot:

            # Format data for PGFPlots
            table_rows = []
            for i, val in enumerate(hist_data):
                bin_left = int(bins[i])
                table_rows.append(f"{bin_left} {val}")

            table_str = "\n            ".join(table_rows)

            fout.write(f"""
        \\addplot+[
            fill=blue!60!gray,
            draw=none,
            opacity=0.25, 
            forget plot
        ] table [x index=0, y index=1] {{
            {table_str}
        }};
""")

        fout.write("""
    \\end{axis}
\\end{tikzpicture}
""")

    print(f"Successfully wrote LaTeX to {args.tex_output}")


if __name__ == '__main__':
    main()