import torch
import torch.nn as nn
import torch.nn.functional as F
from spikingjelly.activation_based import layer, neuron
from functools import partial

DEV = 'cuda'#torch.device('cuda:1')
#DEV=torch.device('mps')

TORCH_SEED = 42

TORCH_SEEDS = [2197414221, 3983550199, 2290376441, 1544162936, 3859964368]

def torch_empty_cache():
    if DEV == 'mps':
        torch.mps.empty_cache()
    elif DEV == 'cuda':
        torch.cuda.empty_cache()

def matinverse_device():
    if DEV == 'mps':
        return 'cpu'
    return DEV

def torch_synchronize():
    if DEV == 'mps':
        torch.mps.synchronize()
    elif DEV == 'cuda':
        torch.cuda.synchronize()

def find_layers(module, layers=[layer.Conv2d, layer.Linear], name=''):
    res = {}
    for name1, child in module.named_children():
        if type(child) in layers:
            res[int(name1)] = child
    return res

def find_layers_string(module, layers=[layer.Conv2d, layer.Linear], name=''):
    res = {}
    for name1, child in module.named_children():
        if type(child) in layers:
            res[name1] = child
    return res

def find_linear_neuron_pairs(module, name=''):
    """
    Finds all consecutive pairs of Linear and Activation layers within a PyTorch module.

    Args:
        module (nn.Module): The PyTorch module to search.
        name (str): The prefix of the module name (used for nested modules).

    Returns:
        list: A list of tuples, where each tuple contains:
              - A tuple with the full name and the Linear layer.
              - A tuple with the full name and the Activation layer.
    """
    pairs = []
    named_children_list = list(module.named_children())
    activation_layers = (neuron.LIFNode, neuron.ParametricLIFNode)

    for i, (name1, child) in enumerate(named_children_list):
        current_full_name = f"{name}.{name1}" if name else name1

        # Check for Linear then Activation pair among direct children
        if i > 0:
            prev_name1, prev_child = named_children_list[i - 1]
            previous_full_name = f"{name}.{prev_name1}" if name else prev_name1
            if isinstance(prev_child, nn.Linear) and isinstance(child, activation_layers):
                pairs.append(((previous_full_name, prev_child), (current_full_name, child)))

        # Recursively search in children
        pairs.extend(find_linear_neuron_pairs(child, current_full_name))

    return pairs

def _laplace_kernel_1d(d: int,
                       tau: float,
                       device,
                       dtype,
                       mode: str = "exact",
                       normalize: bool = True) -> torch.Tensor:
    """
    Symmetric (Laplace) kernel of length 2d-1 centered at 0:
      k = -d+1, ..., -1, 0, 1, ..., d-1
      h[k] ~ a^{|k|}  or  exp(-|k|/tau)
    Returns shape: (1, 1, 2d-1)
    """
    k = torch.arange(-(d-1), d, device=device, dtype=dtype)  # length 2d-1

    if mode == "exact":
        a = torch.exp(torch.tensor(-1.0 / float(tau), device=device, dtype=dtype))
        gain = (1.0 - a) if normalize else 1.0
        h = gain * (a ** k.abs())
    elif mode == "euler":
        a = 1.0 - 1.0 / float(tau)
        a_mag = torch.tensor(abs(a), device=device, dtype=dtype).clamp(max=0.999)
        gain = (1.0 / float(tau)) if normalize else 1.0
        h = gain * (a_mag ** k.abs())
    elif mode == "plain":
        h = torch.exp(-k.abs() / float(tau))
        if normalize:
            h = h / float(tau)
    else:
        raise ValueError("mode must be 'exact', 'euler', or 'plain'")

    return h.view(1, 1, -1)  # (1,1,2d-1)

def _causal_exp_kernel_1d(d: int,
                          tau: float,
                          device,
                          dtype,
                          mode: str = "exact",
                          normalize: bool = True) -> torch.Tensor:
    """
    Causal exponential kernel h[k], k=0..d-1 (no negative lags).
    Returns shape: (1, 1, d)
    """
    k = torch.arange(d, device=device, dtype=dtype)  # 0..d-1
    if mode == "exact":
        a = torch.exp(torch.tensor(-1.0 / float(tau), device=device, dtype=dtype))
        gain = (1.0 - a) if normalize else 1.0
        h = gain * (a ** k)
    elif mode == "euler":
        a = 1.0 - 1.0 / float(tau)
        a_mag = torch.tensor(abs(a), device=device, dtype=dtype).clamp(max=0.999)
        gain = (1.0 / float(tau)) if normalize else 1.0
        h = gain * (a_mag ** k)
    elif mode == "plain":
        h = torch.exp(-k / float(tau))
        if normalize:
            h = h / float(tau)
    else:
        raise ValueError("mode must be 'exact', 'euler', or 'plain'")
    return h.view(1, 1, -1)  # (1,1,d)

