import torch as T
import torch.nn
import functools
import dataclasses
import pyperplan.stacktrace
from itertools import permutations
from typing import Union, Optional, Literal, Sequence

from torch_scatter import scatter


# NLM that operates on a sparse tensor.
#
# Let's assume the dense shape of ternary predicate tensor is [B,O,O,O,F].
# Then its sparse representation (COO: COOrdinate format) is as follows:
#
# indices.shape = [5,nnz] where indices[:, i] is a 5-D coordinate in [B,O,O,O,F] of i-th non-zero entry
# values.shape = [nnz]
#
# The idea is based on https://openreview.net/forum?id=WKWAkkXGpWN
# The implementation is rewritten from the scratch.
# We handle sparse coo tensors (indices and values) directly rather than
# using the standard sparse tensors provided by pytorch as-is.
# This is not possible because the supported operations is too few
# (it does not even implement the maximum).
#
# note that this implementation avoids copying by using Tensor.expand .

# fck python
import operator
def prod(lst):
    return functools.reduce(operator.mul, lst)


# We decided not to use the built-in sparse tensors
# due to the bug that Tensor._values()
# do not propagate the gradient, forcing coalescing at every operation.

@dataclasses.dataclass
class SparseTensor:
    indices : T.Tensor
    values  : T.Tensor
    shape   : tuple

    def dense_dim(self):
        return len(self.values.shape)-1

    def sparse_dim(self):
        return len(self.shape)-self.dense_dim()

    @property
    def dtype(self):
        return self.values.dtype



def wrap(x:T.Tensor, *args, **kwargs) -> SparseTensor:
    if x.layout == T.strided:
        return wrap(x.to_sparse(*args, **kwargs))
    elif x.layout == T.sparse_coo:
        return SparseTensor(x.indices(), x.values(), x.size())
    else:
        raise ValueError(f"unknown layout {x.layout}: {x}")


def unwrap(x:SparseTensor) -> T.Tensor:
    return T.sparse_coo_tensor(x.indices, x.values, x.shape).coalesce()


def flatten(x:SparseTensor) -> SparseTensor:
    indices = x.indices
    values  = x.values
    size    = x.shape

    # equivalent implementations

    # for i, dim in enumerate(size):
    #     if i == 0:
    #         new_indices = indices[0]
    #     else:
    #         new_indices *= dim
    #         new_indices += indices[i]
    # new_indices = new_indices.unsqueeze(0) # [1,nnz]

    # unit = [1]
    # for dim in reversed(size[1:]):
    #     unit.insert(0, unit[0]*dim)
    # unit = T.tensor(unit)
    # new_indices = (unit @ x.indices).unsqueeze(0) # [5] @ [5,nnz] --> [nnz] --> [1,nnz]

    length = size[0]
    new_indices = indices[0].clone()
    for i, dim in enumerate(size[1:x.sparse_dim()]):
        new_indices *= dim
        new_indices += indices[i+1]
        length *= dim
    new_indices = new_indices.unsqueeze(0) # [1,nnz]

    return SparseTensor(new_indices, values, (length, *size[x.sparse_dim():] ))


def unflatten(x:SparseTensor, shape:tuple) -> SparseTensor:
    indices = x.indices
    values  = x.values
    size    = x.shape
    assert x.sparse_dim() == 1
    assert prod(size) == prod(shape)
    assert size[len(size)-x.dense_dim():] == shape[len(shape)-x.dense_dim():], \
        (f"Trying to unflatten {size} (dense shape: {size[len(size)-x.dense_dim():]}) into {shape} (dense shape: {shape[len(shape)-x.dense_dim():]}).\n"
         "The target shape must have the same dense shape as the original.")

    tmp_indices = indices       # [1, nnz]
    new_indices = T.empty(len(shape)-x.dense_dim(), len(values), device=indices.device) # [5, nnz]
    for i, dim in enumerate(reversed(shape[1:len(shape)-x.dense_dim()])):
        new_indices[-1-i] = tmp_indices % dim
        tmp_indices //= dim
    new_indices[0] = tmp_indices

    return SparseTensor(new_indices, values, shape)



def sparse_apply(fn, x:SparseTensor) -> SparseTensor:
    indices = x.indices       # [1,nnz]
    values  = x.values        # [nnz]
    size    = x.shape

    new_values = fn(values)
    return SparseTensor(indices, new_values, size[:x.sparse_dim()] + new_values.shape[1:])



