import os
import matplotlib.pyplot as plt
from collections import defaultdict
import numpy as np
import pandas as pd
import json

train_data_dir = "./data/flare_subsampled"
results_data_dir = "./results/ICL"
def compute_length_stats(lengths):
    return {
        "mean": np.mean(lengths),
        "std": np.std(lengths),
        "min": np.min(lengths),
        "max": np.max(lengths),
        "count": len(lengths)
    }

def histograms_string_length(data_dir=train_data_dir):
    language_lengths = defaultdict(list)
    statistics = []

    # Walk through each language directory
    for lang in os.listdir(data_dir):
        lang_path = os.path.join(data_dir, lang, "data.test")
        if not os.path.isfile(lang_path):
            continue

        with open(lang_path, "r") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                tokens = line.split()
                language_lengths[lang].append(len(tokens))

    # Plot histogram and compute stats
    for lang, lengths in language_lengths.items():
        # Histogram
        plt.figure()
        plt.hist(lengths, bins=30, edgecolor='black')
        plt.title(f"Token Length Histogram - {lang}")
        plt.xlabel("Token count")
        plt.ylabel("Frequency")
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(f"figures/{lang}_test_string_length_histogram.png")
        plt.close()

        # Stats
        stats = compute_length_stats(lengths)
        stats["language"] = lang
        statistics.append(stats)

    return statistics

def generate_misclassification_histograms_by_config(
    results_data_dir="./results/ICL",
    output_dir="figures/misclassifications_by_config"
):
    import os
    import json
    import matplotlib.pyplot as plt
    os.makedirs(output_dir, exist_ok=True)

    for model_name in os.listdir(results_data_dir):
        model_path = os.path.join(results_data_dir, model_name)
        if not os.path.isdir(model_path):
            continue

        for prompt_type in os.listdir(model_path):
            prompt_path = os.path.join(model_path, prompt_type)
            if not os.path.isdir(prompt_path):
                continue

            for encoding_strategy in os.listdir(prompt_path):
                encoding_path = os.path.join(prompt_path, encoding_strategy)
                if not os.path.isdir(encoding_path):
                    continue

                for language in os.listdir(encoding_path):
                    language_path = os.path.join(encoding_path, language)
                    if not os.path.isdir(language_path):
                        continue

                    for filename in os.listdir(language_path):
                        if not filename.endswith(".json"):
                            continue

                        file_path = os.path.join(language_path, filename)

                        try:
                            with open(file_path, "r") as f:
                                data = json.load(f)
                            raw_results = data.get("raw_results", [])

                            misclassified_lengths = [
                                len(entry["sequence"].split())
                                for entry in raw_results
                                if entry.get("true_label") != entry.get("model_output")
                                and isinstance(entry.get("sequence"), str)
                            ]

                            if misclassified_lengths:
                                # Build output file name
                                base_name = os.path.splitext(filename)[0]
                                safe_file_name = f"{model_name}_{prompt_type}_{encoding_strategy}_{language}_{base_name}.png"
                                save_path = os.path.join(output_dir, safe_file_name)

                                # Plot histogram
                                plt.figure()
                                plt.hist(misclassified_lengths, bins=30, edgecolor='black')
                                plt.title(f"Misclassified Lengths\n{model_name}, {prompt_type}, {encoding_strategy}, {language}")
                                plt.xlabel("Sequence Length")
                                plt.ylabel("Misclassification Count")
                                plt.grid(True)
                                plt.tight_layout()
                                plt.savefig(save_path)
                                plt.close()

                                print(f"Saved: {save_path}")

                        except Exception as e:
                            print(f"Error reading {file_path}: {e}")


def collect_best_configs(results_dir="./results/ICL"):
    best_config = defaultdict(lambda: {"accuracy": -1, "file_path": ""})

    for model in os.listdir(results_dir):
        model_path = os.path.join(results_dir, model)
        if not os.path.isdir(model_path):
            continue

        for prompt in os.listdir(model_path):
            prompt_path = os.path.join(model_path, prompt)
            if not os.path.isdir(prompt_path):
                continue

            for encoding in os.listdir(prompt_path):
                encoding_path = os.path.join(prompt_path, encoding)
                if not os.path.isdir(encoding_path):
                    continue

                for language in os.listdir(encoding_path):
                    lang_path = os.path.join(encoding_path, language)
                    if not os.path.isdir(lang_path):
                        continue

                    for fname in os.listdir(lang_path):
                        if not fname.endswith(".json"):
                            continue
                        fpath = os.path.join(lang_path, fname)

                        try:
                            with open(fpath, "r") as f:
                                data = json.load(f)
                            acc = data.get("accuracy", -1)
                            if acc > best_config[(model, language)]["accuracy"]:
                                best_config[(model, language)] = {
                                    "accuracy": acc,
                                    "file_path": fpath
                                }
                        except Exception as e:
                            print(f"Failed reading {fpath}: {e}")

    return best_config

