import torch as T
import torch.nn
from itertools import permutations
from typing import Union, Optional, Literal
import pyperplan.stacktrace


class NLM(torch.nn.Module):
    """
    Gets a list of tensors and returns a list of tensors.
    Each tensor in the list contains all predicates of a particular arity.
    Their dimensions are

      [B, F_0]                  where
      [B, O, F_1]
      [B, O, O, F_2]              B : batch size,
      [B, O, O, O, F_3] ...       O : number of objects,
                                  Fi: number of predicates of arity i.

    out_features is a list [F_0,F_1,...F_Ao], where Ao is the maximum arity.
    """

    def __init__(self,
                 out_features:list[int],
                 mode:Literal["fuzzy","probabilistic"]="fuzzy",
                 activation:str="relu",
                 bias:bool=True,
                 ):
        super().__init__()

        if mode == "fuzzy":
            self.forall = lambda *args,**kwargs: T.min(*args,**kwargs)[0]
        else:
            self.forall = T.prod

        self.out_features = out_features
        self.out_arity = len(out_features) # warn: off-by-one error

        self.activation = getattr(torch.nn.functional, activation, lambda x:x)
        self.layers = torch.nn.ModuleList([
            torch.nn.LazyLinear(f, bias=bias)
            for f in out_features
        ])
        pass

    def check_shape(self, xs : list[T.Tensor]):
        in_arity = len(xs)
        assert abs(in_arity - self.out_arity) <= 1, ValueError("shrink/expand must be by at most 1 arity")

        # all batch sizes are same
        B = xs[0].size(0)
        for x in xs[1:]:
            assert B == x.size(0)

        # all object sizes are same (if the maximum input arity is larger than 1)
        if len(xs) >= 2:
            O = xs[1].size(1)
            for x in xs[2:]:
                for j in range(1, x.dim()-1):
                    assert O == x.size(j)

        return in_arity, B, O

    def forward(self, xs : list[T.Tensor]) -> list[T.Tensor]:
        in_arity, B, O = self.check_shape(xs)

        # note: expand is more efficient than repeat because it avoids copying
        # note2: this code must handle shrink/expand cases
        tmp1 = [
            T.cat([
                # expand
                *( [T.unsqueeze(xs[i-1], dim=1).expand(B,O,*([-1]*i))]
                   if i   >  0       else [] ),
                # original
                *( [xs[i]]
                   if i   < in_arity else [] ), # note: i could be out-of-bound for arity expansion
                # reduce
                *( [self.forall(xs[i+1], dim=1)]
                   if i+1 < in_arity else [] ),
            ], dim=-1)
            for i in range(self.out_arity)
        ]

        # permute
        tmp2 = [
            # arity 0 and 1 do not require permutation
            tmp1[0],
            *( [tmp1[1]] if self.out_arity >= 2 else [] ),
            *[
                T.cat(
                    [ T.permute(tmp1[i], [0] + list(perm) + [i+1])
                      for perm in permutations(range(1, i+1)) ],
                    dim=-1)
                for i in range(2,self.out_arity)
            ]
        ]

        tmp3 = [self.activation(l(x)) for x, l in zip(tmp2, self.layers)]

        return tmp3


if __name__ == "__main__":

    try:
        O = 13
        B = 17
        Q = 19
        x = [
            T.zeros((B, 2)),
            T.zeros((B, O, 3)),
            T.zeros((B, O, O, 5)),
            T.zeros((B, O, O, O, 7)), # max arity is 3
        ]
        for i, xi in enumerate(x):
            print(f"input  arity {i}: shape {xi.size()}")

        l = NLM([Q,Q,Q,Q])
        y = l(x)
        for i, yi in enumerate(y):
            print(f"output arity {i}: shape {yi.size()}")
        print(l)

        # increasing the maximum arity
        l = NLM([Q,Q,Q,Q,Q])
        y = l(x)
        for i, yi in enumerate(y):
            print(f"expanded arity {i}: shape {yi.size()}")
        print(l)

        # shrinking the maximum arity
        l = NLM([Q,Q,Q])
        y = l(x)
        for i, yi in enumerate(y):
            print(f"shrunk arity {i}: shape {yi.size()}")
        print(l)

        # should error
        try:
            NLM([Q,Q])(x)
        except Exception:
            pass
        else:
            print("error not raised")

        # test activations
        y = NLM([Q,Q,Q],activation="relu")(x)

        # test activations
        y = NLM([Q,Q,Q],mode="probabilistic")(x)


        # test input arity is 0 and 1 only
        x = [
            T.zeros((B, 2)),
            T.zeros((B, O, 3)),
        ]
        for i, xi in enumerate(x):
            print(f"input  arity {i}: shape {xi.size()}")
        # shrinking the maximum arity
        l = NLM([Q])
        y = l(x)
        for i, yi in enumerate(y):
            print(f"shrunk arity {i}: shape {yi.size()}")
        print(l)



    except:
        import pyperplan.stacktrace as stacktrace
        stacktrace.format()

