import torch
import torch.nn.functional as F
import numpy as np
from typing import Optional

USE_RANDOM_DIRECTIONS = True
PROJECTION_DIM = 2
MAX_POWER = 6
EPSILON = 1e-8

DEFAULT_FEATURE_FLAGS = {
    "adj_m3_m2": True,
    "adj_m4_m2": True,
    "adj_m5_m3": True,
    "adj_m6_m4": True,
    "adj_m4_m3": True,
    "adj_m6_m2": True,
    "adj_m5_m2": False,
    "adj_m6_m3": False,
    "adj_m5_m4": False,
    "regularity_proxy": True,
    "spectral_spread": False,
    "clustering_proxy": True,
    "triangle_density": False,
    "adj_m3_norm": True,
    "adj_m4_norm": False,
    "lap_m2_m1": False,
    "lap_m3_m2": False,
    "lap_m4_m2": False,
    "lap_m4_m3": False,
    "adj_lap_m2": False,
    "adj_lap_m4": False,
    "adj_m2_norm": True,
    "lap_m2_norm": False,
}

ALL_POSSIBLE_FEATURES = list(DEFAULT_FEATURE_FLAGS.keys())

_ALL_FAMILIES_FULL = [
    ("adj_local", ["adj_m3_m2", "adj_m4_m2", "adj_m2_norm", "triangle_density"]),
    ("adj_global", ["adj_m5_m3", "adj_m6_m4", "adj_m6_m2", "adj_m5_m2", "adj_m6_m3", "adj_m5_m4"]),
    ("adj_structure", ["regularity_proxy", "spectral_spread", "clustering_proxy"]),
    ("adj_scale", ["adj_m3_norm", "adj_m4_norm", "adj_m2_norm"]),
    ("lap_shape", ["lap_m2_m1", "lap_m3_m2", "lap_m4_m2", "lap_m4_m3"]),
    ("cross", ["adj_lap_m2", "adj_lap_m4", "adj_m4_m3"]),
    ("scale", ["adj_m2_norm", "lap_m2_norm"]),
]


class FeatureConfig:
    def __init__(self, feature_flags: dict[str, bool] | None = None, direction_seed: int = 42):
        self.direction_seed = direction_seed
        self.feature_flags = DEFAULT_FEATURE_FLAGS.copy()
        if feature_flags is not None:
            for key, value in feature_flags.items():
                if key not in ALL_POSSIBLE_FEATURES:
                    raise ValueError(f"Unknown feature: {key}. Valid features: {ALL_POSSIBLE_FEATURES}")
                self.feature_flags[key] = value
        self._compute_derived()

    def _compute_derived(self):
        self.feature_names = [name for name, enabled in self.feature_flags.items() if enabled]
        self.num_features = len(self.feature_names)
        self.families = []
        for family_name, features in _ALL_FAMILIES_FULL:
            enabled_features = [f for f in features if self.feature_flags.get(f, False)]
            if enabled_features:
                self.families.append((family_name, enabled_features))
        self.num_families = len(self.families)
        self._random_directions_cache = {}
        self._family_directions_cache = {}

    def get_random_directions(self, num_ensembles: int, device: torch.device, seed: int | None = None, proj_dim: int = PROJECTION_DIM) -> torch.Tensor:
        if seed is None:
            seed = self.direction_seed
        cache_key = (num_ensembles, seed, proj_dim)
        if cache_key not in self._random_directions_cache:
            gen = torch.Generator().manual_seed(seed)
            directions = torch.randn(num_ensembles, proj_dim, self.num_features, generator=gen)
            directions = directions / (directions.norm(dim=2, keepdim=True) + EPSILON)
            self._random_directions_cache[cache_key] = directions
        return self._random_directions_cache[cache_key].to(device)

    def get_family_directions(self, num_ensembles: int, seed: int | None = None) -> list[dict]:
        if seed is None:
            seed = self.direction_seed
        cache_key = (num_ensembles, seed)
        if cache_key not in self._family_directions_cache:
            rng = np.random.RandomState(seed)
            directions = []
            for i in range(num_ensembles):
                direction = {}
                for family_name, features in self.families:
                    feat_idx = rng.randint(0, len(features))
                    feature = features[feat_idx]
                    sign = rng.choice([-1, +1])
                    direction[family_name] = (feature, sign)
                directions.append(direction)
            self._family_directions_cache[cache_key] = directions
        return self._family_directions_cache[cache_key]

    def print_info(self):
        print(f"Enabled features ({self.num_features}): {self.feature_names}")
        print(f"Direction seed: {self.direction_seed}")


