import argparse
import glob
import json
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from causallearn.utils.cit import CIT
from tqdm import tqdm

from mi_estimators.mi_hsic import hsic_estimate
from models import training
from tests.mutual_information_estimators import estimate_mi_binning


RUNNING_AVG_WINDOW = 20  # Set the window size for running average


def parse_args():
    parser = argparse.ArgumentParser(description="Plot training curve from logs.")
    parser.add_argument(
        "--log_folder",
        type=str,
        help="Path to the folder containing training logs.",
        default="",
    )
    parser.add_argument(
        "--only_save",
        action="store_true",
        help="Just store the plot.",
        default=False,
    )
    return parser.parse_args()


def load_predictions(file_path):
    # Find all .pt files matching the prediction pattern
    pt_paths = glob.glob(os.path.join(file_path, "epoch_*_predictions_*.pt"))

    # Sort by epoch number (assumes filenames like "epoch_{num}_predictions_*.pt")
    pt_paths.sort(key=lambda p: int(os.path.basename(p).split("_")[1]))

    # Prepare containers
    predictions = {
        "results_recovering_x_given_y": [],
        "results_recovering_y_given_x": [],
    }

    # Load each file into the appropriate list
    for path in pt_paths:
        fname = os.path.basename(path)
        tensor = torch.load(path, map_location=torch.device("cpu"))
        if "x_given_y" in fname:
            predictions["results_recovering_x_given_y"].append(tensor)
        elif "y_given_x" in fname:
            predictions["results_recovering_y_given_x"].append(tensor)

    return predictions


# In: predictions
# Out: dict of additional information
def additional_information(predictions):
    additional_info = {}
    for direction, tensors in predictions.items():
        additional_info[direction] = []
        for tensor in tqdm(tensors[-5:]):
            # pred_noise, noise, c = tensor[0], tensor[1], tensor[2]

            # hsic_xx = hsic_estimate(tensor[0], tensor[0])
            # hsic_yy = hsic_estimate(tensor[2], tensor[2])
            # hsic_xy = hsic_estimate(tensor[0], tensor[2])
            # nHSIC = hsic_xy / torch.sqrt(hsic_xx * hsic_yy)
            # # — new: normalize each signal before HSIC —
            # pred_noise = tensor[0]
            # c = tensor[2]
            # pn_norm = (pred_noise - pred_noise.mean()) / pred_noise.std()
            # c_norm = (c - c.mean()) / c.std()

            # hsic_xy_norm = hsic_estimate(pn_norm, c_norm)

            additional_info[direction].append(
                {
                    # "nHSIC": nHSIC.item(),
                    # "HSIC_norm": hsic_xy_norm.item(),
                    # … any other metrics you want …
                }
            )
    return additional_info


