# train_utils.py
import torch
import torch.nn as nn
from tqdm import trange

from ..dynamics.sample_dynamics import simulate_model_on_grid

def logmean(a, b, eps=1e-16):
    a = torch.clamp(a, min=eps); b = torch.clamp(b, min=eps)
    return (a - b) / (torch.log(a) - torch.log(b) + 1e-16)

@torch.no_grad()
def one_step_forecast(rho, V, beta, omega, dt=1e-2, eps=1e-16):
    """
    One Euler step of the graph FP flow with log-mean mobility.
    rho, V: (n,), beta: scalar, omega: (n,n) symmetric conductance.
    """
    logrho = torch.log(torch.clamp(rho, min=eps))
    dV   = V.unsqueeze(1) - V.unsqueeze(0)                 # V_i - V_j
    dlog = logrho.unsqueeze(1) - logrho.unsqueeze(0)       # logρ_i - logρ_j
    theta = logmean(rho.unsqueeze(1), rho.unsqueeze(0))    # (n,n)
    drive = (-dV + -beta * dlog)                           # (V_j - V_i) + β(logρ_j - logρ_i)
    pos   = torch.clamp(drive, min=0.0)
    W     = omega * theta
    inflow  = (W * pos).sum(dim=1)
    outflow = (W * pos).sum(dim=0)
    drho = inflow - outflow
    rho_next = (rho + dt * drho).clamp_min(0.0)
    rho_next = rho_next / rho_next.sum().clamp_min(1e-16)
    return rho_next

@torch.no_grad()
def multi_step_forecast(rho0, V, beta, omega, steps=10, dt=1e-2):
    rho = rho0.clone()
    traj = [rho]
    for _ in range(steps):
        rho = one_step_forecast(rho, V, beta, omega, dt=dt)
        traj.append(rho)
    return torch.stack(traj, dim=0)  # (steps+1, n)

@torch.no_grad()
def evaluate_by_forecast_tv(
    model, ds_eval, *,
    device="cpu", use_rho_gt=False,
    horizon=5, dt=1e-2, num_snapshots=50, seed=0
):
    """
    Roll out from eval snapshots and compare predicted vs empirical distributions
    using Total Variation distance. Returns { "TV_h1": ..., ..., "TV_hH": ... } averaged
    over num_snapshots random starting times.
    """
    gen = torch.Generator(device="cpu").manual_seed(seed)
    device = torch.device(device)
    n = ds_eval.n
    all_nodes = torch.arange(n, dtype=torch.long, device=device)

    K = torch.tensor(ds_eval.metadata.get("K"), dtype=torch.float64)
    pi = torch.tensor(ds_eval.metadata.get("pi"), dtype=torch.float64)
    omega = K * pi.reshape(n,1)

    omega = omega.to(device)

    sums = {f"TV_h{h}": 0.0 for h in range(1, horizon+1)}
    denom = 0

    for _ in range(num_snapshots):
        t = int(torch.randint(len(ds_eval) - horizon, (1,), generator=gen, device="cpu").item())
        snap_t = ds_eval[t]
        rho_t = (snap_t["rho_gt"] if (use_rho_gt and "rho_gt" in snap_t) else snap_t["rho_hat"]).to(device)

        V_all, beta_nodes = model(all_nodes)
        beta = beta_nodes.mean()

        traj = multi_step_forecast(rho_t, V_all, beta, omega, steps=horizon, dt=dt)  # (H+1,n)

        for h in range(1, horizon+1):
            snap_tp = ds_eval[t+h]
            rho_emp = (snap_tp["rho_gt"] if (use_rho_gt and "rho_gt" in snap_tp) else snap_tp["rho_hat"]).to(device)
            tv = tv_distance(rho_emp, traj[h])
            sums[f"TV_h{h}"] += tv

        denom += 1

    return {k: v / max(denom, 1) for k, v in sums.items()}

# @torch.no_grad()
# def tv_distance(p: torch.Tensor, q: torch.Tensor, eps: float = 0.0) -> float:
#     """
#     Total variation distance for discrete distributions:
#       TV = 0.5 * sum |p - q|
#     p, q: (n,) nonnegative, typically sum to 1 (we renormalize just in case).
#     Returns a Python float.
#     """
#     p = p.clamp_min(0); q = q.clamp_min(0)
#     ps = p.sum().clamp_min(1e-16); qs = q.sum().clamp_min(1e-16)
#     p = p / ps; q = q / qs
#     return float(0.5 * (p - q).abs().sum().item())

