#!/usr/bin/env python3
import argparse
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from experiment_analysis import load_results, hsic_estimate
import json

import npeet.entropy_estimators as ee
import xicorpy


def load_predictions(exp):
    predictions = {}
    for artifact in exp["artifacts"]:
        key = ""
        if "x_given_y" in artifact:
            key += "x_given_y"
        elif "y_given_x" in artifact:
            key += "y_given_x"
        if "full_data" in artifact:
            key += "_full_data"
        elif "test_data" in artifact:
            key += "_test_data"
        predictions.setdefault(key, []).append(torch.load(f"{exp['path']}/{artifact}"))
    return predictions


def calculate_mis(tensors):
    mi_estimates = []
    for data in tqdm(tensors):
        c = data["c"]
        pred_noise = data["pred_noise"]
        # HSIC with different sigma scales
        hsic_1 = hsic_estimate(pred_noise, c.unsqueeze(1)).item()
        hsic_05 = hsic_estimate(pred_noise, c.unsqueeze(1), sigma_scale=0.5).item()
        hsic_2 = hsic_estimate(pred_noise, c.unsqueeze(1), sigma_scale=2).item()
        # NPEET mutual infos
        npeet_3 = ee.mi(pred_noise.squeeze(1), c, k=3).item()
        npeet_5 = ee.mi(pred_noise.squeeze(1), c, k=5).item()
        npeet_10 = ee.mi(pred_noise.squeeze(1), c, k=10).item()
        # xicorpy codec score
        codec = float(xicorpy.compute_conditional_dependence(pred_noise.squeeze(1), c))
        mi_estimates.append(
            {
                "hsic_1": hsic_1,
                "hsic_05": hsic_05,
                "hsic_2": hsic_2,
                "npeet_3": npeet_3,
                "npeet_5": npeet_5,
                "npeet_10": npeet_10,
                "codec": codec,
            }
        )
    return mi_estimates


def main():
    parser = argparse.ArgumentParser(
        description="Calculate MI estimates from saved diffusion predictions"
    )
    parser.add_argument(
        "--log_folder",
        type=str,
        required=True,
        help="Path to results folder (passed to load_results)",
    )
    parser.add_argument(
        "--output",
        type=str,
        default="mi_outcomes.json",
        help="Filename for storing pure JSON output",
    )
    args = parser.parse_args()

    experiments = load_results(args.log_folder)
    outcomes = []
    for exp in tqdm(experiments, desc="Experiments"):
        preds = load_predictions(exp)
        outcome = {
            "noise": exp["config"]["data_config"]["dictionary"]["X"]["type"],
            "transform": exp["config"]["data_config"]["dictionary"]["transformation"][
                "type"
            ],
        }
        for suffix in ["full_data", "test_data"]:
            y_preds = preds.get(f"y_given_x_{suffix}", [])
            x_preds = preds.get(f"x_given_y_{suffix}", [])
            if not y_preds or not x_preds:
                continue
            mis_y = calculate_mis(y_preds)
            mis_x = calculate_mis(x_preds)
            keys = mis_y[0].keys()
            for key in keys:
                arr_y = np.array([m[key] for m in mis_y])
                arr_x = np.array([m[key] for m in mis_x])
                vote = int(np.sum(arr_y < arr_x) > (len(arr_y) / 2))
                mean_flag = int(np.nanmean(arr_y) < np.nanmean(arr_x))
                outcome[f"vote_{key}_{suffix}"] = vote
                outcome[f"mean_{key}_{suffix}"] = mean_flag
        outcomes.append(outcome)

    # Save pure JSON output
    out_file = args.output if args.output.endswith(".json") else args.output + ".json"
    with open(out_file, "w") as f:
        json.dump(outcomes, f, indent=2)
    print(f"Saved MI outcomes to {out_file}")


if __name__ == "__main__":
    main()
