from typing import Sequence
import abc
import numpy as np
import torch
from torch import nn
from typing import Any, NamedTuple
from symo.group import Eq, I, flat_shape
from symo.factor import Factor, factor_init, ZeroFactor

Tensor = torch.Tensor


class Block(nn.Module):
    factor: Factor
    pos: tuple[int, int] | np.ndarray
    offset: tuple[int, int] | np.ndarray

    def __init__(
        self,
        group_pair: tuple[Any, Any],
        pos: tuple[int, int],
        offset: tuple[int, int] | np.ndarray,
    ):
        super().__init__()

        l, r = group_pair
        eq = Eq[l, r]()
        eq_t = Eq[r, l]()

        check_transpose = l != r

        factor = factor_init(eq)
        factor_t = None

        if isinstance(factor, ZeroFactor) and check_transpose:
            factor_t = factor_init(eq_t)

        if isinstance(factor_t, ZeroFactor):
            self.factor = factor
            self.transpose = False
        elif factor_t is None:
            self.factor = factor
            self.transpose = False
        else:
            self.factor = factor_t
            self.transpose = True

        if self.transpose and pos[0] != pos[1]:
            self.offset = np.flip(offset)
            self.pos = np.flip(pos)
        else:
            self.offset = offset
            self.pos = pos

    def cov(
        self,
        out: Tensor | None = None,
        diag: bool = False,
        surrogate: bool = True,
    ):
        cov = self.factor.cov(surrogate=surrogate)
        offset = self.offset
        i, j = offset
        ei, ej = i + cov.shape[0], j + cov.shape[1]

        if diag:
            if out is None:
                return cov
            else:
                out[i:ei, j:ej] = cov
                return out[i:ei, j:ej]

        cov_t = cov.T

        if out is None:
            return cov, cov_t
        else:
            out[i:ei, j:ej] = cov
            out[j:ej, i:ei] = cov_t
            return out[i:ei, j:ej], out[j:ej, i:ei]

    def matvec(
        self, lhs: Tensor, rhs: Tensor | None = None
    ) -> tuple[Tensor, Tensor | None]:
        tr = self.transpose
        if rhs is None:
            res = self.factor.matvec(lhs, transpose=tr)
            return res, None
        else:
            lhs_res = self.factor.matvec(lhs, transpose=False)
            rhs_res = self.factor.matvec(rhs, transpose=True)
            return lhs_res, rhs_res

    def update_with_vectors(self, lhs, rhs):
        pair = (lhs, rhs)
        eq = self.factor.eq
        weight = self.factor.outer_estimate(eq, pair)
        self.factor.weights.copy_(weight)

    def shape(self, surrogate: bool = True):
        if surrogate:
            shape = self.factor.eq.stable_shape
        else:
            shape = self.factor.eq.shape

        return flat_shape(shape)

    @property
    def is_zero(self) -> bool:
        return isinstance(self.factor, ZeroFactor)

    def __repr__(self):
        return f"Block({self.pos}, {self.offset}, {self.factor})"


class Factory(nn.Module):
    groups: list

    def __init__(self, groups):
        super().__init__()
        self.groups = groups

    def weights(self, clone: bool = True):
        if clone:
            return [b.clone() for b in self.buffers()]
        else:
            return list(self.buffers())

    def update_weights(self, values):
        for i, buf in enumerate(self.buffers()):
            buf.copy_(values[i])

    def disassemble(self, values: list):
        for i, v in enumerate(values):
            g = self.groups[i]
            if isinstance(g, BlockDiagGroups):
                vs = g.split(v)
                for vi in vs:
                    yield vi
            else:
                yield v

    def assemble(self, values: list):
        j = 0
        for g in self.groups:
            if isinstance(g, BlockDiagGroups):
                n = len(g.groups)
                value = g.merge(values[j : j + n])
                j += n
                yield value
            else:
                yield values[j]
                j += 1


class MeanFactory(Factory):
    blocks: list[Block]

    def __init__(
        self,
        groups,
    ):
        super().__init__(groups)
        groups_flat = list(flatten_group_settings(groups))

        blocks = mean_blocks(groups_flat)
        self.blocks = nn.ModuleList(blocks)

    def avg(self, values):
        vals = list(self.disassemble(values))
        v = vals[0]
        ones = torch.ones((1,), device=v.device, dtype=v.dtype)
        for block in self.blocks:
            p = block.pos
            lhs, rhs = (ones, values[p[1]]) if block.transpose else (values[p[0]], ones)
            block.update_with_vectors(lhs, rhs)

        avgs = [torch.zeros_like(v) for v in vals]
        for block in self.blocks:
            avg = block.cov(surrogate=False, diag=True)
            pos = block.pos

            if block.transpose:
                p = pos[1]
                avg = avg.T
                shape = block.factor.eq.shape[1]
            else:
                p = pos[0]
                shape = block.factor.eq.shape[0]

            avgs[p] = avg.reshape(shape)

        out = list(self.assemble(avgs))
        return out