def _sample_pairs(dataset, batch_size: int, use_rho_gt: bool, gen: torch.Generator):
    """
    Randomly sample (t, x) pairs:
      - t ~ Uniform{0..T-1}
      - x ~ Uniform over samples_tm[t]
    Returns:
      t_idx: (B,), x_idx: (B,), v_rows: (B,n), log_rho_rows: (B,n), n
    """
    T, n = len(dataset), dataset.n
    # IMPORTANT: CPU generator for index RNG
    if gen.device.type != "cpu":
        gen = torch.Generator(device="cpu").manual_seed(int(torch.randint(2**31-1, (1,), device="cpu").item()))
    t_idx = torch.randint(T, (batch_size,), generator=gen, device="cpu")
    x_idx = torch.empty(batch_size, dtype=torch.long)
    v_rows = torch.empty(batch_size, n, dtype=dataset.v_mat_seq.dtype)
    log_rho_rows = torch.empty_like(v_rows)

    for b in range(batch_size):
        t = int(t_idx[b])
        samples_t = dataset.samples_tm[t]
        j = torch.randint(samples_t.numel(), (1,), generator=gen, device="cpu").item()
        x = int(samples_t[j])
        x_idx[b] = x

        snap = dataset[t]  # __getitem__
        v_rows[b] = snap["v_mat"][x]
        log_rho_rows[b] = snap["logrho_gt"] if (use_rho_gt and "logrho_gt" in snap) else snap["logrho_hat"]

    return t_idx, x_idx, v_rows, log_rho_rows, n

def train_one_epoch(
    model,
    dataset,
    optimizer,
    *,
    device="cpu",
    batch_size=64,
    steps_per_epoch=1000,
    use_rho_gt=False,
    grad_clip=None,
    seed=0,
    show_progress=True,
    reg_beta='None',
    beta0 = 0.0,
    lam_beta = 0.0
):
    model.train()
    mse = nn.MSELoss()
    gen = torch.Generator().manual_seed(seed)
    all_nodes = torch.arange(dataset.n, dtype=torch.long, device=device)

    total_loss = 0.0
    it = trange(steps_per_epoch, desc="[TRAIN] steps", leave=False, disable=not show_progress)
    for _ in it:
        _, x_idx, v_rows, log_rho_rows, _ = _sample_pairs(dataset, batch_size, use_rho_gt, gen)
        x_idx      = x_idx.to(device)
        v_rows     = v_rows.to(device)
        log_rho    = log_rho_rows.to(device)

        # import IPython
        # import sys
        # with open('/dev/pts/0') as user_tty:
        #     sys.stdin=user_tty
        #     IPython.embed()

        optimizer.zero_grad(set_to_none=True)
        V_all, beta_nodes = model(all_nodes)          # (n,), (n,)
        beta = beta_nodes.mean()                      # global scalar
        V_all_exp = V_all.unsqueeze(0)                # (1,n)
        V_x = V_all[x_idx].unsqueeze(1)               # (B,1)
        log_rho_x = log_rho[torch.arange(batch_size, device=device), x_idx].unsqueeze(1)
        pred_rows = (V_all_exp - V_x) + beta * (log_rho - log_rho_x)  # (B,n)

        loss_data = mse(pred_rows, v_rows)

        if reg_beta == 'ridge':
            beta0 = torch.tensor(beta0, device=device)
            loss_reg = lam_beta * (beta - beta0)**2
            loss = loss_data + loss_reg
        elif reg_beta == 'None':
            loss = loss_data

        loss.backward()
        if grad_clip is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()

        total_loss += float(loss.item())
        it.set_postfix(loss=loss.item())

    return total_loss / max(steps_per_epoch, 1)

