from typing import Dict, List, Tuple
from torch.utils.data import DataLoader
import numpy as np
import torch
import torch.nn as nn
from scipy.stats import kendalltau

def knn_predict(Z_lib: np.ndarray,
                Z_query: np.ndarray,
                values_lib: np.ndarray,
                k: int = 20) -> np.ndarray:
    D = np.sqrt(((Z_query[:, None, :] - Z_lib[None, :, :]) ** 2).sum(axis=-1) + 1e-12)
    idx = np.argpartition(D, kth=min(k, Z_lib.shape[0]-1), axis=1)[:, :k]
    d_knn = np.take_along_axis(D, idx, axis=1)
    v_knn = values_lib[idx]
    d0 = np.maximum(d_knn.min(axis=1, keepdims=True), 1e-12)
    w = np.exp(-d_knn / d0)
    w_sum = w.sum(axis=1, keepdims=True) + 1e-12
    pred = (w * v_knn).sum(axis=1) / w_sum.squeeze(1)
    return pred

def knn_local_linear_predict(Z_lib: np.ndarray,
                             Z_query: np.ndarray,
                             values_lib: np.ndarray,
                             k: int = 30,
                             ridge: float = 1e-2) -> np.ndarray:
    D = np.sqrt(((Z_query[:, None, :] - Z_lib[None, :, :])**2).sum(axis=-1) + 1e-12)
    idx = np.argpartition(D, kth=min(k, Z_lib.shape[0]-1), axis=1)[:, :k]
    preds = []
    for qi in range(Z_query.shape[0]):
        inds = idx[qi]
        Zk = Z_lib[inds]
        yk = values_lib[inds]
        Xk = np.concatenate([np.ones((len(inds),1)), Zk], axis=1)
        A = Xk.T @ Xk + ridge * np.eye(Xk.shape[1])
        w = np.linalg.solve(A, Xk.T @ yk)
        zq = np.concatenate([np.ones(1), Z_query[qi]], axis=0)
        preds.append(zq @ w)
    return np.array(preds)

def ccm_convergence_trend(
    lib_sizes: List[int],
    scores: np.ndarray,
    frac_head: float = 0.3,
    frac_tail: float = 0.3,
    clip_below_zero: bool = True,
) -> float:
    sc = np.array(scores, dtype=np.float64)
    ls = np.array(lib_sizes, dtype=np.float64) 
    
    mask = ~np.isnan(sc)
    if mask.sum() < 4:
        return 0.0
    
    sc = sc[mask]
    ls = ls[mask]

    if clip_below_zero:
        sc = np.maximum(sc, 0.0)

    L = sc.size
    h = max(1, int(L * frac_head))
    t = max(1, int(L * frac_tail))

    head_mean = float(sc[:h].mean())
    tail_mean = float(sc[-t:].mean())
    
    magnitude = tail_mean - head_mean

    tau, _ = kendalltau(ls, sc)
    
    if np.isnan(tau):
        consistency = 0.0
    else:
        consistency = abs(float(tau))


    return magnitude * consistency

def ccm_score_with_trend(
    lib_sizes: List[int],
    scores: np.ndarray,
    auc_weight: float = 0,
    trend_weight: float = 0,
    clip_below_zero_for_auc: bool = False,
    normalize_auc: bool = True,
) -> float:

    auc = auc_over_libsizes(
        lib_sizes,
        scores,
        clip_below_zero=clip_below_zero_for_auc,
        normalize=normalize_auc,
        method="trapz",
    )

    if np.isnan(auc):
        return np.nan

    trend = ccm_convergence_trend(
        lib_sizes,
        scores,
        frac_head=0.3,   
        frac_tail=0.3,   
        clip_below_zero=True,  )

    return auc_weight * auc + trend_weight * trend