_DEFAULT_CONFIG = None


def _get_default_config() -> FeatureConfig:
    global _DEFAULT_CONFIG
    if _DEFAULT_CONFIG is None:
        _DEFAULT_CONFIG = FeatureConfig(None)
        _DEFAULT_CONFIG.print_info()
    return _DEFAULT_CONFIG


def reset_default_config():
    global _DEFAULT_CONFIG
    _DEFAULT_CONFIG = None


def _get_feature_names():
    return _get_default_config().feature_names


def _get_num_features():
    return _get_default_config().num_features


FEATURE_FLAGS = DEFAULT_FEATURE_FLAGS
ALL_FEATURE_NAMES = property(lambda self: _get_feature_names())
NUM_ALL_FEATURES = property(lambda self: _get_num_features())
PATTERN_NAMES = [f"direction_{i}" for i in range(1000)]


def _compute_adjacency_moments(adj: torch.Tensor, max_k: int = 6) -> dict[int, torch.Tensor]:
    B, N, _ = adj.shape
    moments = {}
    moments[1] = torch.zeros(B, device=adj.device)
    A2 = torch.bmm(adj, adj)
    moments[2] = torch.diagonal(A2, dim1=1, dim2=2).sum(dim=1)
    Ak = A2
    for k in range(3, max_k + 1):
        Ak = torch.bmm(Ak, adj)
        moments[k] = torch.diagonal(Ak, dim1=1, dim2=2).sum(dim=1)
    return moments


def _compute_laplacian_moments(adj: torch.Tensor, max_k: int = 4) -> dict[int, torch.Tensor]:
    B, N, _ = adj.shape
    degrees = adj.sum(dim=2)
    L = torch.diag_embed(degrees) - adj
    moments = {}
    moments[1] = degrees.sum(dim=1)
    L2 = torch.bmm(L, L)
    moments[2] = torch.diagonal(L2, dim1=1, dim2=2).sum(dim=1)
    Lk = L2
    for k in range(3, max_k + 1):
        Lk = torch.bmm(Lk, L)
        moments[k] = torch.diagonal(Lk, dim1=1, dim2=2).sum(dim=1)
    return moments


def _safe_ratio(num: torch.Tensor, denom: torch.Tensor, max_val: float = 100.0) -> torch.Tensor:
    ratio = num / (denom.abs() + EPSILON)
    mask = (denom.abs() > EPSILON * 100).float()
    return (ratio * mask).clamp(-max_val, max_val)


def _normalize_to_01(x: torch.Tensor, expected_max: float) -> torch.Tensor:
    return (x / (expected_max + EPSILON)).clamp(0, 1)


