import argparse
import json
import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm

from experiment_analysis import get_run_data, load_results
from mi_estimators.mi_hsic import hsic_estimate
from plot_training_curve import load_predictions


def parse_args():
    parser = argparse.ArgumentParser(
        description="Plot a given attribute from train logs with moving average"
    )
    parser.add_argument(
        "--log_folder",
        type=str,
        required=True,
        help="Path to pass as first arg of load_results()",
    )
    parser.add_argument(
        "--skip_existing",
        action="store_true",
        help="Skip calculation if additional_info.json already exists",
    )
    return parser.parse_args()


def main():
    args = parse_args()
    info_path = os.path.join(args.log_folder, "additional_info.json")
    if os.path.exists(info_path) and args.skip_existing:
        with open(info_path, "r") as f:
            existing = json.load(f)
        # if every record in every direction has the required keys, skip
        if all(
            isinstance(v, list)
            and all(
                isinstance(d, dict) and "nHSIC" in d and "HSIC_norm" in d for d in v
            )
            for v in existing.values()
        ):
            print(
                "Found existing additional_info.json with required keys; skipping calculation."
            )
            return

    predictions = load_predictions(
        args.log_folder,
    )

    additional_info = {}
    for direction, tensors in predictions.items():
        additional_info_per_direction = []
        for i, tensor in enumerate(tqdm(tensors)):
            if i % 5 != 0:
                continue

            hsic_xx = hsic_estimate(tensor[0], tensor[0])
            hsic_yy = hsic_estimate(tensor[2], tensor[2])
            hsic_xy = hsic_estimate(tensor[0], tensor[2])
            nHSIC = hsic_xy / torch.sqrt(hsic_xx * hsic_yy)
            # — new: normalize each signal before HSIC —
            pred_noise = tensor[0]
            c = tensor[2]
            pn_norm = (pred_noise - pred_noise.mean()) / pred_noise.std()
            c_norm = (c - c.mean()) / c.std()

            hsic_xy_norm = hsic_estimate(pn_norm, c_norm)

            additional_info_per_direction.append(
                {
                    "nHSIC": nHSIC.item(),
                    "HSIC_norm": hsic_xy_norm.item(),
                }
            )
        additional_info[direction] = additional_info_per_direction

    with open(f"{args.log_folder}/additional_info.json", "w") as f:
        json.dump(additional_info, f)

    print(additional_info)


if __name__ == "__main__":
    main()
