import os
import json
import copy
from pathlib import Path
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, Any, List, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.preprocessing import StandardScaler
from scipy.interpolate import BSpline

try:
    from .evaluate import hat_intervals_from_selected_basis
except ImportError:
    from evaluate import hat_intervals_from_selected_basis


# ============================================================
# B-spline utilities
# ============================================================
def bspline_basis_functions(knots: np.ndarray, degree: int, num_basis: int, x: np.ndarray) -> np.ndarray:
    """
    Return basis matrix B(x) with shape (len(x), num_basis).
    """
    x = np.asarray(x, dtype=float)
    basis_matrix = np.zeros((len(x), num_basis), dtype=float)
    for i in range(num_basis):
        c = np.zeros(num_basis, dtype=float)
        c[i] = 1.0
        spline = BSpline(knots, c, degree, extrapolate=False)
        basis_matrix[:, i] = spline(x)
    return basis_matrix


# ============================================================
# Model
# ============================================================
class MyNetReLU(nn.Module):
    """
    ReLU MLP: Linear(feature_num -> hidden_dim) + (depth-1) * Linear(hidden_dim -> hidden_dim) + Linear(hidden_dim -> 1).
    """

    def __init__(self, feature_num: int, hidden_dim: int = 64, depth: int = 3):
        super().__init__()
        if depth < 1:
            raise ValueError("depth must be >= 1")

        self.feature_num = int(feature_num)
        self.hidden_dim = int(hidden_dim)
        self.depth = int(depth)

        self.l1 = nn.Linear(self.feature_num, self.hidden_dim)
        self.hiddens = nn.ModuleList([nn.Linear(self.hidden_dim, self.hidden_dim) for _ in range(self.depth - 1)])
        self.out = nn.Linear(self.hidden_dim, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.l1(x))
        for layer in self.hiddens:
            x = F.relu(layer(x))
        return self.out(x)


# ============================================================
# Autograd helpers (used by Laplace evidence)
# ============================================================
def gradient(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph: bool = False) -> torch.Tensor:
    """
    Flattened gradient of outputs w.r.t. inputs.
    """
    if torch.is_tensor(inputs):
        inputs = [inputs]
    else:
        inputs = list(inputs)

    grads = torch.autograd.grad(
        outputs, inputs, grad_outputs,
        allow_unused=True,
        retain_graph=retain_graph,
        create_graph=create_graph
    )
    grads = [g if g is not None else torch.zeros_like(p) for g, p in zip(grads, inputs)]
    return torch.cat([g.contiguous().view(-1) for g in grads])


def hessian(output, inputs, out: Optional[torch.Tensor] = None, allow_unused: bool = False, create_graph: bool = False) -> torch.Tensor:
    """
    Dense Hessian of scalar output w.r.t. parameters in inputs.
    """
    if torch.is_tensor(inputs):
        inputs = [inputs]
    else:
        inputs = list(inputs)

    n = sum(p.numel() for p in inputs)
    if out is None:
        out = output.new_zeros(n, n)

    ai = 0
    for i, inp in enumerate(inputs):
        [grad_i] = torch.autograd.grad(output, inp, create_graph=True, allow_unused=allow_unused)
        grad_i = torch.zeros_like(inp) if grad_i is None else grad_i
        grad_i = grad_i.contiguous().view(-1)

        for j in range(inp.numel()):
            if grad_i[j].requires_grad:
                row = gradient(grad_i[j], inputs[i:], retain_graph=True, create_graph=create_graph)[j:]
            else:
                row = grad_i[j].new_zeros(sum(x.numel() for x in inputs[i:]) - j)

            out[ai, ai:].add_(row.type_as(out))
            if ai + 1 < n:
                out[ai + 1:, ai].add_(row[1:].type_as(out))
            del row
            ai += 1

        del grad_i

    return out


# ============================================================
# Config
# ============================================================
@dataclass
class TrainConfig:
    step_lr: float = 0.005
    sigma: float = 1.0
    lambda_n: float = 1e-5
    prior_sigma_0: float = 1e-5
    prior_sigma_1: float = 1e-3
    max_loop: int = 80001
    patience: int = 3000
    subn: int = 500
    device: torch.device = torch.device("cpu")
    show_information: int = 100

    hidden_dim: int = 64
    depth: int = 3