def _compute_all_features_with_config(adj: torch.Tensor, config: FeatureConfig) -> torch.Tensor:
    B, N, _ = adj.shape
    device = adj.device
    adj_moments = _compute_adjacency_moments(adj, max_k=MAX_POWER)
    lap_moments = _compute_laplacian_moments(adj, max_k=4)
    all_computed = {}

    adj_m3_m2 = _safe_ratio(adj_moments[3], adj_moments[2])
    all_computed["adj_m3_m2"] = _normalize_to_01(adj_m3_m2, N)

    adj_m4_m2 = _safe_ratio(adj_moments[4], adj_moments[2])
    all_computed["adj_m4_m2"] = _normalize_to_01(adj_m4_m2, N * N)

    adj_m5_m3 = _safe_ratio(adj_moments[5], adj_moments[3])
    all_computed["adj_m5_m3"] = _normalize_to_01(adj_m5_m3, N * N)

    adj_m6_m4 = _safe_ratio(adj_moments[6], adj_moments[4])
    all_computed["adj_m6_m4"] = _normalize_to_01(adj_m6_m4, N * N)

    adj_m4_m3 = _safe_ratio(adj_moments[4], adj_moments[3])
    all_computed["adj_m4_m3"] = _normalize_to_01(adj_m4_m3, N)

    adj_m6_m2 = _safe_ratio(adj_moments[6], adj_moments[2])
    all_computed["adj_m6_m2"] = _normalize_to_01(adj_m6_m2, N ** 4)

    adj_m5_m2 = _safe_ratio(adj_moments[5], adj_moments[2])
    all_computed["adj_m5_m2"] = _normalize_to_01(adj_m5_m2, N ** 3)

    adj_m6_m3 = _safe_ratio(adj_moments[6], adj_moments[3])
    all_computed["adj_m6_m3"] = _normalize_to_01(adj_m6_m3, N ** 3)

    adj_m5_m4 = _safe_ratio(adj_moments[5], adj_moments[4])
    all_computed["adj_m5_m4"] = _normalize_to_01(adj_m5_m4, N)

    reg_num = adj_moments[2] ** 2
    reg_denom = N * adj_moments[4]
    regularity_proxy = _safe_ratio(reg_num, reg_denom)
    all_computed["regularity_proxy"] = regularity_proxy.clamp(0, 2) / 2

    spectral_spread = _safe_ratio(adj_moments[4], adj_moments[2] ** 2 + EPSILON)
    all_computed["spectral_spread"] = _normalize_to_01(spectral_spread, N)

    clust_num = adj_moments[3] ** 2
    clust_denom = adj_moments[6]
    clustering_proxy = _safe_ratio(clust_num, clust_denom)
    all_computed["clustering_proxy"] = clustering_proxy.clamp(0, 2) / 2

    tri_denom = (adj_moments[2] + EPSILON) ** 1.5
    triangle_density = _safe_ratio(adj_moments[3], tri_denom)
    all_computed["triangle_density"] = _normalize_to_01(triangle_density, N ** 0.5)

    all_computed["adj_m3_norm"] = (adj_moments[3] / (N ** 3 + EPSILON)).clamp(0, 1)
    all_computed["adj_m4_norm"] = (adj_moments[4] / (N ** 4 + EPSILON)).clamp(0, 1)

    lap_m2_m1 = _safe_ratio(lap_moments[2], lap_moments[1])
    all_computed["lap_m2_m1"] = _normalize_to_01(lap_m2_m1, N)

    lap_m3_m2 = _safe_ratio(lap_moments[3], lap_moments[2])
    all_computed["lap_m3_m2"] = _normalize_to_01(lap_m3_m2, N * N)

    lap_m4_m2 = _safe_ratio(lap_moments[4], lap_moments[2])
    all_computed["lap_m4_m2"] = _normalize_to_01(lap_m4_m2, N * N * N)

    lap_m4_m3 = _safe_ratio(lap_moments[4], lap_moments[3])
    all_computed["lap_m4_m3"] = _normalize_to_01(lap_m4_m3, N)

    adj_lap_m2 = _safe_ratio(adj_moments[2], lap_moments[2])
    all_computed["adj_lap_m2"] = adj_lap_m2.clamp(0, 2) / 2

    adj_lap_m4 = _safe_ratio(adj_moments[4], lap_moments[4])
    all_computed["adj_lap_m4"] = adj_lap_m4.clamp(0, 2) / 2

    max_trace_A2 = N * (N - 1)
    all_computed["adj_m2_norm"] = (adj_moments[2] / (max_trace_A2 + EPSILON)).clamp(0, 1)

    max_trace_L2 = N * N * N
    all_computed["lap_m2_norm"] = (lap_moments[2] / (max_trace_L2 + EPSILON)).clamp(0, 1)

    features = torch.stack([all_computed[name] for name in config.feature_names], dim=1)
    return features


def _compute_all_features(adj: torch.Tensor) -> torch.Tensor:
    return _compute_all_features_with_config(adj, _get_default_config())


def batched_features_for_ensemble(
    batch: torch.Tensor,
    ensemble_id: int,
    num_ensembles: int = 36,
    config: FeatureConfig | None = None,
    proj_dim: int = PROJECTION_DIM,
) -> torch.Tensor:
    if config is None:
        config = _get_default_config()
    B, N, _ = batch.shape
    device = batch.device
    adj = (batch + batch.transpose(1, 2)) / 2
    adj = adj * (1.0 - torch.eye(N, device=device).unsqueeze(0))
    all_feats = _compute_all_features_with_config(adj, config)
    if USE_RANDOM_DIRECTIONS:
        directions = config.get_random_directions(num_ensembles, device, proj_dim=proj_dim)
        ens_directions = directions[ensemble_id % len(directions)]
        combined = torch.mm(all_feats, ens_directions.T)
    else:
        directions = config.get_family_directions(num_ensembles)
        direction = directions[ensemble_id % len(directions)]
        combined = torch.zeros(B, device=device)
        for family_name, (feature_name, sign) in direction.items():
            feat_idx = config.feature_names.index(feature_name)
            combined = combined + sign * all_feats[:, feat_idx]
        combined = combined / config.num_families
        combined = combined.unsqueeze(1)
    return combined


def batched_features(batch: torch.Tensor, config: FeatureConfig | None = None) -> torch.Tensor:
    if config is None:
        config = _get_default_config()
    B, N, _ = batch.shape
    device = batch.device
    adj = (batch + batch.transpose(1, 2)) / 2
    adj = adj * (1.0 - torch.eye(N, device=device).unsqueeze(0))
    return _compute_all_features_with_config(adj, config)