def ccm_scores(Zx: np.ndarray,
               Zy: np.ndarray,
               X: np.ndarray,
               Y: np.ndarray,
               lib_sizes: List[int],
               k: int = 20,
               n_folds: int = 10,  
               use_local_linear: bool = True) -> Tuple[np.ndarray, np.ndarray]:

    T = min(Zx.shape[0], Zy.shape[0], len(X), len(Y))

    Zx, Zy, X, Y = Zx[:T], Zy[:T], X[:T], Y[:T]

    def corr(a, b):

        if len(a) < 2: return 0.0
        a = a - a.mean()
        b = b - b.mean()
        denom = np.sqrt((a ** 2).sum() * (b ** 2).sum()) + 1e-12
        return float((a * b).sum() / denom)

    folds_y2x = np.zeros((n_folds, len(lib_sizes)))
    folds_x2y = np.zeros((n_folds, len(lib_sizes)))

    fold_len = T // n_folds
    
    for f_idx in range(n_folds):
        start = f_idx * fold_len
        end = (f_idx + 1) * fold_len if f_idx < n_folds - 1 else T
        
        test_indices = np.arange(start, end)
        lib_indices = np.concatenate([np.arange(0, start), np.arange(end, T)])
        
        Zx_test, Zy_test = Zx[test_indices], Zy[test_indices]
        X_test, Y_test = X[test_indices], Y[test_indices]

        Zx_lib_pool, Zy_lib_pool = Zx[lib_indices], Zy[lib_indices]
        X_lib_pool, Y_lib_pool = X[lib_indices], Y[lib_indices]
        
        pool_size = len(lib_indices)

        for i, L in enumerate(lib_sizes):
            L = int(L)

            if L > pool_size or L < k + 5:
                folds_y2x[f_idx, i] = np.nan
                folds_x2y[f_idx, i] = np.nan
                continue

            curr_Zy_lib = Zy_lib_pool[:L]
            curr_X_lib = X_lib_pool[:L]
            
            curr_Zx_lib = Zx_lib_pool[:L]
            curr_Y_lib = Y_lib_pool[:L]

            if use_local_linear:
                pred_X = knn_local_linear_predict(curr_Zy_lib, Zy_test, curr_X_lib, k=k)
            else:
                pred_X = knn_predict(curr_Zy_lib, Zy_test, curr_X_lib, k=k)

            if use_local_linear:
                pred_Y = knn_local_linear_predict(curr_Zx_lib, Zx_test, curr_Y_lib, k=k)
            else:
                pred_Y = knn_predict(curr_Zx_lib, Zx_test, curr_Y_lib, k=k)

            folds_y2x[f_idx, i] = corr(pred_X, X_test)
            folds_x2y[f_idx, i] = corr(pred_Y, Y_test)

    avg_y2x = np.nanmean(folds_y2x, axis=0)
    avg_x2y = np.nanmean(folds_x2y, axis=0)
    
    return avg_y2x, avg_x2y


def auc_over_libsizes(lib_sizes: List[int],
                      scores: np.ndarray,
                      clip_below_zero: bool = True,
                      normalize: bool = True,
                      method: str = "trapz") -> float:
    Ls = np.array(lib_sizes, dtype=np.float64)
    sc = np.array(scores, dtype=np.float64)
    mask = ~np.isnan(sc)
    if mask.sum() < 2:
        return np.nan
    Ls = Ls[mask]
    sc = sc[mask]
    if clip_below_zero:
        sc = np.maximum(sc, 0.0)
    if method == "trapz":
        area = np.trapz(sc, Ls)
        if normalize:
            L_range = max(Ls) - min(Ls)
            area = area / (L_range + 1e-12)
        return float(area)
    else:
        return float(sc.mean())

def ccm_quantify_auc(lib_sizes: List[int],
                     s_y2x: np.ndarray,
                     s_x2y: np.ndarray,
                     clip_below_zero: bool = False,
                     normalize_auc: bool = True) -> Tuple[float, float]:
    auc_y2x = auc_over_libsizes(lib_sizes, s_y2x, clip_below_zero, normalize_auc, method="trapz")
    auc_x2y = auc_over_libsizes(lib_sizes, s_x2y, clip_below_zero, normalize_auc, method="trapz")
    return auc_y2x, auc_x2y