def expand(x:SparseTensor, O:int) -> SparseTensor:
    """Expands a new dimension in input by duplicating each entry for n times"""
    indices = x.indices
    values  = x.values
    size    = x.shape

    B, *Os, F = size             # batch_size, objects (repeated), predicates
    rank, nnz = indices.shape

    flat = flatten(x)

    q = flat.indices
    q = q.reshape(nnz,1).expand(nnz,O).reshape(1,nnz*O)
    s = T.arange(O).reshape(1,O).expand(nnz,O).reshape(1,nnz*O)

    new_indices = q*O + s

    new_values = values.reshape(nnz,1,F).expand(nnz,O,F).reshape(nnz*O,F) # [nnz*O,F]

    new_flat = SparseTensor(new_indices, new_values, (flat.shape[0]*O,F))

    return unflatten(new_flat, size[:-1] + (O,F))


def reduce(x:SparseTensor, O:int, mode:Literal["fuzzy","probabilistic"]="fuzzy") -> SparseTensor:
    """Reduces by max at the last dimension."""
    indices = x.indices
    values  = x.values
    size    = x.shape

    B, *Os, F = size             # batch_size, objects (repeated), predicates
    rank, nnz = indices.shape

    flat = flatten(x)

    q = flat.indices
    bucket = (q // O)[0]

    # note: we assume that indices are in an increasing order.
    if mode == "fuzzy":
        scatter_result = scatter(values, bucket, dim=0, reduce="min")
    elif mode == "probabilistic":
        scatter_result = scatter(values, bucket, dim=0, reduce="mul")
    else:
        raise ValueError(f"mode {mode} is not one of 'fuzzy','probabilistic'")

    # when bucket contains a gap between indices,
    # then the returned value has 0 for those missing indices, which must be removed.
    mask = (scatter_result>0).any(dim=1)
    new_values       = scatter_result[mask]
    new_flat_indices = T.arange(scatter_result.shape[0])[mask]

    assert new_values.shape[0] == len(new_flat_indices)

    result = SparseTensor(new_flat_indices.unsqueeze(0), new_values, (prod(size[:-2]),F))

    return unflatten(result, size[:-2]+(F,))


def concat(xs : list[SparseTensor]) -> SparseTensor:

    B, *Os, _ = xs[0].shape

    new_indices = torch.concat([x.indices for x in xs], dim=1)
    new_shape   = B, *Os, sum([ x.shape[-1] for x in xs ])
    new_nnz = new_indices.shape[1]
    new_values = torch.zeros([new_nnz, new_shape[-1]])

    # each concatenated elements are considered as different predicates.

    offset_nnz = 0
    offset_F   = 0
    for x in xs:
        xv = x.values
        new_values[offset_nnz:offset_nnz+xv.shape[0],
                   offset_F  :offset_F  +xv.shape[1] ] += xv
        offset_nnz=offset_nnz+xv.shape[0]
        offset_F  =offset_F  +xv.shape[1]


    return SparseTensor(new_indices, new_values, new_shape)


def permute(x:SparseTensor) -> SparseTensor:
    """"""
    indices = x.indices       # [5, nnz]
    values  = x.values        # [nnz]
    size    = x.shape

    rank, _ = indices.shape

    if rank <= 3:
        return x

    result = []
    for dims in permutations(range(1,rank-1)): # e.g., if rank=5, range(1,4) = [1,2,3]
        # dims : e.g., (2, 3, 1)
        permuted_indices = T.stack([
            indices[0],         # 0
            *[indices[dim] for dim in dims], # 2,3,1
            indices[-1],        # 4
        ])
        result.append(SparseTensor(permuted_indices, values, size))

    return concat(result)





class SparseNLM(torch.nn.Module):
    """
    An NLM that assumes that the inputs are sparse_coo tensors.
    See the original NLM.
    """
    def __init__(self,
                 out_features:list[int],
                 mode:Literal["fuzzy","probabilistic"]="fuzzy",
                 activation:str="relu",
                 bias:bool=True,):
        super().__init__()

        self.mode = mode

        self.out_features = out_features
        self.out_arity = len(out_features) # warn: off-by-one error

        self.activation = getattr(torch.nn.functional, activation, lambda x:x)
        self.layers = torch.nn.ModuleList([
            torch.nn.LazyLinear(f, bias=bias)
            for f in out_features
        ])
        pass

    def check_shape(self, xs : list[T.Tensor]):
        for x in xs:
            assert x.layout == T.sparse_coo

        in_arity = len(xs)
        assert abs(in_arity - self.out_arity) <= 1, ValueError("shrink/expand must be by at most 1 arity")

        # all batch sizes are same
        B = xs[0].size(0)
        for x in xs[1:]:
            assert B == x.size(0)

        # all object sizes are same (if the maximum input arity is larger than 1)
        if len(xs) >= 2:
            O = xs[1].size(1)
            for x in xs[2:]:
                for j in range(1, x.dim()-1):
                    assert O == x.size(j)

        return in_arity, B, O

    def forward(self, xs : list[T.Tensor]) -> list[T.Tensor]:
        in_arity, B, O = self.check_shape(xs)
        ys : list[SparseTensor] = [ wrap(x) for x in xs]

        return [
            unwrap(
                sparse_apply(
                    lambda x: self.activation(layer(x)),
                    permute(
                        concat(
                            [
                                # expand
                            *( [expand(ys[i-1], O)]
                               if i   > 0        else [] ),
                                # original
                            *( [ys[i]]
                               if i   < in_arity else [] ), # note: i could be out-of-bound for arity expansion
                                # reduce
                            *( [reduce(ys[i+1], O, mode=self.mode)]
                               if i+1 < in_arity else [] )]))))
            for i, layer in zip(range(self.out_arity), self.layers)
        ]


NLMS = SparseNLM


if __name__ == "__main__":

    import pyperplan.stacktrace as stacktrace
    import inspect
    from colors import red, green, blue
    try:

        def sparse_rand(shape,percentage=10):
            x = T.rand(*shape)
            x[T.randint(low=0,high=100,size=shape) > percentage] = 0
            return x.requires_grad_()


        def test(fn):
            lines, lineno = inspect.getsourcelines(fn)
            try:
                fn()
            except Exception as e:
                errormsg = str(e).split('\n')[0]
                print(red(f"Failed  | Line {lineno} | {lines[0].strip()} ")+errormsg)
                # stacktrace.format(exit=False, arraytypes=[np.ndarray, T.Tensor])
            else:
                print(green(f"Success | Line {lineno} | {lines[0].strip()} "))


        shape = (10,10)

        x = sparse_rand(shape)
        y = sparse_rand(shape)

        def test_sparse_ops(x_sparse, y_sparse):
            # print(type(x_sparse),dir(x_sparse)) # it does have methods like log. It is throwing errors inside log
            test(lambda : x_sparse.reshape(10,10,1))
            test(lambda : x_sparse.flatten())
            test(lambda : T.maximum(x_sparse, y_sparse))
            test(lambda : T.maximum(x_sparse, y))
            test(lambda : T.maximum(x, y_sparse))
            test(lambda : T.minimum(x_sparse, y_sparse))
            test(lambda : T.minimum(x_sparse, y))
            test(lambda : T.minimum(x, y_sparse))
            test(lambda : x_sparse.max(dim=1))
            test(lambda : x_sparse.min(dim=1))
            test(lambda : x_sparse.log())
            test(lambda : x_sparse.exp())
            test(lambda : x_sparse.sin())
            test(lambda : x_sparse.cos())
            test(lambda : x_sparse.tan())
            test(lambda : x_sparse + y_sparse)
            test(lambda : x_sparse - y_sparse)
            test(lambda : x_sparse * y_sparse)
            test(lambda : x_sparse / y_sparse)
            test(lambda : x_sparse // y_sparse)
            test(lambda : x_sparse + y)
            test(lambda : x_sparse - y)
            test(lambda : x_sparse * y)
            test(lambda : x_sparse / y)
            test(lambda : x_sparse // y)
            test(lambda : x + y_sparse)
            test(lambda : x - y_sparse)
            test(lambda : x * y_sparse)
            test(lambda : x / y_sparse)
            test(lambda : x // y_sparse)


        test_sparse_ops(x.to_sparse(), y.to_sparse())
        test_sparse_ops(x.to_sparse_csr(), y.to_sparse_csr())

        # x_coo = x.to_sparse(1)
        # x_coo2 = x_coo.to_sparse(2) # no, you can't change the sparse_dim this way

        x_coo = wrap(x)   # same as to_sparse(2)
        flat1 = flatten(x_coo)
        flat2 = wrap(x.flatten())
        assert (flat1.indices == flat2.indices).all()
        assert (flat1.shape == flat2.shape)
        x_coo_flat1 = unflatten(flat1,x_coo.shape)
        x_coo_flat2 = unflatten(flat2,x_coo.shape)
        assert (x_coo_flat1.indices == x_coo_flat2.indices).all()
        assert (x_coo_flat1.shape == x_coo_flat2.shape)
        assert (x_coo_flat1.indices == x_coo_flat2.indices).all()
        assert (x_coo_flat1.shape == x_coo_flat2.shape)

        x_coo = wrap(x,1)
        flat1 = flatten(x_coo)
        flat2 = wrap(x.reshape(10,10),1)
        assert (flat1.indices == flat2.indices).all()
        assert (flat1.shape == flat2.shape)
        x_coo_flat1 = unflatten(flat1,x_coo.shape)
        x_coo_flat2 = unflatten(flat2,x_coo.shape)
        assert (x_coo_flat1.indices == x_coo_flat2.indices).all()
        assert (x_coo_flat1.shape == x_coo_flat2.shape)
        assert (x_coo_flat1.indices == x_coo_flat2.indices).all()
        assert (x_coo_flat1.shape == x_coo_flat2.shape)

        O = 2
        B = 10
        Q = 5
        x = [
            sparse_rand((B, 7)),
            sparse_rand((B, O, 11)),
            sparse_rand((B, O, O, 13)),
            sparse_rand((B, O, O, O, 17)),
        ]
        for i, xi in enumerate(x):
            print(f"input original arity {i}: shape {xi.shape}")

        x_coo    = [ wrap(xi,i+1)   for i, xi in enumerate(x) ]
        for i, xi in enumerate(x_coo):
            print(f"input sparse   arity {i}: shape {xi.shape}")

        x_expand = [ expand(xi, O) for xi in x_coo[:-1] ]
        for i, xi in enumerate(x_expand):
            print(f"input expand   arity {i}: shape {xi.shape}")

        x_reduce = [ reduce(xi, O) for xi in x_coo[1:] ]
        for i, xi in enumerate(x_reduce):
            print(f"input reduce   arity {i+1}: shape {xi.shape}")

        x_concat = [
            concat([ z for z in [e,o,r] if z is not None ])
            for e, o, r in zip([None]+x_expand, x_coo, x_reduce+[None])
        ]
        for i, xi in enumerate(x_concat):
            print(f"input concat   arity {i}: shape {xi.shape}")

        x_permute = [ permute(xi) for xi in x_concat ]
        for i, xi in enumerate(x_permute):
            print(f"input permute  arity {i}: shape {xi.shape}")

        # testing dense layer + sparse input
        y_target = T.rand((100, 30))

        # Naive approach: after the first application, it becomes dense. Also, lazyness does not matter.
        layer1  = torch.nn.LazyLinear(40)
        layer2  = torch.nn.LazyLinear(30)
        x_layer = sparse_rand((100, 20)).to_sparse()
        y_layer = layer2(layer1(x_layer))
        assert y_layer.shape == (100, 30)
        assert y_layer.layout == T.strided
        print(f"Applying 2 layers: {x_layer.shape}/{x_layer.layout} -> {y_layer.shape}/{y_layer.layout}")
        loss = T.square(y_target - y_layer)
        loss = loss.sum(dim=-1)
        loss = loss.mean(dim=0)
        print(loss)
        loss.backward()
        print("layer1 weight grad:",type(layer1.weight.grad), layer1.weight.grad.layout, layer1.weight.grad.shape)
        print("layer2 weight grad:",type(layer2.weight.grad), layer2.weight.grad.layout, layer2.weight.grad.shape)
        print("layer1 bias grad:",type(layer1.bias.grad), layer1.bias.grad.layout, layer1.bias.grad.shape)
        print("layer2 bias grad:",type(layer2.bias.grad), layer2.bias.grad.layout, layer2.bias.grad.shape)
        print("input grad:",type(x_layer.grad))
        print("input grad:",type(x_layer.values().grad))
        print("input grad:",type(x_layer.indices().grad))

        # hybrid COO: naive approach fails, requires addmm during backward
        # layer1  = torch.nn.LazyLinear(40)
        # layer2  = torch.nn.LazyLinear(30)
        # x_layer = sparse_rand((100, 20)).to_sparse(1)
        # y_layer = layer2(layer1(x_layer))
        # assert y_layer.shape == (100, 30)
        # assert y_layer.layout == T.strided
        # print(f"Applying 2 layers: {x_layer.shape}/{x_layer.layout} -> {y_layer.shape}/{y_layer.layout}")
        # loss = T.square(y_target - y_layer)
        # loss = loss.sum(dim=-1)
        # loss = loss.mean(dim=0)
        # print(loss)
        # loss.backward()
        # print("layer1 grad:",type(layer1.weight.grad), layer1.weight.grad.layout, layer1.weight.grad.shape)
        # print("layer2 grad:",type(layer2.weight.grad), layer2.weight.grad.layout, layer2.weight.grad.shape)
        # print("input grad:",type(x_layer.grad))
        # print("input grad:",type(x_layer.values.grad))
        # print("input grad:",type(x_layer.indices.grad))

        # hybrid COO: using sparse_apply
        y_target = T.rand((100, 5, 5, 30)).to_sparse(3)
        layer1  = torch.nn.LazyLinear(40)
        layer2  = torch.nn.LazyLinear(30)
        x_layer = wrap(sparse_rand((100, 5, 5, 20)),3)
        h_layer = sparse_apply(layer1, x_layer)
        y_layer = unwrap(sparse_apply(layer2, h_layer))
        assert y_layer.shape == (100, 5, 5, 30)
        assert y_layer.layout == T.sparse_coo
        print(f"Applying 2 layers: {x_layer.shape} -> {y_layer.shape}")
        loss = T.square(y_target - y_layer)
        loss2 = loss.sum(dim=-1)
        assert loss2.shape == (100, 5, 5)
        loss2 = loss.sum(dim=-2)
        assert loss2.shape == (100, 5, 30)
        loss2 = loss.sum(dim=-3)
        assert loss2.shape == (100, 5, 30)
        loss2 = loss.sum(dim=-4)
        assert loss2.shape == (5, 5, 30)
        loss2 = loss.sum(dim=[1,2,3])
        assert loss2.shape == (100,)
        loss2 = loss.sum() / 100
        assert loss2.layout == T.strided
        assert loss2.shape == tuple()

        loss2.backward()
        print("layer1 weight grad:",type(layer1.weight.grad), layer1.weight.grad.layout, layer1.weight.grad.shape)
        print("layer2 weight grad:",type(layer2.weight.grad), layer2.weight.grad.layout, layer2.weight.grad.shape)
        print("layer1 bias grad:",type(layer1.bias.grad), layer1.bias.grad.layout, layer1.bias.grad.shape)
        print("layer2 bias grad:",type(layer2.bias.grad), layer2.bias.grad.layout, layer2.bias.grad.shape)

        O = 2
        B = 10
        Q = 5
        x = [
            sparse_rand((B, 7)),
            sparse_rand((B, O, 11)),
            sparse_rand((B, O, O, 13)),
            sparse_rand((B, O, O, O, 17)),
        ]
        x_coo    = [ xi.to_sparse(i+1)   for i, xi in enumerate(x) ]
        # integration testing
        y = SparseNLM([Q,Q,Q,Q])(x_coo)
        for i, yi in enumerate(y):
            print(f"output arity {i}: shape {yi.shape}")

        # increasing the maximum arity
        y = SparseNLM([Q,Q,Q,Q,Q])(x_coo)
        for i, yi in enumerate(y):
            print(f"expanded arity {i}: shape {yi.shape}")

        # shrinking the maximum arity
        y = SparseNLM([Q,Q,Q])(x_coo)
        for i, yi in enumerate(y):
            print(f"shrunk arity {i}: shape {yi.shape}")

        # should error
        try:
            SparseNLM([Q,Q])(x_coo)
        except Exception:
            pass
        else:
            print("error not raised")

        # test activations
        y = SparseNLM([Q,Q,Q],activation="relu")(x_coo)

        # test activations
        y = SparseNLM([Q,Q,Q],mode="probabilistic")(x_coo)

        layer1 = SparseNLM([Q,Q,Q])
        layer2 = SparseNLM([Q,Q])
        layer3 = SparseNLM([1])
        h_coo = layer1(x_coo)
        print([x.shape for x in h_coo])
        assert [list(x.shape) for x in h_coo] == [[B,Q],[B,O,Q],[B,O,O,Q],]
        h_coo = layer2(h_coo)
        assert [list(x.shape) for x in h_coo] == [[B,Q],[B,O,Q],]
        h_coo = layer3(h_coo)
        assert [list(x.shape) for x in h_coo] == [[B,1],]

        model = torch.nn.Sequential(layer1,layer2,layer3)
        y_coo = model(x_coo)[0]
        target = T.rand(B,1)
        loss = T.square(target-y_coo).sum() / B
        loss.backward()

        from torch.optim import Adam
        o = Adam(params=model.parameters(),lr=1.0)
        o.step()

    except:
        stacktrace.format(arraytypes=[T.Tensor,SparseTensor])