def compute_pattern_features(batch: torch.Tensor, pattern_name: str) -> torch.Tensor:
    ensemble_id = int(pattern_name.split("_")[-1]) if "_" in pattern_name else 0
    return batched_features_for_ensemble(batch, ensemble_id, 36)


def pairwise_distance(features: torch.Tensor) -> torch.Tensor:
    sq_norms = (features ** 2).sum(dim=1, keepdim=True)
    distances_sq = sq_norms + sq_norms.t() - 2 * torch.mm(features, features.t())
    distances_sq = F.relu(distances_sq)
    return torch.sqrt(distances_sq + EPSILON)


def metric_distance_diversity(
    graphs: torch.Tensor,
    orca_path: str,
    metric: str = "gcd",
) -> tuple[float, np.ndarray]:
    if isinstance(graphs, torch.Tensor):
        graphs = graphs.cpu().numpy()
    B = len(graphs)
    if metric == "gcd":
        distances = _compute_gcd_distances(graphs, orca_path)
    elif metric == "netlsd_heat":
        distances = _compute_netlsd_distances(graphs, "heat")
    elif metric == "netlsd_wave":
        distances = _compute_netlsd_distances(graphs, "wave")
    else:
        raise ValueError(f"Unknown metric: {metric}")
    mask = 1 - np.eye(B)
    avg_dist = (distances * mask).sum() / (B * (B - 1))
    return float(avg_dist), distances


def _compute_gcd_distances(graphs: np.ndarray, orca_path: str) -> np.ndarray:
    B, N, _ = graphs.shape
    gcds = []
    for i in range(B):
        adj = graphs[i]
        degrees = adj.sum(axis=1)
        gcd = np.histogram(degrees, bins=N, range=(0, N))[0].astype(float)
        gcd = gcd / (gcd.sum() + 1e-8)
        gcds.append(gcd)
    gcds = np.array(gcds)
    distances = np.zeros((B, B))
    for i in range(B):
        for j in range(i + 1, B):
            d = np.abs(gcds[i] - gcds[j]).sum()
            distances[i, j] = d
            distances[j, i] = d
    return distances


def _compute_netlsd_distances(graphs: np.ndarray, kernel: str) -> np.ndarray:
    B = len(graphs)
    signatures = []
    for i in range(B):
        adj = graphs[i]
        degrees = adj.sum(axis=1)
        L = np.diag(degrees) - adj
        eigvals = np.linalg.eigvalsh(L)
        eigvals = np.maximum(eigvals, 0)
        if kernel == "heat":
            sig = np.exp(-eigvals)
        else:
            sig = np.cos(np.sqrt(eigvals))
        signatures.append(sig)
    signatures = np.array(signatures)
    distances = np.zeros((B, B))
    for i in range(B):
        for j in range(i + 1, B):
            d = np.sqrt(((signatures[i] - signatures[j]) ** 2).sum())
            distances[i, j] = d
            distances[j, i] = d
    return distances


def print_direction_info(config: FeatureConfig, num_ensembles: int, device: torch.device):
    if USE_RANDOM_DIRECTIONS:
        directions = config.get_random_directions(num_ensembles, device)
        proj_dim = directions.shape[1]
        print(f"\nGenerated {num_ensembles} RANDOM directions in {config.num_features}D spectral moment space:")
        print(f"  Projection dimension: {proj_dim}")
        print(f"  Enabled features: {config.feature_names}")
        for i in range(min(5, num_ensembles)):
            dir_vec = directions[i, 0]
            top_k = min(3, config.num_features)
            abs_weights = dir_vec.abs()
            top_indices = abs_weights.argsort(descending=True)[:top_k]
            top_feats = [f"{config.feature_names[j]}:{dir_vec[j]:+.2f}" for j in top_indices]
            print(f"  Ensemble {i} (dir 0): top = {', '.join(top_feats)}")
        if num_ensembles > 5:
            print(f"  ... and {num_ensembles - 5} more")
        print()
    else:
        directions = config.get_family_directions(num_ensembles)
        print(f"\nGenerated {num_ensembles} orthogonal family directions:")
        for i, d in enumerate(directions[:5]):
            parts = [f"{fn}:{feat}({'+' if s > 0 else '-'})"
                     for fn, (feat, s) in d.items()]
            print(f"  Ensemble {i}: {', '.join(parts)}")
        if num_ensembles > 5:
            print(f"  ... and {num_ensembles - 5} more")
        print()