def train_one_epoch_v2(
    model,
    dataset,
    optimizer,
    *,
    device="cpu",
    batch_size=64,
    steps_per_epoch=1000,
    use_rho_gt=False,
    grad_clip=None,
    seed=0,
    show_progress=True,
    reg_beta='None',
    beta0 = 0.0,
    lam_beta = 0.0
):
    
    """
    Model forward is assumed to be:
        gradV_pred, beta_pred = model(x_idx)
      where gradV_pred: (B, n) approximates ∇V(x)_· for each anchor x in the batch,
            beta_pred : scalar () or (B,). We reduce to a global scalar with .mean().
    """

    model.train()
    mse = nn.MSELoss()  # mean-squared error criterion (PyTorch)  [docs]
    gen = torch.Generator(device="cpu").manual_seed(seed)  # CPU RNG for indices

    total_loss = 0.0
    it = trange(steps_per_epoch, desc="[TRAIN] steps", leave=False, disable=not show_progress)
    for _ in it:
        # Sample random (t, x) pairs and gather targets for the batch
        _, x_idx, v_rows, log_rho_rows, _ = _sample_pairs(dataset, batch_size, use_rho_gt, gen)
        x_idx   = x_idx.to(device)          # (B,)
        v_rows  = v_rows.to(device)         # (B, n)
        log_rho = log_rho_rows.to(device)   # (B, n)

        optimizer.zero_grad(set_to_none=True)

        # ---- Forward: model returns gradient row per anchor + a (global) beta
        gradV_pred, beta_pred = model(x_idx)      # gradV_pred: (B,n); beta_pred: () or (B,)
        if beta_pred.ndim > 0:
            beta = beta_pred.mean()               # enforce a single global scalar beta
        else:
            beta = beta_pred

        # ---- Compose prediction: gradV(x,·) + beta * (logρ(·) - logρ(x))

        # import IPython
        # import sys
        # with open('/dev/pts/0') as user_tty:
        #     sys.stdin=user_tty
        #     IPython.embed()

        B = x_idx.shape[0]
        log_rho_x = log_rho[torch.arange(B, device=device), x_idx].unsqueeze(1)  # (B,1)
        pred_rows = gradV_pred + beta * (log_rho - log_rho_x)                    # (B,n)

        # ---- Data loss
        loss_data = mse(pred_rows, v_rows)  # MSE between predicted and ground-truth rows

        # ---- Optional regularization on beta
        if reg_beta == 'ridge':
            beta0_t = torch.tensor(beta0, dtype=beta.dtype, device=device)
            loss_reg = lam_beta * (beta - beta0_t).pow(2)
            loss = loss_data + loss_reg
        else:
            loss = loss_data

        # ---- Backprop / step
        loss.backward()
        if grad_clip is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)  # [docs]
        optimizer.step()

        total_loss += float(loss.item())
        it.set_postfix(loss=loss.item())

    return total_loss / max(steps_per_epoch, 1)

@torch.no_grad()
def evaluate(
    model,
    dataset,
    *,
    device="cpu",
    batch_size=64,
    steps=200,
    use_rho_gt=False,
    seed=0,
    show_progress=True,
    reg_beta=None,
    beta0 = 0.0,
    lam_beta = 0.0
):
    model.eval()
    mse = nn.MSELoss()
    gen = torch.Generator().manual_seed(seed + 1234)
    all_nodes = torch.arange(dataset.n, dtype=torch.long, device=device)

    total = 0.0
    it = trange(steps, desc="[EVAL] steps", leave=False, disable=not show_progress)
    for _ in it:
        _, x_idx, v_rows, log_rho_rows, _ = _sample_pairs(dataset, batch_size, use_rho_gt, gen)
        x_idx      = x_idx.to(device)
        v_rows     = v_rows.to(device)
        log_rho    = log_rho_rows.to(device)

        V_all, beta_nodes = model(all_nodes)
        beta = beta_nodes.mean()
        V_all_exp = V_all.unsqueeze(0)
        V_x = V_all[x_idx].unsqueeze(1)
        log_rho_x = log_rho[torch.arange(batch_size, device=device), x_idx].unsqueeze(1)
        pred_rows = (V_all_exp - V_x) + beta * (log_rho - log_rho_x)

        loss_data = mse(pred_rows, v_rows)

        if reg_beta == 'ridge':
            beta0_t = torch.tensor(beta0, device=device)
            loss_reg = lam_beta * (beta - beta0_t)**2
            loss = loss_data + loss_reg
        else:
            loss = loss_data

        total += float(loss.item())
        it.set_postfix(loss=loss.item())

    return total / max(steps, 1)

