# model/eamc_utils.py
# -*- coding: utf-8 -*-
from typing import Dict, Tuple
import torch
import torch.nn.functional as F

@torch.no_grad()
def estimate_lengths_from_segments(segments: torch.Tensor) -> torch.Tensor:
    """
    Estimate effective length of each segment based on "row norm > 0"
    segments: [B,S,L,D] -> [B,S]
    """

    # segments: [B,S,L,D] -> [B,S,L] (
    return (segments.norm(p=2, dim=-1) > 1e-8).sum(dim=-1)

@torch.no_grad()
def reconstruct_sequence_from_segments(
    segments: torch.Tensor, lengths: torch.Tensor, T: int
) -> torch.Tensor:
    """
    Sequentially concatenate valid parts of each segment, truncate if exceeds T, pad zeros if insufficient
    segments: [B,S,L,D], lengths: [B,S] -> [B,T,D]
    """
    B, S, L, D = segments.shape

    out_dtype = torch.float32 if segments.dtype == torch.bool else segments.dtype
    out = torch.zeros(B, T, D, device=segments.device, dtype=out_dtype)
    
    for b in range(B):
        t = 0
        for s in range(S):
            le = int(lengths[b, s].item())
            if le <= 0:
                continue
            take = min(le, L)
            seg_valid = segments[b, s, :take, :]

            if segments.dtype == torch.bool:
                seg_valid = seg_valid.float()
            end = min(T, t + take)
            if end > t:
                need = end - t
                out[b, t:end, :] = seg_valid[:need, :]
                t = end
            if t >= T:
                break
    return out

@torch.no_grad()
def nrmse(x: torch.Tensor, x_hat: torch.Tensor, eps: float = 1e-8, mode: str = "var") -> torch.Tensor:
    """
    x, x_hat: [B,T,D]
    mode="var": denom=||x - mean_t(x)||; otherwise denom=||x||
    Returns batch-averaged NRMSE scalar tensor
    """
    diff = (x - x_hat).reshape(x.size(0), -1)
    if mode == "var":
        denom = (x - x.mean(dim=1, keepdim=True)).reshape(x.size(0), -1)
    else:
        denom = x.reshape(x.size(0), -1)
    num = diff.norm(p=2, dim=1)
    den = denom.norm(p=2, dim=1) + eps
    return (num / den).mean()

@torch.no_grad()
def detect_event_positions(
    ts_feat: torch.Tensor, tau_scale: float = 1.0, min_gap: int = 2
) -> list:
    """
    Use first-order difference L2 norm as proxy events, return event indices for each sample as list[tensor]
    ts_feat: [B,T,D]; event indices range in [1..T-1]
    """
    B, T, D = ts_feat.shape
    if T <= 2:
        return [torch.tensor([], device=ts_feat.device, dtype=torch.long) for _ in range(B)]
    dx = (ts_feat[:, 1:, :] - ts_feat[:, :-1, :]).pow(2).sum(dim=-1).sqrt()  # [B,T-1]
    mu = dx.mean(dim=1, keepdim=True)
    sd = dx.std(dim=1, keepdim=True)
    thr = mu + tau_scale * sd
    cand = (dx > thr).float()
    events = []
    for b in range(B):
        idx = torch.nonzero(cand[b] > 0, as_tuple=False).squeeze(-1)  # [k] in [0..T-2]
        if idx.numel() == 0:
            events.append(idx)
            continue
        sel = [idx[0].item()]
        for j in idx[1:].tolist():
            if j - sel[-1] >= min_gap:
                sel.append(j)
        events.append(torch.tensor(sel, device=ts_feat.device, dtype=torch.long) + 1)  # shift
    return events

@torch.no_grad()
def boundary_indices_from_lengths(lengths: torch.Tensor) -> list:
    """
    lengths: [B,S] -> boundary set for each sample (right endpoint)
    Example: [10,20,30] -> {10,30}
    """
    B, S = lengths.shape
    cum = lengths.cumsum(dim=1)
    return [cum[b, :-1].long() for b in range(B)]

@torch.no_grad()
def _seg_len_hist(lengths: torch.Tensor, bins: torch.Tensor) -> torch.Tensor:
    """
    lengths: [N], bins: [K] (increasing, right-closed) -> [K] count per bin
    """
    hist = []
    prev = 0
    for th in bins:
        cnt = int((lengths <= th).sum().item()) - prev
        hist.append(cnt)
        prev += cnt
    return torch.tensor(hist, device=lengths.device, dtype=torch.long)

@torch.no_grad()
def compute_irr_bcr(
    ts_feat: torch.Tensor,
    segments_used: torch.Tensor,   # [B,S,L,D]
    config: Dict
) -> Tuple[float, float, Dict]:
    """
    Calculate IRR / BCR and return meta (segment length estimation and histogram, etc.)
    - IRR: 1 - NRMSE(x, x_hat)  (mode="var")
    - BCR: proportion of event points falling within ±delta range of segment boundaries (proxy events)
    """
    B, T, D = ts_feat.shape
    _, S, L, _ = segments_used.shape

    # 1)
    lengths_est = estimate_lengths_from_segments(segments_used)  # [B,S]

    # 2)
    x_hat = reconstruct_sequence_from_segments(segments_used, lengths_est, T)  # [B,T,D]
    irr_val = float((1.0 - nrmse(ts_feat, x_hat, mode="var")).clamp(0.0, 1.0).item())

    # 3) BCR：
    tau = float(config.get("bcr_tau_scale", 1.0))
    delta = int(config.get("bcr_tolerance", 2))
    ev_lists = detect_event_positions(ts_feat, tau_scale=tau, min_gap=max(1, delta))
    bd_lists = boundary_indices_from_lengths(lengths_est)

    bcr_vals = []
    for b in range(B):
        E = ev_lists[b]
        if E.numel() == 0:
            continue
        Bset = bd_lists[b]
        if Bset.numel() == 0:
            hit = 0
        else:
            dist = (E.view(-1,1) - Bset.view(1,-1)).abs()
            hit = int((dist.min(dim=1).values <= delta).sum().item())
        bcr_vals.append(hit / (E.numel() + 1e-8))
    bcr_val = float((sum(bcr_vals)/max(1, len(bcr_vals))) if len(bcr_vals) > 0 else 0.0)

    # 4)
    bins = config.get("hist_bins", None)
    if bins is None:
        step = max(1, int(config.get("segment_len", L) // 8))
        bins = [step, 2*step, 4*step, 6*step, int(config.get("segment_len", L))]
    bins_t = torch.tensor(bins, device=segments_used.device)
    lengths_flat = lengths_est.reshape(-1)
    seg_len_hist = _seg_len_hist(lengths_flat, bins_t)

    meta = {
        "lengths_est": lengths_est.detach(),
        "seg_len_hist": seg_len_hist.detach(),
        "hist_bins": bins_t.detach(),
    }
    return irr_val, bcr_val, meta
