import torch
import torch.nn as nn
from typing import Any, Callable, Literal
from symo.group import Eq, I, B, S, O
from symo.experiments.models import MLP, Activation
from symo.factory2 import GroupsSpec, groups_spec
from symo.utils import AtomicGroup


def group_config(
    mlp: MLP,
    hid_group: type | None = None,
    same: bool = False,
    inout_group: type | None = None,
    transpose: bool = False,
) -> tuple[tuple[str, type | tuple], ...]:
    """Generate group configuration for MLP parameter."""
    act_groups = {
        "relu": S,
        "tanh": B,
        "gelu": S,
        torch.relu: S,
        torch.tanh: B,
        torch.nn.functional.gelu: S,
    }

    shapes_ord = []
    shapes = {}
    for name, param in mlp.named_parameters():
        shapes_ord.append((name, param.shape))
        shapes[name] = param.shape

    if hasattr(mlp, "_orig_mod"):
        prefix = "_orig_mod."
    else:
        prefix = ""

    act = mlp.activation
    act_group_key = act_groups.get(act, S)

    inout_group = I if inout_group is None else inout_group
    num_layers = len(mlp.layers)

    def hidg(layer: int, dim: int):
        group = act_group_key if hid_group is None else hid_group
        name = f"L{dim}" if same else f"L{layer}-{dim}"
        return group[name, dim]

    def kernel_key(l: int) -> str:
        """Generate weight parameter name for layer l."""
        return f"{prefix}layers.{l}.weight"

    def bias_key(l: int) -> str:
        """Generate bias parameter name for layer l."""
        return f"{prefix}layers.{l}.bias"

    group_spec = []

    # First layer (input -> first hidden)
    outd, ind = shapes[kernel_key(0)]
    ing = inout_group["input", ind]
    outg = hidg(1, outd)

    pair = (ing, outg) if transpose else (outg, ing)
    group_spec.append((kernel_key(0), pair))
    if bias_key(0) in shapes:
        group_spec.append((bias_key(0), outg))

    # Hidden layers
    l = num_layers - 1
    for h in range(1, l):
        outd, ind = shapes[kernel_key(h)]
        ing = hidg(h, ind)
        outg = hidg(h + 1, outd)

        pair = (ing, outg) if transpose else (outg, ing)
        group_spec.append((kernel_key(h), pair))
        if bias_key(h) in shapes:
            group_spec.append((bias_key(h), outg))

    # Last layer (last hidden -> output)
    outd, ind = shapes[kernel_key(l)]
    ing = hidg(l, ind)

    outg = inout_group["output", outd]
    pair = (ing, outg) if transpose else (outg, ing)
    group_spec.append((kernel_key(l), pair))

    if bias_key(l) in shapes:
        group_spec.append((bias_key(l), outg))

    # Reorder group_spec to match the original parameter order
    group_spec_dict = dict(group_spec)
    group_spec_ord = []
    for k, _ in shapes_ord:
        if k in group_spec_dict:
            group_spec_ord.append((k, group_spec_dict[k]))

    return tuple(group_spec_ord)


def group_config_v2(
    mlp: MLP,
    hid_group: AtomicGroup | None = None,
    same: bool = False,
    inout_group: AtomicGroup | None = None,
    named: bool = True,
) -> GroupsSpec:
    """Generate group configuration for MLP parameter."""

    act_groups: dict[Any, AtomicGroup] = {
        "relu": "S",
        "tanh": "B",
        "gelu": "S",
        torch.relu: "S",
        torch.tanh: "B",
        torch.nn.functional.gelu: "S",
    }

    shapes_ord = []
    shapes = {}
    for name, param in mlp.named_parameters():
        shapes_ord.append((name, param.shape))
        shapes[name] = param.shape

    if hasattr(mlp, "_orig_mod"):
        prefix = "_orig_mod."
    else:
        prefix = ""

    act = mlp.activation
    act_group_key = act_groups.get(act, "S")

    inout_group = "I" if inout_group is None else inout_group
    hid_group = act_group_key if hid_group is None else hid_group
    num_layers = len(mlp.layers)

    def hidg(layer: int, dim: int):
        name = f"L{dim}" if same else f"L{layer}-{dim}"
        return f"{hid_group}_{name}", name

    def kernel_key(l: int) -> str:
        """Generate weight parameter name for layer l."""
        return f"{prefix}layers.{l}.weight"

    def bias_key(l: int) -> str:
        """Generate bias parameter name for layer l."""
        return f"{prefix}layers.{l}.bias"

    group_spec = []

    # First layer (input -> first hidden)
    outd, ind = shapes[kernel_key(0)]
    ing = f"{inout_group}_input"  # inout_group["input", ind]
    outg, outg_name = hidg(1, outd)
    dims = {"input": ind, outg_name: outd}

    pair = (outg, ing)
    group_spec.append((kernel_key(0), pair))
    if bias_key(0) in shapes:
        group_spec.append((bias_key(0), (outg,)))

    # Hidden layers
    l = num_layers - 1
    for h in range(1, l):
        outd, ind = shapes[kernel_key(h)]
        ing, ing_name = hidg(h, ind)
        outg, outg_name = hidg(h + 1, outd)
        dims = dims | {
            ing_name: ind,
            outg_name: outd,
        }

        pair = (outg, ing)
        group_spec.append((kernel_key(h), pair))
        if bias_key(h) in shapes:
            group_spec.append((bias_key(h), (outg,)))

    # Last layer (last hidden -> output)
    outd, ind = shapes[kernel_key(l)]
    ing, ing_name = hidg(l, ind)
    outg = f"{inout_group}_output"  # inout_group["output", outd]
    dims = dims | {
        ing_name: ind,
        "output": outd,
    }
    pair = (outg, ing)
    group_spec.append((kernel_key(l), pair))

    if bias_key(l) in shapes:
        group_spec.append((bias_key(l), (outg,)))

    group_spec_dict = dict(group_spec)
    group_spec_ord = []
    for k, _ in shapes_ord:
        if k in group_spec_dict:
            group_spec_ord.append((k, group_spec_dict[k]))

    group_spec_ord = [g for _, g in group_spec_ord]
    group_spec_ord = tuple(group_spec_ord)
    spec = groups_spec(group_spec_ord, dims)
    return spec