@torch.no_grad()
def epoch_diagnostics(model, dataset, *, device="cpu", use_rho_gt=False, seed=0, k_preview=5):
    """
    Pick a random (t,x), print small summary of predicted vs ground-truth v_row.
    """
    gen = torch.Generator(device="cpu").manual_seed(seed + 777)
    T = len(dataset)
    t = int(torch.randint(T, (1,), generator=gen, device="cpu").item())
    snap = dataset[t]
    # pick x from observed samples at time t
    samples_t = dataset.samples_tm[t]
    x = int(samples_t[torch.randint(samples_t.numel(), (1,), generator=gen, device="cpu").item()])

    v_row_gt = snap["v_mat"][x].to(device)  # (n,)
    log_rho  = (snap["logrho_gt"] if (use_rho_gt and "logrho_gt" in snap) else snap["logrho_hat"]).to(device)

    all_nodes = torch.arange(dataset.n, dtype=torch.long, device=device)
    V_all, beta_nodes = model(all_nodes)
    beta = beta_nodes.mean()
    V_x = V_all[x]
    pred_row = (V_all - V_x) + beta * (log_rho - log_rho[x])

    # compute L2 norms
    l2_err = torch.linalg.norm(pred_row - v_row_gt).item()
    l2_gt  = torch.linalg.norm(v_row_gt).item()
    l2_pr  = torch.linalg.norm(pred_row).item()

    # preview first/last few entries (to keep logs short)
    def short(v):
        k = min(k_preview, v.numel())
        head = v[:k].detach().cpu().numpy()
        tail = v[-k:].detach().cpu().numpy() if v.numel() > k else None
        return head, tail

    head_p, tail_p = short(pred_row)
    head_g, tail_g = short(v_row_gt)

    return {
        "t": t, "x": x,
        "l2_error": l2_err,
        "l2_pred": l2_pr,
        "l2_gt": l2_gt,
        "pred_head": head_p.tolist(),
        "pred_tail": (tail_p.tolist() if tail_p is not None else None),
        "gt_head": head_g.tolist(),
        "gt_tail": (tail_g.tolist() if tail_g is not None else None),
        "beta": float(beta.item()),
    }

@torch.no_grad()
def evaluate_v2(
    model,
    dataset,
    *,
    device="cpu",
    batch_size=64,
    steps=200,
    use_rho_gt=False,
    seed=0,
    show_progress=True,
    reg_beta=None,
    beta0=0.0,
    lam_beta=0.0,
):
    """
    Evaluates with the new interface:
      gradV_pred, beta_pred = model(x_idx)
    Loss: MSE( gradV_pred + beta * (logρ - logρ_x), v_rows )
    """
    model.eval()
    mse = nn.MSELoss()
    gen = torch.Generator(device="cpu").manual_seed(seed + 1234)

    total = 0.0
    it = trange(steps, desc="[EVAL] steps", leave=False, disable=not show_progress)
    for _ in it:
        _, x_idx, v_rows, log_rho_rows, _ = _sample_pairs(dataset, batch_size, use_rho_gt, gen)
        x_idx   = x_idx.to(device)        # (B,)
        v_rows  = v_rows.to(device)       # (B,n)
        log_rho = log_rho_rows.to(device) # (B,n)

        # forward: model returns gradient rows and (global/per-sample) beta
        gradV_pred, beta_pred = model(x_idx)                 # (B,n), () or (B,)
        beta = beta_pred.mean() if beta_pred.ndim > 0 else beta_pred

        # compose prediction
        B = x_idx.shape[0]
        log_rho_x = log_rho[torch.arange(B, device=device), x_idx].unsqueeze(1)  # (B,1)
        pred_rows = gradV_pred + beta * (log_rho - log_rho_x)                     # (B,n)

        # data loss (+ optional beta regularization)
        loss_data = mse(pred_rows, v_rows)
        if reg_beta == 'ridge':
            beta0_t = torch.tensor(beta0, dtype=beta.dtype, device=device)
            loss = loss_data + lam_beta * (beta - beta0_t).pow(2)
        else:
            loss = loss_data

        total += float(loss.item())
        it.set_postfix(loss=loss.item())

    return total / max(steps, 1)


