import pandas as pd
import numpy as np
import json

runs_data = pd.read_json("data/runs.json", orient="index")


def normalize_score(score, min_score, max_score, std):
    scale = max_score - min_score

    if scale > 0:
        scaled_mean = max(score - min_score, 0) / scale
        scaled_std = std / scale
        scaled_std = min(
            scaled_std, 1.0
        )  # Ensure std does not exceed 1.0, which can happen if std is large relative to the scale
        return scaled_mean, scaled_std
    else:
        return 0.0, 0.0


def get_param(run_id, param):
    r = runs_data.loc[run_id]
    return r[f"params.{param}"]


def get_rows(d, min_data, prost_data, mlp_data, normalize=True, include_mlp=True):
    rows = []
    domain = d["domain"]

    r = d["run_id"]
    batch_id = d.get("batch_id", None) or get_param(r, "batch_id")
    train_instances = d.get("train_instances") or json.loads(get_param(r, "instance"))
    train_instances = np.array(train_instances) - 1
    prost = np.array(prost_data[domain])[:, 0]
    prost_std = np.array(prost_data[domain])[:, 1]
    gnn = np.array(d["instance_returns"]).mean(-1)
    gnn_std = np.array(d["instance_returns"]).std(-1)
    mins = np.array(min_data[domain])
    mlp = np.array(mlp_data[domain]).mean(-1)
    mlp_std = np.array(mlp_data[domain]).std(-1)

    for i, (
        gnn_score,
        prost_score,
        mlp_score,
        prost_score_std,
        gnn_score_std,
        mlp_score_std,
        min_score,
    ) in enumerate(
        zip(
            gnn,
            prost,
            mlp,
            prost_std,
            gnn_std,
            mlp_std,
            mins,
        )
    ):
        max_score = (
            max(gnn_score, mlp_score, prost_score)
            if include_mlp
            else max(gnn_score, prost_score)
        )

        norm_mlp_score, norm_mlp_std = (
            normalize_score(mlp_score, min_score, max_score, mlp_score_std)
            if include_mlp
            else (0.0, 0.0)
        )
        norm_gnn_score, norm_gnn_std = normalize_score(
            gnn_score, min_score, max_score, gnn_score_std
        )
        norm_prost_score, norm_prost_std = normalize_score(
            prost_score, min_score, max_score, prost_score_std
        )

        assert norm_gnn_score <= 1.0 and norm_gnn_score >= 0.0
        assert norm_prost_score <= 1.0 and norm_prost_score >= 0.0
        assert norm_mlp_score <= 1.0 and norm_mlp_score >= 0.0
        assert norm_gnn_std >= 0.0 and norm_gnn_std <= 1.0
        assert norm_prost_std >= 0.0 and norm_prost_std <= 1.0
        assert norm_mlp_std >= 0.0 and norm_mlp_std <= 1.0

        row = pd.Series(
            {
                "run_id": r,
                "batch_id": batch_id,
                "domain": domain,
                "instance": str(i),
                "score": norm_gnn_score if normalize else gnn_score,
                "prost": norm_prost_score if normalize else prost_score,
                "is_train": i in train_instances,
                "mlp": norm_mlp_score if normalize else mlp_score,
                "mlp_var": norm_mlp_std**2,
                "score_var": norm_gnn_std**2,
                "prost_var": norm_prost_std**2,
            }
        )
        rows.append(row)
    return rows


def get_data(data, min_data, prost_data, mlp_data, normalize=True, include_mlp=True):
    rows = []
    for d in data:
        if d["domain"] == "TriangleTireworld_MDP_ippc2014":
            continue
        rows.extend(
            get_rows(
                d,
                min_data,
                prost_data,
                mlp_data,
                normalize=normalize,
                include_mlp=include_mlp,
            )
        )
    return rows


def get_all_data(normalize=True, include_mlp=True):
    with open("data/eval_results100.jsonl") as f:
        data = [json.loads(l) for l in f.readlines()]

    with open("data/mins100.json") as f:
        min_data = json.load(f)

    with open("prost100std.json") as f:
        prost_data = json.load(f)

    with open("data/eval_400k_0.jsonl") as f:
        mlp_data = [json.loads(l) for l in f.readlines()]

    mlp_data = {d["domain"]: d["instance_returns"] for d in mlp_data}

    rows = get_data(
        data,
        min_data,
        prost_data,
        mlp_data,
        normalize=normalize,
        include_mlp=include_mlp,
    )

    df = pd.DataFrame(rows)
    df = df.set_index("run_id")
    df = df.sort_values(
        by=["batch_id", "run_id", "domain", "is_train"], ascending=False
    )
    return df
