import argparse
import json
import numpy as np
import pathlib
import math


def get_sensitivity_mean(fin):
    """
    Reads file, returns the MEAN sensitivity (scalar) for that trial.
    """
    values = []
    with open(fin) as f:
        for line in f:
            data = json.loads(line.strip())
            values.append(data['sensitivity'][0])

    if not values:
        return 0.0
    return np.mean(values)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--base-dir', type=pathlib.Path, required=True)
    parser.add_argument('--language', type=pathlib.Path, required=True)
    parser.add_argument('--dataset', default='test')
    parser.add_argument('--num-trials', type=int, required=True)
    parser.add_argument('--string-length', type=int, required=True)
    parser.add_argument('--num-languages', type=int, required=True)
    parser.add_argument('--tex-output', type=pathlib.Path, required=True)
    args = parser.parse_args()

    args.tex_output.parent.mkdir(parents=True, exist_ok=True)

    num_bins = 1000
    max_sensitivity = args.string_length
    bins = np.linspace(0, max_sensitivity, num_bins + 1)

    # Store the histogram for each LANGUAGE
    all_language_histograms = []

    print(f"Processing {args.num_languages} languages...")

    loop_range = range(args.num_languages) if args.num_languages >= 1 else range(1)

    for language_num in loop_range:
        # Collect all trial averages for THIS language
        current_lang_means = []

        if args.num_languages > 1:
            lang_dir = f'{args.language}-{language_num + 1}'
        else:
            lang_dir = f'{args.language}'

        for trial in range(args.num_trials):
            try:
                log_file = args.base_dir / 'models' / lang_dir / 'transformer' / f'{args.string_length}' / f'{trial + 1}' / 'eval' / f'sensitivity-{args.dataset}.jsonl'

                if log_file.exists():
                    # Get single scalar average for this trial
                    trial_mean = get_sensitivity_mean(log_file)
                    current_lang_means.append(trial_mean)
            except Exception:
                continue

        # If we have data for this language, make a histogram
        if current_lang_means:
            counts, _ = np.histogram(current_lang_means, bins=bins, density=False)
            # log_counts = [math.log(c) if c > 0 else 0.0 for c in counts]

            # Log Transform
            all_language_histograms.append(counts)

    if not all_language_histograms:
        print("No data found.")
        return

    print(f"Generated {len(all_language_histograms)} language histograms.")

    tick_step = 20
    upper_bound = math.ceil(max_sensitivity)
    custom_ticks_list = range(0, upper_bound + tick_step, tick_step)
    xtick_str = ",".join(map(str, custom_ticks_list))

    with args.tex_output.open('w') as fout:
        fout.write(f"""
\\begin{{tikzpicture}}
    \\begin{{axis}}[
        ybar,
        axis lines=left,  
        axis line style={{-}},  
        enlargelimits=false,
        bar shift=0pt,
        ymin=0, 
        ymax=6,
        restrict y to domain*=0:8,
        xmin=-1.5,
        xtick={{{xtick_str}}},
        xlabel={{Average Sensitivity}},
        ylabel style={{align=center}},
        ylabel={{Frequency}},
        xmajorgrids=false,
        bar width=3pt,
        colormap={{muted}}{{color(0cm)=(blue!60!gray); color(1cm)=(red!60!gray)}}
    ]
""")
        for hist_data in all_language_histograms:

            table_rows = []
            for i, val in enumerate(hist_data):
                bin_left = bins[i]
                table_rows.append(f"{bin_left} {val}")

            table_str = "\n            ".join(table_rows)

            fout.write(f"""
        \\addplot+[
            fill=blue!60!gray,
            draw=none,
            opacity=0.5, 
            forget plot
        ] table [x index=0, y index=1] {{
            {table_str}
        }};
""")

        fout.write("""
    \\end{axis}
\\end{tikzpicture}
""")


if __name__ == '__main__':
    main()