if __name__ == "__main__":
    args = parse_args()
    if args.log_folder == "":
        log_files = glob.glob(os.path.join(args.log_folder, "sacred_results_3", "*"))
        log_files = [file for file in log_files if os.path.basename(file).isdigit()]
        if not log_files:
            raise ValueError(
                "No numeric folders found in the specified sacred_results directory."
            )

        largest_number = max(int(os.path.basename(file)) for file in log_files)
        print(f"Largest number in sacred results: {largest_number}")
        args.log_folder = f"sacred_results_2/{largest_number}"
    print(f"Log folder: {args.log_folder}")

    # strip trailing slashes
    args.log_folder = args.log_folder.rstrip("/")

    with open(f"{args.log_folder}/run.json", "r") as f:
        metrics = json.load(f)

    results = metrics["result"]

    # Extract attributes for each direction
    all_direction_attributes = {}
    predictions = load_predictions(args.log_folder)

    if os.path.exists(os.path.join(args.log_folder, "additional_info.json")):
        with open(os.path.join(args.log_folder, "additional_info.json"), "r") as f:
            additional_info = json.load(f)

    else:
        additional_info = {}

    for direction in ["results_recovering_x_given_y", "results_recovering_y_given_x"]:
        train_logging_infos = results[direction].get("train_logging_infos", [])
        val_logging_infos = results[direction].get("validation_logging_infos", [])
        additional_infos = additional_info[direction]
        # Collect union of train/val keys
        keys = set()
        for info in train_logging_infos + val_logging_infos + additional_infos:
            keys.update(info.keys())
        train_direction_attributes = {k: [] for k in keys}
        val_direction_attributes = {k: [] for k in keys}
        for info in train_logging_infos:
            for k in keys:
                train_direction_attributes[k].append(info.get(k))
        for info in val_logging_infos:
            for k in keys:
                val_direction_attributes[k].append(info.get(k))
        for info in additional_infos:
            for k in keys:
                train_direction_attributes[k].append(info.get(k))
        all_direction_attributes[direction] = {
            "train": train_direction_attributes,
            "val": val_direction_attributes,
        }

    # Calculate total number of subplots needed: 2 for losses + all attributes per direction
    total_attributes = sum(
        len(attrs["train"]) for attrs in all_direction_attributes.values()
    )
    num_subplots = 2 + total_attributes

    num_rows = (num_subplots + 2) // 3  # Arrange subplots in rows of 3
    fig, axes = plt.subplots(
        num_rows, min(3, num_subplots), figsize=(8 * min(3, num_subplots), 6 * num_rows)
    )
    axes = (
        axes.flatten() if isinstance(axes, np.ndarray) else [axes]
    )  # Flatten axes if it's an array

    all_losses = []
    all_val_losses = []

    for direction in [
        "results_recovering_x_given_y",
        "results_recovering_y_given_x",
    ]:
        losses = results[direction]["losses"]
        val_losses = results[direction]["validation_losses"]
        all_losses.extend(losses)
        all_val_losses.extend(val_losses)

    # Determine shared y-axis limits
    y_min = min(min(all_losses), min(all_val_losses)) - 0.1
    y_max = max(max(all_losses), max(all_val_losses)) + 0.1

    # Plot losses for each direction
    current_ax = 0
    for direction in ["results_recovering_x_given_y", "results_recovering_y_given_x"]:
        losses = results[direction]["losses"]
        val_losses = results[direction]["validation_losses"]

        # Spread validation losses over all epochs
        spread_val_losses = [None] * (len(losses) - 1)
        factor = len(losses) // (len(val_losses) - 1)
        print(
            f"factor: {factor}, len(losses): {len(losses)}, len(val_losses): {len(val_losses)}"
        )
        for i, val_loss in enumerate(val_losses[:-1]):
            spread_val_losses[i * factor] = val_loss

        if len(spread_val_losses) > 0 and len(val_losses) > 0:
            spread_val_losses[-1] = val_losses[-1]

        # Plot the losses in the respective subplot
        ax = axes[current_ax]
        current_ax += 1
        ax.plot(losses, label="Training Loss")
        ax.plot(
            range(len(spread_val_losses)),
            spread_val_losses,
            label="Validation Loss",
            linestyle="--",
            marker="o",
        )
        ax.set_title(f"Loss Curve for {direction}")
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Loss")
        ax.set_ylim(0, 10)  # Set shared y-axis limits
        ax.legend()
        ax.grid(True)

        # Plot each attribute for this direction
        for attr in sorted(all_direction_attributes[direction]["train"].keys()):

            # if (
            #     attr != "mse"
            #     and attr != "hsic_estimate"
            #     and attr != "nHSIC"
            #     and attr != "HSIC_norm"
            # ):
            #     # Skip attributes that don't start with "bin"
            #     continue

            train_values = all_direction_attributes[direction]["train"][attr]
            ax = axes[current_ax]
            current_ax += 1
            direction_name = "X|Y" if "x_given_y" in direction else "Y|X"

            # filter for None values
            train_values = [v for v in train_values if v is not None]
            # plot training metric
            ax.plot(train_values, label=f"{attr} (Train, {direction_name})", marker="o")

            val_values = all_direction_attributes[direction]["val"].get(attr, [])
            if val_values:
                if len(val_values) < len(train_values):
                    idx = np.linspace(0, len(train_values) - 1, len(val_values)).astype(
                        int
                    )
                    val_x = idx
                else:
                    val_x = range(len(val_values))
                ax.plot(
                    val_x,
                    val_values,
                    label=f"{attr} (Val, {direction_name})",
                    linestyle="--",
                    marker="s",
                )
            if len(train_values) >= RUNNING_AVG_WINDOW:
                ma = np.convolve(
                    train_values,
                    np.ones(RUNNING_AVG_WINDOW) / RUNNING_AVG_WINDOW,
                    mode="valid",
                )
                x_ma = np.arange(RUNNING_AVG_WINDOW - 1, len(train_values))
                ax.plot(
                    x_ma, ma, label=f"{attr} ({RUNNING_AVG_WINDOW}-avg)", linewidth=3
                )
            ax.set_title(f"{attr} over Epochs ({direction_name})")
            ax.set_xlabel("Epoch")
            ax.set_ylabel(attr)
            ax.legend()
            ax.grid(True)

    # Hide unused subplots
    for ax in axes[current_ax:]:
        ax.axis("off")

    with open(f"{args.log_folder}/config.json", "r") as f:
        config = json.load(f)
    # Extract configuration details
    depth = config.get("data_config", "{}").get("dictionary", {}).get("depth", "N/A")
    transform_type = (
        config.get("data_config", {})
        .get("dictionary", {})
        .get("transformation", {})
        .get("type", "N/A")
    )
    sample_size = (
        config.get("data_config", {})
        .get("dictionary", {})
        .get("X", {})
        .get("length", "N/A")
    )
    noise = config.get("data_config", {}).get("dictionary", {}).get("noise_type", "N/A")
    train_loss = config.get("training_hyperparameters", {}).get("train_loss", "N/A")
    experiment_number = int(os.path.basename(args.log_folder))
    standardized = (
        config.get("data_config", {}).get("dictionary", {}).get("standardize", "N/A")
    )
    blocks = config.get("training_hyperparameters", {}).get("num_blocks", "N/A")
    dimension = config.get("training_hyperparameters", {}).get("hidden_dim", "N/A")

    # Add a text box with model configuration at the bottom right
    config_text = (
        f"Model Configuration:\n"
        f"Depth: {depth}\n"
        f"Transform Type: {transform_type}\n"
        f"Standardized: {standardized}\n"
        f"Sample Size: {sample_size}\n"
        f"Noise Type: {noise}\n"
        f"Train Loss: {train_loss}\n"
        f"Experiment Number: {experiment_number}\n"
        f"Architecture: {blocks}|{dimension}D\n"
    )

    # Position at bottom right with a small margin
    fig.text(
        0.98,
        0.02,
        config_text,
        fontsize=12,
        horizontalalignment="right",  # Align text to the right
        verticalalignment="bottom",  # Align to the bottom
        bbox=dict(facecolor="white", alpha=0.8, boxstyle="round,pad=0.5"),
    )

    plt.tight_layout()  # Adjust layout to prevent overlap
    plt.subplots_adjust(bottom=0.15)  # Make room at the bottom for the text
    plt.savefig(
        f"{args.log_folder}/{experiment_number}_{transform_type}_{standardized}_{noise}_{train_loss}.png"
    )
    if not args.only_save:
        plt.show()