@torch.no_grad()
def epoch_diagnostics_v2(
    model,
    dataset,
    *,
    device="cpu",
    use_rho_gt=False,
    seed=0,
    k_preview=5,
):
    """
    Random (t, x) diagnostic for the new interface.
    Builds pred_row = gradV_pred(x) + beta*(logρ - logρ[x]) and compares to v_mat[x].
    """
    gen = torch.Generator(device="cpu").manual_seed(seed + 777)
    T = len(dataset)
    t = int(torch.randint(T, (1,), generator=gen, device="cpu").item())
    snap = dataset[t]

    # pick anchor x from observed samples at time t
    samples_t = dataset.samples_tm[t]
    x = int(samples_t[torch.randint(samples_t.numel(), (1,), generator=gen, device="cpu").item()])

    v_row_gt = snap["v_mat"][x].to(device)  # (n,)
    log_rho  = (snap["logrho_gt"] if (use_rho_gt and "logrho_gt" in snap) else snap["logrho_hat"]).to(device)

    # model forward for this single anchor
    x_tensor = torch.tensor([x], dtype=torch.long, device=device)
    gradV_pred, beta_pred = model(x_tensor)        # (1,n), () or (1,)
    gradV_pred = gradV_pred.squeeze(0)             # (n,)
    beta = (beta_pred.mean() if beta_pred.ndim > 0 else beta_pred)

    pred_row = gradV_pred + beta * (log_rho - log_rho[x])

    # diagnostics
    l2_err = torch.linalg.norm(pred_row - v_row_gt).item()
    l2_gt  = torch.linalg.norm(v_row_gt).item()
    l2_pr  = torch.linalg.norm(pred_row).item()

    def short(v):
        k = min(k_preview, v.numel())
        head = v[:k].detach().cpu().numpy()
        tail = v[-k:].detach().cpu().numpy() if v.numel() > k else None
        return head, tail

    head_p, tail_p = short(pred_row)
    head_g, tail_g = short(v_row_gt)

    return {
        "t": t, "x": x,
        "l2_error": l2_err,
        "l2_pred": l2_pr,
        "l2_gt": l2_gt,
        "pred_head": head_p.tolist(),
        "pred_tail": (tail_p.tolist() if tail_p is not None else None),
        "gt_head": head_g.tolist(),
        "gt_tail": (tail_g.tolist() if tail_g is not None else None),
        "beta": float(beta.item()),
    }



@torch.no_grad()
def tv_distance(P, Q, eps: float = 1e-12) -> float:
    P = torch.as_tensor(P, dtype=torch.float64).clamp_min(0)
    Q = torch.as_tensor(Q, dtype=torch.float64).clamp_min(0)
    P = P / (P.sum() + eps)
    Q = Q / (Q.sum() + eps)
    return (0.5 * (P - Q).abs().sum()).item()

@torch.no_grad()
def l2_distance(P, Q, eps: float = 1e-12) -> float:
    P = torch.as_tensor(P, dtype=torch.float64).clamp_min(0)
    Q = torch.as_tensor(Q, dtype=torch.float64).clamp_min(0)
    P = P / (P.sum() + eps)
    Q = Q / (Q.sum() + eps)
    return torch.linalg.vector_norm(P - Q, ord=2).item()

@torch.no_grad()
def hellinger_distance(P, Q, eps: float = 1e-12) -> float:
    P = torch.as_tensor(P, dtype=torch.float64).clamp_min(0)
    Q = torch.as_tensor(Q, dtype=torch.float64).clamp_min(0)
    P = P / (P.sum() + eps)
    Q = Q / (Q.sum() + eps)
    diff = torch.sqrt(P + eps) - torch.sqrt(Q + eps)
    return torch.sqrt(0.5 * (diff * diff).sum()).item()



