import argparse
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import os
import json, difflib

from experiment_analysis import get_run_data, load_results


def parse_args():
    parser = argparse.ArgumentParser(
        description="Plot given attribute(s) from train/validation logs with moving average"
    )
    parser.add_argument(
        "--log_folder",
        nargs="+",
        type=str,
        required=True,
        help="One or more log folder paths to compare experiments",
    )
    parser.add_argument(
        "--attr",
        nargs="+",
        type=str,
        required=True,
        help="Logging attribute(s) to extract and plot (e.g. 'hsic_estimate')",
    )
    parser.add_argument(
        "--window_size",
        type=int,
        default=100,
        help="Window size for the moving average",
    )
    parser.add_argument(
        "--val",
        action="store_true",
        help="If set, plot validation attributes instead of training ones",
    )
    parser.add_argument(
        "--only_window",
        action="store_true",
        help="If set, only plot the moving average",
    )
    return parser.parse_args()


def main():
    args = parse_args()
    # load and preprocess each experiment
    experiments = []
    for folder in args.log_folder:
        big = get_run_data(folder)
        results = big["result"]
        all_direction_attributes = {}
        for direction in [
            "results_recovering_x_given_y",
            "results_recovering_y_given_x",
        ]:
            train_infos = results[direction].get("train_logging_infos", [])
            val_infos = results[direction].get("validation_logging_infos", [])
            keys = {k for info in train_infos + val_infos for k in info.keys()}
            train_attrs = {k: [] for k in keys}
            val_attrs = {k: [] for k in keys}
            for info in train_infos:
                for k in keys:
                    train_attrs[k].append(info.get(k))
            for info in val_infos:
                for k in keys:
                    val_attrs[k].append(info.get(k))
            all_direction_attributes[direction] = {
                "train": train_attrs,
                "val": val_attrs,
            }
        experiments.append(
            {
                "name": os.path.basename(folder.rstrip("/")),
                "folder": folder,
                "attributes": all_direction_attributes,
            }
        )

    # if exactly two experiments, show diff of their config.json
    if len(experiments) == 2:
        cfg1_path = os.path.join(experiments[0]["folder"], "config.json")
        cfg2_path = os.path.join(experiments[1]["folder"], "config.json")
        cfg1 = json.load(open(cfg1_path))
        cfg2 = json.load(open(cfg2_path))
        lines1 = json.dumps(cfg1, indent=2).splitlines()
        lines2 = json.dumps(cfg2, indent=2).splitlines()
        print(
            f"Config diff between {experiments[0]['name']} and {experiments[1]['name']}:"
        )
        for line in difflib.unified_diff(
            lines1,
            lines2,
            fromfile=f"{experiments[0]['name']}/config.json",
            tofile=f"{experiments[1]['name']}/config.json",
            lineterm="",
        ):
            print(line)
    else:
        print("Config diff requires exactly two experiments; skipping diff")

    # one subplot per attribute for independent y-scales
    n = len(args.attr)
    fig, axes = plt.subplots(n, 1, sharex=True, figsize=(12, 4 * n))
    if n == 1:
        axes = [axes]

    train_or_val = "val" if args.val else "train"
    for ax, attr in zip(axes, args.attr):
        for exp in experiments:
            ad = exp["attributes"]
            x_y = ad["results_recovering_x_given_y"][train_or_val].get(attr)
            y_x = ad["results_recovering_y_given_x"][train_or_val].get(attr)
            if x_y is None or y_x is None:
                raise ValueError(f"Attribute '{attr}' not found in {exp['name']} logs.")
            # compute moving averages
            x_y_ma = pd.Series(x_y).rolling(window=args.window_size).mean()
            y_x_ma = pd.Series(y_x).rolling(window=args.window_size).mean()
            # plot lines
            if not args.only_window:
                sns.lineplot(data=x_y, label=f"{exp['name']} x_given_y", ax=ax)
                sns.lineplot(data=y_x, label=f"{exp['name']} y_given_x", ax=ax)
            sns.lineplot(
                data=x_y_ma,
                label=f"{exp['name']} x_given_y MA({args.window_size})",
                ax=ax,
                linestyle="--",
            )
            sns.lineplot(
                data=y_x_ma,
                label=f"{exp['name']} y_given_x MA({args.window_size})",
                ax=ax,
                linestyle="--",
            )

        ax.set_ylabel(attr)
        ax.set_title(f"{attr} and its Moving Average (window={args.window_size})")
        ax.grid(True)
        ax.legend()

    axes[-1].set_xlabel("Index")
    plt.tight_layout()
    plt.show()


if __name__ == "__main__":
    main()