class CovFactory(Factory):
    diag_blocks: list[Block]
    off_blocks: list[Block]

    def __init__(
        self,
        groups,
        surrogate: bool = True,
        block_diag_only: bool = False,
    ):
        super().__init__(groups)
        groups_all = list(flatten_group_settings(groups))

        self.surrogate = surrogate
        triu = triu_indices_grid(groups_all, surrogate=surrogate)

        diag_blocks, off_blocks = triu_blocks(triu, groups_all)

        self.triu = triu
        self.diag_blocks = nn.ModuleList(diag_blocks)
        self.off_blocks = nn.ModuleList(off_blocks)
        self.block_diag_only = block_diag_only

    def cov(self):
        surrogate = self.surrogate
        if not self.block_diag_only:
            size = self.triu.size
            cov = torch.zeros((size, size))

            for block in self.diag_blocks:
                block.cov(out=cov, diag=True, surrogate=surrogate)

            for block in self.off_blocks:
                block.cov(out=cov, surrogate=surrogate)

            return cov
        else:
            cov = []
            for block in self.diag_blocks:
                c = block.cov(diag=True, surrogate=surrogate)
                cov.append(c)

            return cov

    def matvec(self, vectors):
        vecs: list[Tensor] = list(self.disassemble(vectors))
        cumvecs: list[Tensor] = [torch.zeros_like(v) for v in vecs]

        for block in self.diag_blocks:
            d = block.pos[0]
            vec = vecs[d]
            res, _ = block.matvec(vec)
            cumvecs[d] += res

        for block in self.off_blocks:
            i, j = block.pos
            lhs, rhs = vecs[j], vecs[i]
            lhs_res, rhs_res = block.matvec(lhs, rhs)
            assert rhs_res is not None
            cumvecs[i] += lhs_res
            cumvecs[j] += rhs_res

        out = list(self.assemble(cumvecs))
        return out

    def cov_update(self, cov):
        blocks = [*self.diag_blocks, *self.off_blocks]
        surrogate = self.surrogate
        for block in blocks:
            r, c = block.offset
            rs, cs = block.shape(surrogate=surrogate)
            cov_block = cov[r : r + rs, c : c + cs]
            block.factor.cov_estimate_(cov_block, surrogate=surrogate)

    def outer_update(self, updates):
        upds = list(self.disassemble(updates))

        for block in self.diag_blocks:
            i, _ = block.pos
            vec = upds[i]
            block.update_with_vectors(vec, vec)

        for block in self.off_blocks:
            i, j = block.pos
            lhs, rhs = upds[i], upds[j]
            block.update_with_vectors(lhs, rhs)


class TriuGrid(NamedTuple):
    size: int
    diag_indices: np.ndarray
    off_indices: np.ndarray
    offsets: np.ndarray


def triu_indices_grid(
    groups,
    surrogate: bool = True,
    origin: int = 0,
) -> TriuGrid:
    n = len(groups)
    triu = np.triu_indices(n)

    sizes = []
    for g in groups:
        eq = Eq[g, g]()
        shape = eq.stable_shape if surrogate else eq.shape
        shape = flat_shape(shape)
        height = shape[0]
        sizes.append(height)

    cumsizes = np.cumsum(sizes)
    size = cumsizes[-1]
    offsets = np.array([0, *cumsizes[:-1]], dtype=int)
    offsets += origin

    mask = triu[0] == triu[1]
    diag = triu[0][mask]
    off_x = triu[0][~mask]
    off_y = triu[1][~mask]

    diag_inds = np.array((diag, diag), dtype=int).T
    off_inds = np.array((off_x, off_y), dtype=int).T

    return TriuGrid(size, diag_inds, off_inds, offsets)


def mean_blocks(groups):
    o = np.array((0, 0), dtype=int)
    blocks = []
    for i, gr in enumerate(groups):
        gc = I["__N__", 1]
        p = np.array((i, 0), dtype=int)
        block = Block((gr, gc), pos=p, offset=o)
        if not block.is_zero:
            blocks.append(block)

    return blocks