@torch.no_grad()
def get_forecast_metrics(model, ds_val, metrics = ['tv', 'l2', 'hellinger', 'random_check', 'gradV'], device='cpu', dtype=torch.float64):
    
    metric_dict = {}

    T = ds_val.T
    K_meta = ds_val.metadata.get('K', None)
    pi_meta = ds_val.metadata.get('pi', None)
    
    # Handle case where K/pi are already tensors vs need conversion
    if isinstance(K_meta, torch.Tensor):
        K = K_meta.to(dtype=ds_val.dtype, device=device)
    else:
        K = torch.tensor(K_meta, dtype=ds_val.dtype, device=device)
    
    if isinstance(pi_meta, torch.Tensor):
        pi = pi_meta.to(dtype=ds_val.dtype, device=device)
    else:
        pi = torch.tensor(pi_meta, dtype=ds_val.dtype, device=device)

    times = torch.linspace(0.0, 1.0, T+1, device=device)

    states = ds_val.samples_tm.to(device)
    initial_states = states[0]

    S = simulate_model_on_grid(model, K, pi, None, times, initial_states, eval_point="left", dtype=torch.float64, seed=None, gt_prob=False)


    for metric in metrics:
        if metric == 'tv':
            tvs = []
            for t in range(1, T):
                B = states[t].shape[0]
                emp_prob_gt = torch.bincount(states[t], minlength=K.size(0)).to(dtype) / B
                B = S[t].shape[0]
                emp_prob_fc = torch.bincount(S[t], minlength=K.size(0)).to(dtype) / B

                tv  = tv_distance(emp_prob_gt, emp_prob_fc)
                tvs.append(tv)
            metric_dict['TV'] = {'mean': sum(tvs) / len(tvs),
                                 'max': max(tvs),
                                'min': min(tvs)}

        elif metric == 'l2':
            l2s = []
            for t in range(1, T):
                B = states[t].shape[0]
                emp_prob_gt = torch.bincount(states[t], minlength=K.size(0)).to(dtype) / B
                B = S[t].shape[0]
                emp_prob_fc = torch.bincount(S[t], minlength=K.size(0)).to(dtype) / B

                l2  = l2_distance(emp_prob_gt, emp_prob_fc)
                l2s.append(l2)
            metric_dict['L2'] = {'mean': sum(l2s) / len(l2s),
                                 'max': max(l2s),
                                'min': min(l2s)}

        elif metric == 'hellinger':
            hellingers = []
            for t in range(1, T):
                B = states[t].shape[0]
                emp_prob_gt = torch.bincount(states[t], minlength=K.size(0)).to(dtype) / B
                B = S[t].shape[0]
                emp_prob_fc = torch.bincount(S[t], minlength=K.size(0)).to(dtype) / B

                hellinger  = hellinger_distance(emp_prob_gt, emp_prob_fc)
                hellingers.append(hellinger)
            metric_dict['Hellinger'] = {'mean': sum(hellingers) / len(hellingers),
                                    'max': max(hellingers),
                                    'min': min(hellingers)}
            
        elif metric == 'random_check':
            tests = 5
            for _ in range(tests):
                t = torch.randint(1, T, (1,)).item()

                B = states[t].shape[0]
                emp_prob_gt = torch.bincount(states[t], minlength=K.size(0)).to(dtype) / B
                B = S[t].shape[0]
                emp_prob_fc = torch.bincount(S[t], minlength=K.size(0)).to(dtype) / B

                print(f"Probabilities at time {t}: \n GT samples prob {emp_prob_gt} \n FC samples prob {emp_prob_fc}" )

        if metric == 'gradV':
            gradV, beta = model.get_potential()
            gradV = gradV.to('cpu')
            beta = beta.to('cpu')
            V_gt, beta_gt = ds_val.metadata.get('V', None), ds_val.metadata.get('beta', None)
            try:
                V_gt = torch.from_numpy(V_gt).float()
            except:
                V_gt = torch.FloatTensor(V_gt)
            gradV_gt =  V_gt.unsqueeze(1) - V_gt.unsqueeze(0)  # [n,n] with [i,j] = V[j]-V[i]

            gradV_gt = gradV_gt.to('cpu')

            tests = 5
            for _ in range(tests):
                t = torch.randint(1, T, (1,)).item()

                B = states[t].shape[0]
                emp_prob_gt = torch.bincount(states[t], minlength=K.size(0)).to(dtype) / B
                B = S[t].shape[0]
                emp_prob_fc = torch.bincount(S[t], minlength=K.size(0)).to(dtype) / B

                print(f"Vectors at time {t}: \n GT vector {gradV_gt + beta_gt * torch.log(emp_prob_gt.to('cpu'))} \n FC vector {gradV + beta * torch.log(emp_prob_gt.to('cpu'))}" )


    return metric_dict
