import os
import json
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import argparse

def flatten(list_of_lists):
    return [item for sublist in list_of_lists for item in sublist]

def plot_distribution(train_values, val_values, title, ylabel, output_path):
    plt.figure(figsize=(8,5))
    plt.hist(train_values, bins=100, alpha=0.5, label='train', color='blue', density=True)
    plt.hist(val_values, bins=100, alpha=0.5, label='val', color='orange', density=True)
    plt.title(title)
    plt.xlabel('Value')
    plt.ylabel(ylabel)
    plt.legend()
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()

def process_model_dataset(model_dir, dataset_dir, charts_dir, model_short_name):
    train_file = dataset_dir / "train_metrics.json"
    val_file = dataset_dir / "val_metrics.json"
    if not train_file.exists() or not val_file.exists():
        return

    with open(train_file, "r") as f:
        train_metrics = json.load(f)
    with open(val_file, "r") as f:
        val_metrics = json.load(f)

    train_similarity = flatten(train_metrics.get("similarity", []))
    val_similarity = flatten(val_metrics.get("similarity", []))
    # probability
    train_probability = flatten(train_metrics.get("probability", []))
    val_probability = flatten(val_metrics.get("probability", []))

    charts_out = Path(charts_dir) / model_dir.name / dataset_dir.name
    charts_out.mkdir(parents=True, exist_ok=True)

    plot_distribution(
        train_similarity, val_similarity,
        f"{model_short_name}; {dataset_dir.name}; Similarity distribution",
        "Density",
        charts_out / "similarity_distribution.pdf"
    )
    plot_distribution(
        train_probability, val_probability,
        f"{model_short_name}; {dataset_dir.name}; Probability distribution",
        "Density",
        charts_out / "probability_distribution.pdf"
    )

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--raw_values_dir", type=str, required=True)
    parser.add_argument("--charts_dir", type=str, required=True)
    parser.add_argument("--model_name", type=str, required=True)
    args = parser.parse_args()

    raw_values_dir = args.raw_values_dir
    charts_dir = args.charts_dir
    model_short_name = Path(args.model_name).name
    model_name_safe = args.model_name.replace("/", "_")

    model_dir = Path(raw_values_dir) / model_name_safe
    if not model_dir.is_dir():
        print(f"Model directory not found: {model_dir}")
        return

    for dataset_dir in model_dir.iterdir():
        if not dataset_dir.is_dir():
            continue
        process_model_dataset(model_dir, dataset_dir, charts_dir, model_short_name)

if __name__ == "__main__":
    main()