def triu_blocks(
    triu: TriuGrid, groups, block_diag_only: bool = False
) -> tuple[list[Block], list[Block]]:
    n = len(groups)

    diag_blocks = []

    for i in range(n):
        pos = triu.diag_indices[i]
        g = groups[i]
        offset = triu.offsets[pos]

        block = Block(
            (g, g),
            pos=pos,
            offset=offset,
        )

        if block.is_zero:
            raise ValueError(f"Expected non-zero diagonal factor for {g}")

        diag_blocks.append(block)

    if block_diag_only:
        return diag_blocks, []

    off_blocks = []

    # def off_blocks_fn(i: int):
    #     pos = triu.off_indices[i]
    #     offset = triu.offsets[pos]
    #     gx = groups[pos[0]]
    #     gy = groups[pos[1]]

    #     block = Block(
    #         (gx, gy),
    #         pos=pos,
    #         offset=offset,
    #     )
    #     return block

    # indices = list(range(triu.off_indices.shape[0]))
    # with ProcessPoolExecutor(max_workers=32) as pool:
    #     res = list(pool.map(off_blocks_fn, indices))
    #     off_blocks = [r for r in res if not r.is_zero]

    for i in range(triu.off_indices.shape[0]):
        pos = triu.off_indices[i]
        offset = triu.offsets[pos]
        gx = groups[pos[0]]
        gy = groups[pos[1]]

        block = Block(
            (gx, gy),
            pos=pos,
            offset=offset,
        )
        if block.is_zero:
            continue

        off_blocks.append(block)

    return diag_blocks, off_blocks


class GroupCollection:
    def __init__(self, groups: Sequence):
        self._groups = groups

    @property
    def groups(self) -> Sequence:
        return self._groups

    @abc.abstractmethod
    def split(self, value: Tensor) -> Sequence[Tensor]: ...

    @abc.abstractmethod
    def merge(self, values) -> Tensor: ...


class BlockDiagGroups(GroupCollection):
    """Identity kronecker groups I ⊗ G"""

    def __init__(self, diag_groups, dim: int):
        super().__init__(diag_groups)
        self.dim = dim

    @property
    def size(self) -> int:
        return len(self.groups)

    def split(self, value):
        tensors = torch.split(value, self.size, dim=self.dim)
        return tensors

    def merge(self, values):
        tensor = torch.concat(values, dim=self.dim)
        return tensor


def flatten_group_settings(groups: list):
    out_groups = []
    for g in groups:
        if isinstance(g, GroupCollection):
            out_groups += g.groups
        else:
            out_groups.append(g)
    return out_groups


# class FullGrid(NamedTuple):
#     size: tuple[int, int]
#     offsets: np.ndarray


# def full_indices_grid(
#     groups: tuple[Any, Any],
#     surrogate: bool = True,
#     origin: np.ndarray | None = None,
# ):
#     rs, cs = groups
#     cs = cs if isinstance(cs, list) else [cs]
#     rs = rs if isinstance(rs, list) else [rs]

#     n = len(cs)
#     m = len(rs)

#     col_sizes = []
#     for ni in range(n):
#         eq = Eq[rs[0], cs[ni]]()
#         shape = eq.stable_shape if surrogate else eq.shape
#         shape = flat_shape(shape)
#         height = shape[1]
#         col_sizes.append(height)

#     row_sizes = []
#     for mi in range(m):
#         eq = Eq[rs[mi], cs[0]]()
#         shape = eq.stable_shape if surrogate else eq.shape
#         shape = flat_shape(shape)
#         height = shape[0]
#         row_sizes.append(height)

#     col_cumsizes = np.cumsum(col_sizes)
#     col_size = col_cumsizes[-1]
#     col_offsets = np.array([0, *col_cumsizes[:-1]], dtype=int)

#     row_cumsizes = np.cumsum(row_sizes)
#     row_size = row_cumsizes[-1]
#     row_offsets = np.array([0, *row_cumsizes[:-1]], dtype=int)

#     row_offs, col_offs = np.meshgrid(row_offsets, col_offsets, indexing="ij")
#     offsets = np.concat([row_offs.reshape(-1, 1), col_offs.reshape(-1, 1)], axis=1)

#     size = (row_size, col_size)
#     grid = FullGrid(size, offsets)
#     return grid
