import numpy as np
import pandas as pd
from dataclasses import dataclass
from typing import Dict, List, Tuple, Callable

from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.neighbors import NearestNeighbors


@dataclass
class EmpPreprocArtifacts:
    num_cols: List[str]
    cat_cols: List[str]
    scaler: StandardScaler
    ohe: OneHotEncoder
    num_dim: int
    cat_slices: List[Tuple[int, int]]
    feat2idx: Dict[str, List[int]]
    D: int


def fit_empirical_preprocessor(X_train_raw: pd.DataFrame) -> EmpPreprocArtifacts:
    num_cols = X_train_raw.select_dtypes(include=["number"]).columns.tolist()
    cat_cols = [c for c in X_train_raw.columns if c not in num_cols]

    scaler = StandardScaler()
    Xn = scaler.fit_transform(X_train_raw[num_cols].to_numpy(dtype=np.float32))
    num_dim = Xn.shape[1]

    ohe = OneHotEncoder(sparse_output=False, handle_unknown="ignore")
    Xc = ohe.fit_transform(X_train_raw[cat_cols].astype(str))

    cat_sizes = [len(cats) for cats in ohe.categories_]
    cat_slices = []
    start = num_dim
    for k in cat_sizes:
        cat_slices.append((start, start + k))
        start += k
    D = start

    feat2idx_proc: Dict[str, List[int]] = {}
    for j, col in enumerate(num_cols):
        feat2idx_proc[col] = [j]
    for col, (s, e) in zip(cat_cols, cat_slices):
        feat2idx_proc[col] = list(range(s, e))

    return EmpPreprocArtifacts(
        num_cols=num_cols,
        cat_cols=cat_cols,
        scaler=scaler,
        ohe=ohe,
        num_dim=num_dim,
        cat_slices=cat_slices,
        feat2idx=feat2idx_proc,
        D=D,
    )


def transform_empirical(X_raw: pd.DataFrame, ep: EmpPreprocArtifacts) -> np.ndarray:
    Xn = ep.scaler.transform(X_raw[ep.num_cols].to_numpy(dtype=np.float32))
    Xc = ep.ohe.transform(X_raw[ep.cat_cols].astype(str))
    return np.concatenate([Xn, Xc], axis=1).astype(np.float32)


def make_empirical_global_utility_fn(
    *,
    model,
    X_train_raw: pd.DataFrame,
    X_eval_raw: pd.DataFrame,
    y_eval: np.ndarray,
    feat_names: List[str],
    n_eval: int = 64,
    seed: int = 0,
    K: int = 200,
    k_nn: int = 500,
) -> Callable[[List[str]], float]:
    rng = np.random.RandomState(seed)

    N = X_eval_raw.shape[0]
    n_eval = min(int(n_eval), int(N))
    idxs = rng.choice(N, size=n_eval, replace=False)
    X_sub = X_eval_raw.iloc[idxs].reset_index(drop=True)
    y_sub = np.asarray(y_eval)[idxs].astype(int)

    ep = fit_empirical_preprocessor(X_train_raw)
    Xtr_proc = transform_empirical(X_train_raw, ep)
    Xev_proc = transform_empirical(X_sub, ep)

    nn_cache: Dict[Tuple[str, ...], NearestNeighbors] = {}
    idx_cache: Dict[Tuple[str, ...], np.ndarray] = {}
    v_cache: Dict[frozenset, float] = {}

    def _sample_unconditional(Ki: int) -> pd.DataFrame:
        jj = rng.randint(0, X_train_raw.shape[0], size=Ki)
        return X_train_raw.iloc[jj].copy()

    def v_global(S_feats: List[str]) -> float:
        keyS = tuple(sorted(S_feats))
        Sset = frozenset(keyS)
        if Sset in v_cache:
            return v_cache[Sset]

        if len(S_feats) == len(feat_names):
            proba1 = model.predict_proba(X_sub)[:, 1].astype(np.float64)
            fy = np.where(y_sub == 1, proba1, 1.0 - proba1)
            out = float(np.mean(fy))
            v_cache[Sset] = out
            return out

        if len(S_feats) == 0:
            neigh_idx = None
        else:
            if keyS not in idx_cache:
                idxS = []
                for f in keyS:
                    idxS.extend(ep.feat2idx[f])
                idxS = np.asarray(idxS, dtype=np.int64)

                if keyS not in nn_cache:
                    nn = NearestNeighbors(
                        n_neighbors=min(k_nn, Xtr_proc.shape[0]),
                        algorithm="auto",
                        metric="euclidean",
                    )
                    nn.fit(Xtr_proc[:, idxS])
                    nn_cache[keyS] = nn

                nn = nn_cache[keyS]
                _, inds = nn.kneighbors(Xev_proc[:, idxS], return_distance=True)
                idx_cache[keyS] = inds.astype(np.int64)
            neigh_idx = idx_cache[keyS]

        vals = []
        for i in range(n_eval):
            x_i = X_sub.iloc[i]
            y_i = int(y_sub[i])

            if len(S_feats) == 0:
                X_samp = _sample_unconditional(K)
            else:
                pool = neigh_idx[i]
                pick = pool[rng.randint(0, pool.shape[0], size=K)]
                X_samp = X_train_raw.iloc[pick].copy()
                for f in keyS:
                    X_samp.loc[:, f] = x_i[f]

            proba1 = model.predict_proba(X_samp)[:, 1].astype(np.float64)
            fy = proba1 if y_i == 1 else (1.0 - proba1)
            vals.append(float(np.mean(fy)))

        out = float(np.mean(vals))
        v_cache[Sset] = out
        return out

    return v_global