def extract_misclassified_lengths(file_path):
    with open(file_path, "r") as f:
        data = json.load(f)

    raw_results = data.get("raw_results", [])
    misclassified_lengths = []

    for entry in raw_results:
        seq = entry.get("sequence")
        if not isinstance(seq, str):
            continue

        # First try to use the 'correct' field if available
        if "correct" in entry:
            if entry["correct"] is False:
                misclassified_lengths.append(len(seq.split()))
            continue

        true = entry.get("true_label")
        pred = entry.get("model_output") or entry.get("predicted_label")

        if true is None or pred is None:
            continue  # Skip incomplete entries

        try:
            if int(true) != int(pred):
                misclassified_lengths.append(len(seq.split()))
        except Exception as e:
            print(f"Bad entry skipped in {file_path}: {e}")

    return misclassified_lengths


def plot_misclass_histograms(best_config, output_dir="./figures/misclassifications_by_language_best_config"):
    languages = set(lang for (_, lang) in best_config)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    for language in sorted(languages):
        lengths_by_model = {}

        for model in ["chatgpt", "deepseek-chat"]:
            key = (model, language)
            if key in best_config:
                file_path = best_config[key]["file_path"]
                lengths_by_model[model] = extract_misclassified_lengths(file_path)

        if not lengths_by_model:
            continue

        # Plot
        max_len = max(len for lengths in lengths_by_model.values() for len in lengths)
        bins = np.arange(0, max_len + 20, 20)

        plt.figure()
        for model, lengths in lengths_by_model.items():
            plt.hist(lengths, bins=bins, alpha=0.6, label=f"{model} ({len(lengths)} errors)", edgecolor="black")

        plt.title(f"Misclassified Sequence Lengths - {language}")
        plt.xlabel("Sequence Length")
        plt.ylabel("Count")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        out_path = os.path.join(output_dir, f"{language}_misclassification_hist.png")
        plt.savefig(out_path)
        plt.close()
        print(f"Saved: {out_path}")

def count_invalid_predictions(root_dir="./results/ICL"):
    total_files = 0
    total_invalid = 0
    invalid_per_file = {}

    for dirpath, _, filenames in os.walk(root_dir):
        for filename in filenames:
            if not filename.endswith(".json"):
                continue
            file_path = os.path.join(dirpath, filename)
            try:
                with open(file_path, "r") as f:
                    data = json.load(f)
                raw_results = data.get("raw_results", [])
                count = 0
                for entry in raw_results:
                    output = entry.get("model_output", entry.get("predicted_label"))
                    if str(output).strip() == "-1":
                        count += 1
                if count > 0:
                    invalid_per_file[file_path] = count
                    total_invalid += count
                total_files += 1
            except Exception as e:
                print(f"Error processing {file_path}: {e}")

    print(f"\nScanned {total_files} JSON files.")
    print(f"Total invalid predictions (-1): {total_invalid}\n")
    if invalid_per_file:
        print("Breakdown per file:")
        for path, cnt in invalid_per_file.items():
            print(f"  {path} → {cnt} invalid entries")

    return total_invalid, invalid_per_file

def main():
    # if not os.path.exists("figures"):
    #     os.makedirs("figures")

    stats = histograms_string_length()
    # df = pd.DataFrame(stats).sort_values(by="mean", ascending=False)

    # # Round numeric columns
    # df[["mean", "std", "min", "max"]] = df[["mean", "std", "min", "max"]].round(2)

    # # Reorder columns to put 'language' first
    # df = df[["language", "mean", "std", "min", "max", "count"]]

    # # Print to console
    # print(df.to_string(index=False))

    # # Save LaTeX table
    # caption = "This table reports summary statistics (mean, standard deviation, min, max, count) of the symbol lengths of training examples for each formal language in the subsampled FlaRe dataset."
    # with open("./tables/train_length_statistics.txt", "w") as f:
    #     f.write(df.style.to_latex(
    #         caption=caption,
    #         label="tab:token_lengths",
    #         hrules=True
    #     ))
    #
    # # Generate misclassification histograms
    # #generate_misclassification_histograms_by_config()
    # best_config = collect_best_configs()
    # plot_misclass_histograms(best_config)
    # count_invalid_predictions()

if __name__ == "__main__":
    main()