# ============================================================
# Core training
# ============================================================
def _compute_threshold_constants(lambda_n: float, prior_sigma_0: float, prior_sigma_1: float, w: int) -> Dict[str, float]:
    lam = float(lambda_n)
    s0 = float(prior_sigma_0)
    s1 = float(prior_sigma_1)
    w = int(w)

    # Use stable log(1-lam) when lam is tiny.
    log_lam = np.log(lam)
    log_1ml = np.log1p(-lam)

    c1 = (log_lam - log_1ml) + 0.5 * w * (np.log(s0) - np.log(s1))
    c2 = 0.5 / s0 - 0.5 / s1

    if c2 == 0:
        raise ValueError("Invalid prior sigmas: prior_sigma_0 equals prior_sigma_1 leads to c2=0.")

    threshold = -c1 / c2
    return {"c1": float(c1), "c2": float(c2), "threshold": float(threshold)}


def train_model(
    x_train: torch.Tensor, y_train: torch.Tensor,
    x_val: torch.Tensor, y_val: torch.Tensor,
    x_test: torch.Tensor, y_test: torch.Tensor,
    config: TrainConfig = TrainConfig(),
):
    """
    Train ReLU net with column-wise spike-slab style gradient adjustment on first-layer weights.
    Returns: (net, history, trained_pi_bool, threshold_float)
    """
    history = {
        "train_loss": [],
        "val_loss": [],
        "test_mse": [],
        "test_mae": [],
        "test_rmse": [],
        "pi_path": [],
    }

    net = MyNetReLU(x_train.shape[1], hidden_dim=config.hidden_dim, depth=config.depth).to(config.device)
    loss_func = nn.MSELoss()
    optim = torch.optim.SGD(net.parameters(), lr=config.step_lr)

    sigma = torch.tensor([config.sigma], dtype=torch.float32, device=config.device)

    w = int(net.l1.weight.shape[0])
    NTrain = int(x_train.shape[0])
    TotalP = int(x_train.shape[1])

    consts = _compute_threshold_constants(config.lambda_n, config.prior_sigma_0, config.prior_sigma_1, w)
    c1, c2, threshold = consts["c1"], consts["c2"], consts["threshold"]

    print("threshold:", threshold)

    max_loop = int(config.max_loop)
    show_information = int(config.show_information)
    patience_iter = int(config.patience)

    best_val_loss = float("inf")
    iters_no_improve = 0
    best_model_state = None

    subn = int(config.subn)
    if subn <= 0 or subn > NTrain:
        subn = NTrain

    pi_snapshots = np.zeros((max_loop // show_information + 1, TotalP), dtype=float)
    trained_pi = np.zeros(TotalP, dtype=bool)

    for it in range(max_loop):
        if subn == NTrain:
            subsample = np.arange(NTrain)
        else:
            subsample = np.random.choice(NTrain, size=subn, replace=False)

        net.zero_grad(set_to_none=True)

        yhat = net(x_train[subsample])
        loss = loss_func(yhat, y_train[subsample])
        loss = loss.div(2.0 * sigma).add(sigma.log().mul(0.5))
        loss.backward()

        # Gradient adjustment for priors 
        with torch.no_grad():
            for name, p in net.named_parameters():
                if p.grad is None:
                    continue

                if name == "l1.weight":
                    col_norm2 = torch.norm(p, p=2, dim=0).pow(2)  # shape: (TotalP,)
                    # temp = 1 / (1 + exp(c1 + c2 * ||w_j||^2))
                    temp = torch.exp(col_norm2.mul(c2).add(c1))
                    temp = 1.0 / (1.0 + temp)

                    prior_grad = p.div(-config.prior_sigma_0).mul(temp) + p.div(-config.prior_sigma_1).mul(1.0 - temp)
                    prior_grad = prior_grad.div(NTrain)
                    p.grad.sub_(prior_grad)
                else:
                    # Standard normal prior with variance sigma 
                    prior_grad = p.mul(-1.0 / config.sigma).div(NTrain)
                    p.grad.sub_(prior_grad)

        optim.step()

        if it % show_information == 0:
            with torch.no_grad():
                # Train / Val
                tr_pred = net(x_train)
                tr_loss = loss_func(tr_pred, y_train).item()
                va_pred = net(x_val)
                va_loss = loss_func(va_pred, y_val).item()

                # Test metrics
                te_pred = net(x_test)
                mse = loss_func(te_pred, y_test).item()
                mae = torch.mean(torch.abs(te_pred - y_test)).item()
                rmse = float(np.sqrt(mse))

                history["train_loss"].append(float(tr_loss))
                history["val_loss"].append(float(va_loss))
                history["test_mse"].append(float(mse))
                history["test_mae"].append(float(mae))
                history["test_rmse"].append(float(rmse))

                col_norm2 = torch.norm(net.l1.weight.detach(), p=2, dim=0).pow(2).cpu().numpy()
                pi = (col_norm2 > threshold).astype(float)
                pi_snapshots[it // show_information] = pi
                history["pi_path"].append(pi)

                print(f"iter={it:>6d}  train={tr_loss:.6g}  val={va_loss:.6g}  test_mse={mse:.6g}  test_mae={mae:.6g}")

            if va_loss < best_val_loss:
                best_val_loss = va_loss
                best_model_state = copy.deepcopy(net.state_dict())
                iters_no_improve = 0
                trained_pi = pi_snapshots[it // show_information].astype(bool)
            else:
                iters_no_improve += show_information

            if iters_no_improve >= patience_iter:
                print("Early stopping triggered.")
                break

    if best_model_state is not None:
        net.load_state_dict(best_model_state)

    return net, history, trained_pi, float(threshold)


# ============================================================
# Laplace evidence (unchanged logic, cleaned formatting)
# ============================================================
def evaluate_model(net: nn.Module, final_pi: np.ndarray, x_train: torch.Tensor, y_train: torch.Tensor, config: TrainConfig = TrainConfig()) -> float:
    """
    Laplace approx log-evidence with masking by final_pi on first-layer columns.
    """
    net.eval()
    sigma = torch.tensor([config.sigma], dtype=torch.float32, device=config.device)
    loss_func = nn.MSELoss()

    yhat = net(x_train)
    loss = loss_func(yhat, y_train)
    loss = loss.div(2.0 * sigma).add(sigma.log().mul(0.5))

    w = int(net.l1.weight.shape[0])
    NTrain = int(x_train.shape[0])

    lambda_t = torch.tensor(config.lambda_n, device=config.device, dtype=torch.float32)
    s1_t = torch.tensor(config.prior_sigma_1, device=config.device, dtype=torch.float32)
    s0_t = torch.tensor(config.prior_sigma_0, device=config.device, dtype=torch.float32)

    prior1 = torch.tensor(0.0, device=config.device)
    prior2 = torch.tensor(0.0, device=config.device)

    for name, p in net.named_parameters():
        if name == "l1.weight":
            col2 = torch.norm(p, p=2, dim=0).pow(2)
            a = torch.log(lambda_t) - col2.div(2.0 * s1_t) - 0.5 * (w * torch.log(s1_t))
            b = torch.log(1.0 - lambda_t) - col2.div(2.0 * s0_t) - 0.5 * (w * torch.log(s0_t))
            log_mix = torch.logsumexp(torch.stack([a, b]), dim=0)

            mask = torch.from_numpy(final_pi.astype(bool)).to(config.device)
            prior1 = prior1 - log_mix[mask].sum()
        else:
            tmp = ((p ** 2).div(-2.0 * sigma).exp().mul(1.0 / sigma.sqrt())).log()
            prior2 = prior2 - tmp.sum()

    other_params = [p for n, p in net.named_parameters() if n != "l1.weight"]
    D_other = sum(p.numel() for p in other_params)

    prior1 = prior1.div(NTrain)
    prior2 = prior2.div(NTrain)
    obj = loss + prior1 + prior2

    # Loss Hessian
    H_loss = hessian(loss, list(net.parameters()), allow_unused=True, create_graph=False)

    # Prior Hessian diagonal 
    lambda_h1 = config.prior_sigma_1 / NTrain
    d1 = net.l1.weight.numel()
    prior_diag_part1 = torch.full((d1,), float(lambda_h1), device=config.device)

    lambda_h2 = (1.0 / config.sigma) / NTrain
    prior_diag_part2 = torch.full((D_other,), float(lambda_h2), device=config.device)

    prior_diag = torch.cat([prior_diag_part1, prior_diag_part2], dim=0)

    part1_mask = np.tile(final_pi.astype(bool), w)
    part1_mask = torch.from_numpy(part1_mask).to(config.device)
    part2_mask = torch.ones(D_other, dtype=torch.bool, device=config.device)
    mask_all = torch.cat([part1_mask, part2_mask], dim=0)
    idx = mask_all.nonzero(as_tuple=False).squeeze(1)

    idx_cpu = idx.detach().cpu()
    H_loss_cpu = H_loss.detach().cpu()
    prior_diag_cpu = prior_diag.detach().cpu()

    del H_loss
    del prior_diag
    torch.cuda.empty_cache()

    H_sel = H_loss_cpu[idx_cpu][:, idx_cpu] + torch.diag(prior_diag_cpu[idx_cpu])
    d = int(H_sel.shape[0])

    eigvals = torch.linalg.eigvals(H_sel).real
    hessian_part = 0.5 * d * np.log(2.0 * np.pi) - 0.5 * d * np.log(NTrain) - 0.5 * eigvals.abs().log().sum().item()

    neglogpost = obj.mul(NTrain).item()
    log_evidence = -neglogpost + hessian_part
    print("Laplace approx. log-evidence:", log_evidence)
    return float(log_evidence)


# ============================================================
def train_one(
    X: np.ndarray,
    y: np.ndarray,
    *,
    seed: Optional[int] = None,
    bundle_dir: Path,
    num_repeats: int = 5,
    return_hist: bool = False,
    train_cfg: Optional[TrainConfig] = None,
    **cfg,
) -> Dict[str, Any]:
    """
    End-to-end: denoise X with B-spline (GCV), project to spline basis, CV over repeats, pick by avg log-evidence,
    then apply Bayesian FDR on mean PIP to produce intervals.
    """
    N, TotalP = X.shape
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    t_obs = np.linspace(0.0, 1.0, TotalP)
    fine_t = np.linspace(0.0, 1.0, 1000)

    # ---- Denoise X via B-spline fit (using best paramaters in our setting) ----
    K_fit_list = [54]
    degree_fit_list = [4]

    best_fit = None
    best_fit_cache = None

    for degree_fit in degree_fit_list:
        for K_fit in K_fit_list:
            if K_fit >= TotalP:
                continue
            inner = K_fit - degree_fit + 1
            if inner <= 1:
                continue

            knots_fit = np.concatenate([
                np.full(degree_fit, t_obs.min()),
                np.linspace(t_obs.min(), t_obs.max(), inner),
                np.full(degree_fit, t_obs.max()),
            ])

            Phi = bspline_basis_functions(knots_fit, degree_fit, K_fit, t_obs)
            coeffs = np.linalg.lstsq(Phi, X.T, rcond=None)[0].T
            X_fit = coeffs @ Phi.T

            rss_total = np.sum((X - X_fit) ** 2)
            rss_mean = rss_total / float(N)

            trS = float(K_fit)
            denom = max(TotalP - trS, 1e-8)
            gcv = rss_mean / (denom ** 2)

            if (best_fit is None) or (gcv < best_fit["gcv"]):
                B_fine = bspline_basis_functions(knots_fit, degree_fit, K_fit, fine_t)
                best_fit = {"K": K_fit, "degree": degree_fit, "gcv": float(gcv), "rss_mean": float(rss_mean)}
                best_fit_cache = {"coeffs": coeffs, "B_fine": B_fine, "knots": knots_fit}

    if best_fit_cache is None:
        raise RuntimeError("No valid (degree, K) pair for denoising.")

    x_denoised = best_fit_cache["coeffs"] @ best_fit_cache["B_fine"].T  # (N, Tf)

    # ---- Split CV/test ----
    N_test = int(0.10 * N)
    N_cv = N - N_test
    idx_cv = np.arange(N_cv)
    idx_test = np.arange(N_cv, N)

    y_cv = y[idx_cv]
    y_test_data = y[idx_test]

    base_seed = 0 if seed is None else int(seed)

    scenario_dir = bundle_dir.parent
    if base_seed == 0:
        projection_num_list = [55, 60, 70, 80]
    else:
        para_path = os.path.join(scenario_dir, "my_para.json")
        if os.path.exists(para_path):
            with open(para_path, "r", encoding="utf-8") as f:
                cfg_saved = json.load(f)
            projection_num_list = [int(cfg_saved["projection_num"])]
        else:
            print("Using default projection num: 70")
            projection_num_list = [70]

    degree_projection = 4
    if train_cfg is None:
        train_cfg = TrainConfig(
            step_lr=0.001,
            sigma=1.0,
            lambda_n=1e-5,
            prior_sigma_0=1e-5,
            prior_sigma_1=2e-3,
            max_loop=80001,
            patience=3000,
            subn=64,
            device=device,
            hidden_dim=cfg.get("hidden_dim", 64),
            depth=cfg.get("depth", 3),
            show_information=cfg.get("show_information", 100),
        )
    else:
        train_cfg = copy.deepcopy(train_cfg)
        train_cfg.device = device

    candidates: List[Dict[str, Any]] = []

    for projection_num in projection_num_list:
        knots = np.concatenate((
            np.zeros(degree_projection),
            np.linspace(0.0, 1.0, projection_num - degree_projection + 1),
            np.ones(degree_projection),
        ))
        B_fine_pro = bspline_basis_functions(knots, degree_projection, projection_num, fine_t)

        wts = np.gradient(fine_t)
        W = B_fine_pro * wts[:, None]
        projection = x_denoised @ W

        x_cv = projection[idx_cv]
        x_test_data = projection[idx_test]

        indices = np.arange(N_cv)
        np.random.seed(base_seed)
        np.random.shuffle(indices)
        folds = np.array_split(indices, num_repeats)

        for r in range(num_repeats):
            val_idx = folds[r]
            train_idx = np.concatenate([folds[i] for i in range(num_repeats) if i != r])

            run_seed = base_seed * 1000 + r
            np.random.seed(run_seed)
            torch.manual_seed(run_seed)

            x_train = x_cv[train_idx]
            x_val = x_cv[val_idx]
            x_test = x_test_data

            y_train = y_cv[train_idx]
            y_val = y_cv[val_idx]
            y_test = y_test_data

            proj_mean = x_train.mean()
            proj_std = x_train.std()
            x_train = (x_train - proj_mean) / proj_std
            x_val = (x_val - proj_mean) / proj_std
            x_test = (x_test - proj_mean) / proj_std

            scaler_y = StandardScaler()
            y_train_std = scaler_y.fit_transform(y_train.reshape(-1, 1))
            y_val_std = scaler_y.transform(y_val.reshape(-1, 1))
            y_test_std = scaler_y.transform(y_test.reshape(-1, 1))
            sigma_y = float(scaler_y.scale_[0])
            mu_y = float(scaler_y.mean_[0])

            x_train_t = torch.tensor(x_train, dtype=torch.float32, device=device)
            x_val_t = torch.tensor(x_val, dtype=torch.float32, device=device)
            x_test_t = torch.tensor(x_test, dtype=torch.float32, device=device)

            y_train_t = torch.tensor(y_train_std, dtype=torch.float32, device=device)
            y_val_t = torch.tensor(y_val_std, dtype=torch.float32, device=device)
            y_test_t = torch.tensor(y_test_std, dtype=torch.float32, device=device)

            net1, hist1, final_pi, threshold = train_model(x_train_t, y_train_t, x_val_t, y_val_t, x_test_t, y_test_t, train_cfg)

            best_idx = int(np.argmin(hist1["val_loss"]))
            mse_test_std = float(hist1["test_mse"][best_idx])
            mae_test_std = float(hist1["test_mae"][best_idx])
            rmse_test_std = float(hist1["test_rmse"][best_idx])

            mse_test1 = mse_test_std * (sigma_y ** 2)
            mae_test1 = mae_test_std * sigma_y
            rmse_test1 = rmse_test_std * sigma_y

            net1.eval()
            with torch.no_grad():
                yhat_test_std = net1(x_test_t).detach().cpu().numpy().ravel()
            yhat_test_orig = yhat_test_std * sigma_y + mu_y

            log_evd1 = evaluate_model(net1, final_pi, x_train_t, y_train_t, train_cfg)

            pi_path = np.asarray(hist1["pi_path"], dtype=float)  # (M, K)
            pip = pi_path.mean(axis=0)

            sel_cols = np.where(final_pi.astype(bool))[0]

            col_norm2 = torch.norm(net1.l1.weight.detach(), p=2, dim=0).pow(2).cpu().numpy()

            pack = dict(
                projection_num=int(projection_num),
                degree_projection=int(degree_projection),
                repeat_id=int(r),
                seed=int(run_seed),
                mse_test_std=float(mse_test_std),
                mse_test1=float(mse_test1),
                mae_test1=float(mae_test1),
                rmse_test1=float(rmse_test1),
                column_norm2=col_norm2,
                log_evd1=float(log_evd1),
                val_loss1=float(min(hist1["val_loss"])),
                threshold=float(threshold),
                sel_cols=sel_cols,
                pip=pip,
                yhat_test_orig=yhat_test_orig.tolist(),
            )
            if return_hist:
                pack["hist1"] = hist1

            candidates.append(pack)

    if len(candidates) == 0:
        raise RuntimeError("No valid candidates in train_one().")

    grouped = defaultdict(list)
    for c in candidates:
        grouped[(c["projection_num"], c["degree_projection"])].append(c)

    summary = []
    for (proj, deg), group in grouped.items():
        avg_log_evd1 = float(np.mean([g["log_evd1"] for g in group]))
        summary.append({"projection_num": proj, "degree_projection": deg, "avg_log_evd1": avg_log_evd1, "group": group})

    summary_sorted = sorted(summary, key=lambda s: s["avg_log_evd1"], reverse=True)
    best_summary = summary_sorted[0]
    best_group = best_summary["group"]
    best_pack = min(best_group, key=lambda c: c["val_loss1"])

    if base_seed == 0:
        with open(os.path.join(scenario_dir, "my_para.json"), "w", encoding="utf-8") as f:
            json.dump({"projection_num": int(best_pack["projection_num"])}, f, ensure_ascii=False, indent=2)

    intervals_hat = hat_intervals_from_selected_basis(best_pack["sel_cols"], best_pack["projection_num"], best_pack["degree_projection"])

    # ---- Ensemble predictions on original scale ----
    y_test_orig = y_test_data.reshape(-1).astype(float)
    yhat_mat = np.vstack([np.asarray(g["yhat_test_orig"], dtype=float) for g in best_group])
    yhat_ens = yhat_mat.mean(axis=0)

    mse_ens = float(np.mean((yhat_ens - y_test_orig) ** 2))
    rmse_ens = float(np.sqrt(mse_ens))
    mae_ens = float(np.mean(np.abs(yhat_ens - y_test_orig)))

    run_info = dict(
        projection_num=int(best_pack["projection_num"]),
        degree_projection=int(best_pack["degree_projection"]),
        intervals_hat=[[float(a), float(b)] for a, b in intervals_hat],
        mse_test_ens=mse_ens,
        rmse_test_ens=rmse_ens,
        mae_test_ens=mae_ens,
        log_evidence1=float(best_pack["log_evd1"]),
        sel_cols=np.asarray(best_pack["sel_cols"], dtype=int).tolist(),
        column_norm2=np.asarray(best_pack["column_norm2"], dtype=float).tolist(),
    )
    if return_hist and "hist1" in best_pack:
        run_info["hist1"] = best_pack["hist1"]

    return dict(
        metric=mse_ens,
        metric_name="mse_test_ens",
        seed=seed,
        repeats=num_repeats,
        best_run=run_info,
    )
