import argparse
import json
import os

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from aion_eval.benchmarks.desiddpayne.dataset import DESIDDPayneDatasetModule
from astropy.table import Table
from sklearn.metrics import r2_score
from tqdm.auto import tqdm


def make_plot(pred_path, dm, limit_train_size, clip=5, save_plots=True):
    preds = Table.read(pred_path)

    r2s = {}
    sigmas = {}

    if save_plots:
        fig, ax = plt.subplots(4, 4, figsize=(15, 15))
    for i in range(4):
        for j in range(4):
            ix = i * 4 + j
            prop = dm.output_fields[ix]
            gt = dm.val_data[prop][:limit_train_size]
            pred = preds[prop].data
            if clip:
                mask = np.where(np.abs(gt - dm.mean[prop]) < clip * dm.std[prop])
                gt = gt[mask]
                pred = pred[mask]
            r2 = r2_score(gt, pred)
            sigma = np.std(gt - pred)
            r2s[prop] = r2
            sigmas[prop] = float(sigma)
            if save_plots:
                ax[i, j].hexbin(
                    gt, pred, gridsize=60, mincnt=1, norm=matplotlib.colors.LogNorm()
                )
                ax[i, j].plot(gt, gt, "r--")
                ax[i, j].set_title(f"{prop}, $\sigma={sigma:.3f}$, $R^2={r2:.3f}$")

    if save_plots:
        fig.suptitle(pred_path.split("/")[-1].split(".")[0], y=0.92)
    else:
        fig = None

    return fig, r2s, sigmas


def main(args):
    dm = DESIDDPayneDatasetModule(
        data_dir=args.data_dir, num_workers=0, input_fields=[]
    )
    dm.setup(None)

    base_dir = args.results_dir
    paths = sorted(
        [os.path.join(base_dir, f) for f in os.listdir(base_dir) if f.endswith(".fits")]
    )

    results = {}

    os.makedirs(args.output_dir, exist_ok=True)

    for path in tqdm(paths):
        try:
            limit_train_size = int(path.split("_spec_")[1].split(".")[0].split("_")[0])
        except Exception as e:
            print("couldnt find train size, using 1000000", e)
            limit_train_size = 10000000
        fig, r2s, sigmas = make_plot(
            path, dm, limit_train_size, clip=args.clip, save_plots=args.save_plots
        )
        run_name = f"{path.split('/')[-1].split('.')[0]}"
        if args.save_plots:
            fig.savefig(os.path.join(args.output_dir, f"{run_name}.png"))
            plt.close(fig)
        results[run_name] = {"r2s": r2s, "sigmas": sigmas}

    with open(os.path.join(args.output_dir, "results.json"), "w") as f:
        json.dump(results, f)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--results_dir", type=str, required=True)
    parser.add_argument("--data_dir", type=str, required=True)
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument("--save_plots", action="store_true", default=False)
    parser.add_argument("--clip", type=float, default=5)
    args = parser.parse_args()

    main(args)