def van_rossum_distance_torch(
    s1: torch.Tensor,
    s2: torch.Tensor,
    tau: float,
    reduction: str = "none",
    mode: str = "euler",       # "exact" | "euler" | "plain"
    normalize: bool = True,    # include natural gain (recommended)
    kernel_type: str = "causal"  # "symmetric" (Laplace) | "causal"
) -> torch.Tensor:
    """
    Batched van Rossum-like distance.

    Args:
        s1, s2: (d, b) spike trains (time length d, batch b).
        tau:    time constant (>0).
        reduction: 'none' -> (b,), 'sum' -> scalar, 'mean' -> scalar.
        mode:   kernel discretization.
        normalize: include gain to keep magnitudes comparable across tau.
        kernel_type:
            - "symmetric": Laplace kernel (even), length 2d-1, same-length conv.
            - "causal": causal exponential kernel (k>=0), length d, causal conv.

    Returns:
        distances: (b,) if reduction='none', else scalar.
    """
    assert s1.ndim == 2 and s2.ndim == 2 and s1.shape == s2.shape, "s1, s2 must both be (d, b)"
    d, b = s1.shape
    device = s1.device
    dtype = torch.float32

    # (d, b) -> (b, 1, d)
    x1 = s1.to(dtype).transpose(0, 1).unsqueeze(1)  # (b,1,d)
    x2 = s2.to(dtype).transpose(0, 1).unsqueeze(1)  # (b,1,d)

    if kernel_type == "symmetric":
        # Symmetric Laplace kernel (conv == correlation, no flip)
        kernel = _laplace_kernel_1d(d, tau, device=device, dtype=dtype, mode=mode, normalize=normalize)  # (1,1,2d-1)
        pad = kernel.shape[-1] // 2  # = d-1
        f1 = F.conv1d(x1, kernel, padding=pad)  # (b,1,d)
        f2 = F.conv1d(x2, kernel, padding=pad)  # (b,1,d)
    elif kernel_type == "causal":
        # Causal kernel: flip because conv1d is correlation
        k_causal = _causal_exp_kernel_1d(d, tau, device=device, dtype=dtype, mode=mode, normalize=normalize)  # (1,1,d)
        kernel = k_causal.flip(-1)  # implement true convolution
        pad = d - 1
        f1 = F.conv1d(x1, kernel, padding=pad)[:, :, :d]  # same-length output
        f2 = F.conv1d(x2, kernel, padding=pad)[:, :, :d]
    else:
        raise ValueError("kernel_type must be 'symmetric' or 'causal'")

    diff = (f1 - f2).squeeze(1)     # (b, d)
    dist = (diff ** 2).sum(dim=-1)  # (b,)

    if reduction == "sum":
        return dist.sum()
    if reduction == "mean":
        return dist.mean()
    return dist

def per_out_weight_stats(
    model,
    include_bias=True,
    to_cpu=True,
    precision=4,
    return_dict=False,
    max_out=None,
):
    """
    Prints (or returns) mean, variance, min and max for every *output column*
    in each learnable tensor of `model`.  The first dimension is treated as the
    output dimension (out-features / out-channels).

    New columns compared with the earlier version: **min** and **max**.
    """
    # ---- printing helpers --------------------------------------------------
    hdr = "{:<40s} {:>7s} {:>11s} {:>11s} {:>11s} {:>11s}"
    row = "{:<40s} {:>7d} {:>11s} {:>11s} {:>11s} {:>11s}"
    if not return_dict:
        print(hdr.format("parameter", "out_idx", "mean", "variance", "min", "max"))
        print("-" * 93)

    out = {}
    with torch.no_grad():
        for name, p in model.named_parameters():
            if not include_bias and p.dim() == 1:
                continue

            # ---------- flatten everything except dim-0 (outputs) -----------
            if p.dim() == 1:                     # bias: one value per output already
                flat = p.float().unsqueeze(1)    # shape (O,1) to reuse code below
            else:
                flat = p.float().view(p.shape[0], -1)  # (O, *)

            m   = flat.mean(1)
            var = flat.var(1, unbiased=False)
            mn  = flat.min(1).values
            mx  = flat.max(1).values

            stats = list(zip(m.tolist(), var.tolist(), mn.tolist(), mx.tolist()))
            if return_dict:
                out[name] = [
                    {"mean": a, "var": b, "min": c, "max": d} for a, b, c, d in stats
                ]
            else:
                rng = range(len(stats)) if max_out is None else range(max_out)
                for i in rng:
                    a, b, c, d = stats[i]
                    print(
                        row.format(
                            name,
                            i,
                            f"{a:.{precision}f}",
                            f"{b:.{precision}f}",
                            f"{c:.{precision}f}",
                            f"{d:.{precision}f}",
                        )
                    )
                if max_out is not None and len(stats) > max_out:
                    print(f"... ({len(stats) - max_out} more outputs hidden)")

    return out if return_dict else None