from itertools import product
import string

pattern = {1: "i", 2: "ij", 3: "ijk", 4: "ijkl", 5: "ijklm"}


def partition(n):
    """
    Generate partition solution for permutation group
    Args:
        n (int): number of the same permutation group
    """
    res = [""]
    for k in range(1, n + 1):
        new_res = []
        for sol in res:
            for char in pattern[k]:
                if char in sol or char == pattern[k][-1]:
                    new_res.append(sol + char)
        res = new_res
    return new_res


S_part = {k + 1: partition(k + 1) for k in range(5)}
O_part = {2: ["ii"], 4: ["ijij", "ijji", "iijj"]}
B_part = {2: ["ii"], 4: ["ijij", "ijji", "iijj", "iiii"]}
# TODO(bla): add partition solution for O/B group mathematically


def make_g():
    """
    Generator for indices, different for different groups and indices
    """
    it = iter(string.ascii_lowercase)

    def fn(k: int):
        return tuple(next(it) for _ in range(k))

    return fn


def filter_group(groups, mem, begin=0):
    """
    Record for each type of group, group by identities
    Each I group will immediately become distinct hence represent a free axis

    Return:
        mem (dict):
        a dictionary of key: [group_type][identity], value: their location in the equation
    """
    for i, (g, idx) in enumerate(zip(*groups)):
        if g not in mem:
            mem[g] = {}
        if g == "I":
            # each I group should be a free axis
            # even for the same group (same idx)
            while idx in mem[g]:
                idx += "_o"
            mem[g][idx] = [i + begin]
        else:
            if idx not in mem[g]:
                mem[g][idx] = [i + begin]
            else:
                mem[g][idx].append(i + begin)
    return mem


def expand(d, n):
    """
    Turn a dictionary of pair:
        key: location in the equation
        value: all possible solutions (partitions) as in einsum notation
    to all possible complete solutions

    """
    for combo in product(*d.values()):
        s = [""] * n
        for idxs, val in zip(d.keys(), combo):
            for i, c in zip(idxs, val):
                s[i] = c
        yield "".join(s)


def invariance_from_spec(lhs, rhs):
    """
    Generate invariant equation solutions (in einsum notation) from group-structured
    tensor specifications.

    Each element of `lhs` and `rhs` must be a string of the form:
        "<GroupType>_<Identity>_<Dimension>"

    After splitting by "_":
        * GroupType:  A letter indicating the group (e.g., "O", "S", "I").
        * Identity:   An integer or label identifying which indices belong to the same group.
        * Dimension:  A dimension name used to look up actual sizes externally.

    The function constructs all valid invariant contractions implied by the
    group structure.

    Args:
        lhs (list[str]):
            Left-hand side tensor specification. Each entry is a string
            describing a dimension in the format "<GroupType>_<Identity>_<Dim>".
        rhs (list[str]):
            Right-hand side tensor specification in the same format as `lhs`.

    Returns:
        * `solutions` (list[list[str, str, str]]):
                A list of contraction solutions.
                Each element is a 3-item list:
                `[lhs_einsum, rhs_einsum, factor]`

    Notes:
        - If an orthogonal group appears with odd multiplicity, the equation
          has no invariant solutions and the function returns an empty list.
        - The contraction patterns rely on external variables:
          `filter_group`, `make_g`, `pattern`, `O_part`, `S_part`, and `expand`.

    """
    ls, rs = len(lhs), len(rhs)
    lhs_g, lhs_idx = zip(*(g.split("_") for g in lhs))
    rhs_g, rhs_idx = zip(*(g.split("_") for g in rhs))

    mem = filter_group((lhs_g, lhs_idx), {})
    mem = filter_group((rhs_g, rhs_idx), mem, begin=ls)

    g = make_g()
    in_eq = {}
    factor_eq = ""

    if "O" in mem:
        for k, v in mem["O"].items():
            num = len(v)
            if num % 2 != 0:
                return []
            else:
                translater = str.maketrans(dict(zip(pattern[num], g(num))))
                in_eq[tuple(v)] = [sol.translate(translater) for sol in O_part[num]]

    if "B" in mem:
        for k, v in mem["B"].items():
            num = len(v)
            if num % 2 != 0:
                return []
            else:
                translater = str.maketrans(dict(zip(pattern[num], g(num))))
                in_eq[tuple(v)] = [sol.translate(translater) for sol in B_part[num]]

    if "S" in mem:
        for k, v in mem["S"].items():
            num = len(v)
            translater = str.maketrans(dict(zip(pattern[num], g(num))))
            in_eq[tuple(v)] = [sol.translate(translater) for sol in S_part[num]]

    if "I" in mem:
        for k, v in mem["I"].items():
            free_indx = "".join(g(len(v)))
            in_eq[tuple(v)] = free_indx
            factor_eq += free_indx

    return list([eq[:ls], eq[ls:], factor_eq] for eq in expand(in_eq, ls + rs))
