import sys
from pathlib import Path

# project root = parent of "scripts"
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

import torch
from mdu.randomness import set_all_seeds
import numpy as np
import torch.nn as nn
from mdu.unc.constants import UncertaintyType
from mdu.unc.risk_metrics.constants import GName, RiskType, ApproximationType
from mdu.data.load_dataset import get_dataset
from mdu.data.constants import DatasetName
from mdu.eval.toy_exp import eval_unc_decomp
import pandas as pd
from mdu.vis.toy_plots import plot_data_and_test_point
from mdu.unc.constants import VectorQuantileModel

seed = 42
set_all_seeds(seed)

dataset_name = DatasetName.BLOBS
n_classes = 2
device = torch.device("cuda:0")
n_members = 50
input_dim = 2
n_epochs = 50
batch_size = 64
lambda_ = 0.0
calib_ratio = 0.2
val_ratio = 0.2
lr = 1e-1
criterion = nn.CrossEntropyLoss()

UNCERTAINTY_MEASURES = [
    {
        "type": UncertaintyType.RISK,
        "kwargs": {
            "g_name": GName.BRIER_SCORE,
            "risk_type": RiskType.BAYES_RISK,
            "gt_approx": ApproximationType.OUTER,
            "T": 1.0,
        },
    },
    {
        "type": UncertaintyType.RISK,
        "kwargs": {
            "g_name": GName.BRIER_SCORE,
            "risk_type": RiskType.EXCESS_RISK,
            "pred_approx": ApproximationType.OUTER,
            "gt_approx": ApproximationType.INNER,
            "T": 1.0,
        },
    },
]

if dataset_name == DatasetName.BLOBS:
    # Generate n_classes centers uniformly on a circle
    radius = 2.0
    angles = np.linspace(0, 2 * np.pi, n_classes, endpoint=False)
    centers = np.stack([radius * np.cos(angles), radius * np.sin(angles)], axis=1)

    dataset_params = {
        "n_samples": 4000,
        "cluster_std": 1.0,
        "centers": centers,
    }
elif dataset_name == DatasetName.MOONS:
    dataset_params = {
        "n_samples": 4000,
        "noise": 0.1,
    }
else:
    raise ValueError(f"Invalid dataset: {dataset_name}")


X, y = get_dataset(dataset_name, **dataset_params)

X = X + 3.0

mean_point = np.mean(X, axis=0)


hidden_dim_vqm = 10
n_epochs_vqm = 10
lr_vqm = 1e-4

MULTIDIM_MODEL = VectorQuantileModel.ENTROPIC_OT

if MULTIDIM_MODEL == VectorQuantileModel.CPFLOW:
    train_kwargs = {
        "lr": lr_vqm,
        "num_epochs": n_epochs_vqm,
        "batch_size": batch_size,
        "device": device,
    }
    multidim_params = {
        "feature_dimension": len(UNCERTAINTY_MEASURES),
        "hidden_dim": hidden_dim_vqm,
        "num_hidden_layers": 10,
        "nblocks": 4,
        "zero_softplus": False,
        "softplus_type": "softplus",
        "symm_act_first": False,
    }

elif MULTIDIM_MODEL == VectorQuantileModel.OTCP:
    train_kwargs = {
        "batch_size": batch_size,
        "device": device,
    }
    multidim_params = {
        "positive": True,
    }
elif MULTIDIM_MODEL == VectorQuantileModel.ENTROPIC_OT:
    train_kwargs = {
        "batch_size": batch_size,
        "device": device,
    }
    multidim_params = {
        "target": "exp",
        "standardize": False,
        "fit_mse_params": False,
        "eps": 0.1,
        "max_iters": 100,
        "tol": 1e-6,
        "random_state": seed,
    }
else:
    raise ValueError(f"Invalid multidim model: {MULTIDIM_MODEL}")


res = eval_unc_decomp(
    MULTIDIM_MODEL,
    train_kwargs,
    multidim_params,
    X=X,
    y=y,
    test_point=mean_point,
    device=device,
    uncertainty_measures=UNCERTAINTY_MEASURES,
    n_epochs=n_epochs,
    input_dim=input_dim,
    n_members=n_members,
    batch_size=batch_size,
    lambda_=lambda_,
    criterion=criterion,
    calib_ratio=calib_ratio,
    val_ratio=val_ratio,
    lr=lr,
)

uncertainty_keys = set()
for r in res:
    for k in r.keys():
        if any(
            prefix in k
            for prefix in [
                "aleatoric_",
                "epistemic_",
                "additive_total",
                "multidim_scores",
            ]
        ):
            uncertainty_keys.add(k)
uncertainty_keys = sorted(uncertainty_keys)

df_results = pd.DataFrame(
    [
        {
            "samples_per_class": r["n_samples_per_class"],
            "total_samples": r["total_samples"],
            "avg_val_acc": r["avg_val_acc"],
            "val_size": r["val_size"],
            "calib_size": r["calib_size"],
            **{k: r.get(k, float("nan")) for k in uncertainty_keys},
        }
        for r in res
    ]
)

print("Summary of Results at Midpoint (all available uncertainty metrics):")
with pd.option_context("display.max_columns", None):
    print(df_results.to_string(index=False, float_format="%.4f"))

plot_data_and_test_point(X, y, mean_point)