def make_lambda_from_map(weight_map: Dict[str, float], feat_names: List[str], feat2idx: Dict[str, int]) -> np.ndarray:
    lam = np.ones(len(feat_names), dtype=np.float64)
    for f, w in weight_map.items():
        if f in feat2idx:
            lam[feat2idx[f]] = float(w)
    return lam


def lambda_uniform(n_features: int) -> np.ndarray:
    return np.ones(n_features, dtype=np.float64)


def lambda_block_wise(
    base: float,
    block_features: List[str],
    exponents: List[float],
    all_features: List[str],
    feat2idx: Dict[str, int],
) -> np.ndarray:
    if len(block_features) != len(exponents):
        raise ValueError(f"block_features ({len(block_features)}) and exponents ({len(exponents)}) must have same length")
    
    weight_map = {}
    for f in all_features:
        if f in block_features:
            idx = block_features.index(f)
            exp = exponents[idx]
            weight_map[f] = base ** float(exp)
        else:
            weight_map[f] = 1.0  # Default weight for features not in the block
    
    return make_lambda_from_map(weight_map, all_features, feat2idx)


def are_comparable(u_name: str, v_name: str, P: dict, name_to_idx: dict) -> bool:
    u_idx = name_to_idx[u_name]
    v_idx = name_to_idx[v_name]
    
    if v_idx in P["succs"][u_idx]:
        return True
    
    if u_idx in P["succs"][v_idx]:
        return True
    
    visited = set()
    stack = [u_idx]
    while stack:
        current = stack.pop()
        if current == v_idx:
            return True
        if current in visited:
            continue
        visited.add(current)
        for succ in P["succs"][current]:
            if succ not in visited:
                stack.append(succ)
    
    visited = set()
    stack = [v_idx]
    while stack:
        current = stack.pop()
        if current == u_idx:
            return True
        if current in visited:
            continue
        visited.add(current)
        for succ in P["succs"][current]:
            if succ not in visited:
                stack.append(succ)
    
    return False


def is_predecessor(u_name: str, v_name: str, P: dict, name_to_idx: dict) -> bool:
    u_idx = name_to_idx[u_name]
    v_idx = name_to_idx[v_name]
    
    if v_idx in P["succs"][u_idx]:
        return True
    
    visited = set()
    stack = [u_idx]
    while stack:
        current = stack.pop()
        if current == v_idx:
            return True
        if current in visited:
            continue
        visited.add(current)
        for succ in P["succs"][current]:
            if succ not in visited:
                stack.append(succ)
    
    return False


def is_successor(u_name: str, v_name: str, P: dict, name_to_idx: dict) -> bool:
    return is_predecessor(v_name, u_name, P, name_to_idx)


def swap_delay(perm: list, feature: str, P: dict, name_to_idx: dict) -> list:
    perm = perm.copy()
    feat_idx = perm.index(feature)
    
    while feat_idx < len(perm) - 1:
        right_feature = perm[feat_idx + 1]
        
        if are_comparable(feature, right_feature, P, name_to_idx):
            if is_predecessor(feature, right_feature, P, name_to_idx):
                break
        
        perm[feat_idx], perm[feat_idx + 1] = perm[feat_idx + 1], perm[feat_idx]
        feat_idx += 1
    
    return perm


def swap_advance(perm: list, feature: str, P: dict, name_to_idx: dict) -> list:
    perm = perm.copy()
    feat_idx = perm.index(feature)
    
    while feat_idx > 0:
        left_feature = perm[feat_idx - 1]
        
        if are_comparable(feature, left_feature, P, name_to_idx):
            if is_successor(feature, left_feature, P, name_to_idx):
                break
        
        perm[feat_idx], perm[feat_idx - 1] = perm[feat_idx - 1], perm[feat_idx]
        feat_idx -= 1
    
    return perm


def compute_utility_trajectory(perm: list, v_global_fn: Callable[[List[str]], float], feat_names: List[str]) -> List[float]:
    F = len(feat_names)
    utilities = []
    
    v_empty = v_global_fn([])
    utilities.append(v_empty)
    
    S = []
    for feat in perm:
        S.append(feat)
        v_curr = v_global_fn(S)
        utilities.append(v_curr)
    
    return utilities