def build_pairwise_convergence_matrix(
    model: nn.Module,
    test_loader: DataLoader,
    cfg,
    device: str = "cuda",          
    k_ccm: int = 20,
    n_lib_points: int = 10,
    n_folds: int = 10,
):

    model.to(device)
    model.eval()
    
    dataset = test_loader.dataset
    cols: List[str] = dataset.sequence_columns
    P = len(cols)

    Z_torch: Dict[str, List[torch.Tensor]] = {c: [] for c in cols}
    S_torch: Dict[str, List[torch.Tensor]] = {c: [] for c in cols}
    
    with torch.no_grad():
        print("Extracting embeddings and aligned values...")
        for batch in test_loader:
            global_windows = batch.global_windows.to(device, non_blocking=True)
            individual_windows = batch.individual_windows.to(device, non_blocking=True)
            series_ids = batch.series_ids.to(device, non_blocking=True)
            aligned_values = batch.aligned_values.to(device, non_blocking=True)  
            
            batch_data = {
                "global_windows": global_windows,
                "individual_windows": individual_windows,
                "series_ids": series_ids,
                "time_indices": batch.time_indices.to(device, non_blocking=True),
            }
            
            outputs = model(batch_data)
            individual_embeddings = outputs["individual_embeddings"]  
            
            for series_idx, col_name in enumerate(cols):
                mask = (series_ids == series_idx)  
                if not mask.any():
                    continue
                Z_torch[col_name].append(individual_embeddings[mask])  
                S_torch[col_name].append(aligned_values[mask])          

    Z: Dict[str, np.ndarray] = {}
    S_aligned: Dict[str, np.ndarray] = {}
    for col_name in cols:
        if len(Z_torch[col_name]) > 0:
            z_cat = torch.cat(Z_torch[col_name], dim=0)      
            s_cat = torch.cat(S_torch[col_name], dim=0)       
            Z[col_name] = z_cat.cpu().numpy()                  
            S_aligned[col_name] = s_cat.cpu().numpy()          
        else:
            Z[col_name] = np.empty((0, model.embedding_dim), dtype=np.float32)
            S_aligned[col_name] = np.empty((0,), dtype=np.float32)

    
    S_mat = np.zeros((P, P), dtype=np.float64)
    np.fill_diagonal(S_mat, 0.0)

    for i, ci in enumerate(cols):
        print("processing", ci )
        for j, cj in enumerate(cols):
            if i == j:
                continue
            Zi, Xi = Z[ci], S_aligned[ci]
            Zj, Yj = Z[cj], S_aligned[cj]
            

            T_pair = min(len(Zi), len(Zj), len(Xi), len(Yj))
            
            if T_pair < 80:
                print("Too short, skip:", ci, cj, T_pair)
                S_mat[i, j] = np.nan
                continue
            
            max_train_size = int(T_pair * (n_folds - 1) / n_folds)

            if max_train_size < k_ccm + 10:
                S_mat[i, j] = np.nan
                continue

            Li = np.linspace(max(50, k_ccm + 5), max_train_size, num=n_lib_points, dtype=int).tolist()
            
            s_j2i, _ = ccm_scores(
                Zi, Zj, Xi, Yj,
                lib_sizes=Li,
                k=k_ccm,
                n_folds=n_folds, 
                use_local_linear=cfg.use_local_linear
            )

            score_j2i = ccm_score_with_trend(
                Li,
                s_j2i,
                auc_weight=1.0,   
                trend_weight=0.5,
                clip_below_zero_for_auc=True,
                normalize_auc=True,
            )
            S_mat[i, j] = score_j2i
            print(f"S_mat {i}->{j} score: ", s_j2i)

    return S_mat, cols
