from typing import Any, Callable, Literal, NamedTuple, Sequence

import numpy as np
import torch
from torch import nn

from symo.compiler import Compiler, Factor
from symo.group import flat_shape
from symo.invariance import invariance_from_spec
from symo.utils import flatten

Tensor = torch.Tensor

SurrogateSizes = dict[str, dict[str, int]]


class GroupShapes(NamedTuple):
    shape: tuple[int | tuple[int, ...], ...]
    surr_shape: tuple[int | tuple[int, ...], ...]


class GroupsSpec(NamedTuple):
    groups: tuple[str | tuple[str, ...], ...]
    dim_sizes: dict[str, int]
    dim_surr_sizes: SurrogateSizes
    pre_hooks: tuple[Callable, ...] | None = None
    post_hooks: tuple[Callable, ...] | None = None

    @property
    def num_groups(self) -> int:
        return len(self.groups)

    @property
    def shapes(self) -> GroupShapes:
        sizes = self.dim_sizes
        surr_sizes = self.dim_surr_sizes

        out = []
        surr_out = []
        for g in self.groups:
            if isinstance(g, str):
                name, dim = split_group_dim(g)
                out.append((sizes[dim],))
                surr_out.append((surr_sizes[name][dim],))
            elif isinstance(g, Sequence):
                sub = []
                surr_sub = []
                for s in g:
                    name, dim = split_group_dim(s)
                    size = sizes[dim]
                    surr_size = surr_sizes[name][dim]
                    sub.append(size)
                    surr_sub.append(surr_size)
                out.append(tuple(sub))
                surr_out.append(tuple(surr_sub))

        return GroupShapes(tuple(out), tuple(surr_out))


def groups_spec(
    groups,
    dim_sizes,
    surr_sizes: dict | None = None,
    pre_hooks: tuple[Callable, ...] | None = None,
    post_hooks: tuple[Callable, ...] | None = None,
) -> GroupsSpec:
    if surr_sizes is None:
        surr_sizes = build_surr_dims(groups, dim_sizes)

    gs = GroupsSpec(groups, dim_sizes, surr_sizes, pre_hooks, post_hooks)
    return gs


class Block(nn.Module):
    factor: Compiler
    pos: tuple[int, int] | np.ndarray
    info: tuple[Any, Any]
    _offset: tuple[int, int] | np.ndarray
    _surr_offset: tuple[int, int] | np.ndarray

    def __init__(
        self,
        group: tuple | list,
        shape: tuple[list[int], list[int]],
        surr_shape: tuple[list[int], list[int]],
        pos: tuple[int, int],
        offset: tuple[int, int] | np.ndarray,
        surr_offset: tuple[int, int] | np.ndarray,
        device: str = "cpu",
    ):
        super().__init__()

        self._offset = offset
        self._surr_offset = surr_offset

        self.factor = Compiler(group, shape, surr_shape, device)
        self.pos = pos

    def offset(self, surrogate: bool) -> tuple[int, int] | np.ndarray:
        offset = self._surr_offset if surrogate else self._offset
        return offset

    def cov(
        self,
        out: Tensor | None = None,
        diag: bool = False,
        surrogate: bool = True,
    ):
        cov = self.factor.cov(surrogate=surrogate)
        offset = self.offset(surrogate)
        i, j = offset
        ei, ej = i + cov.shape[0], j + cov.shape[1]

        if diag:
            if out is None:
                return cov
            else:
                out[i:ei, j:ej] = cov
                return out[i:ei, j:ej]

        cov_t = cov.T

        if out is None:
            return cov, cov_t
        else:
            out[i:ei, j:ej] = cov
            out[j:ej, i:ei] = cov_t
            return out[i:ei, j:ej], out[j:ej, i:ei]

    def matvec(
        self, lhs: Tensor, rhs: Tensor | None = None
    ) -> tuple[Tensor, Tensor | None]:
        if rhs is None:
            res = self.factor.matvec(lhs)
            return res, None
        else:
            lhs_res = self.factor.matvec(lhs, transpose=False)
            rhs_res = self.factor.matvec(rhs, transpose=True)
            return lhs_res, rhs_res

    def update_with_vectors(self, lhs, rhs):
        self.factor.outer_estimate_(lhs, rhs)

    def update_with_cov(self, cov, surrogate=False):
        self.factor.cov_estimate_(cov, surrogate=surrogate)

    def shape(self, surrogate: bool = True):
        factor = self.factor
        shape = factor.surr_dims if surrogate else factor.dims
        return flat_shape(shape)

    def __repr__(self):
        return f"Block({self.pos}, {self.offset})"


class Factory(nn.Module):
    groups_spec: GroupsSpec

    def __init__(self, groups_spec):
        super().__init__()
        self.groups_spec = groups_spec

    def weights(self, clone: bool = True):
        ws = [w for n, w in self.named_buffers() if n.endswith(".factor.weights")]
        if clone:
            return tuple([w.clone() for w in ws])
        else:
            return tuple(ws)

    def update_weights(self, values):
        for i, buf in enumerate(self.weights(clone=False)):
            buf.copy_(values[i])


class MeanFactory(Factory):
    blocks: list[Block]

    def __init__(
        self,
        groups_spec: GroupsSpec,
    ):
        super().__init__(groups_spec)
        blocks = mean_blocks(groups_spec)
        self.blocks = nn.ModuleList(blocks)

    def avg(self, values):
        pre_hook = self.groups_spec.pre_hooks
        post_hook = self.groups_spec.post_hooks

        vals = process(values, pre_hook)
        v = vals[0]
        ones = torch.ones((1,), device=v.device, dtype=v.dtype)
        avgs = [torch.zeros_like(v) for v in vals]

        for block in self.blocks:
            i, _ = block.pos
            lhs, rhs = (vals[i], ones)
            block.update_with_vectors(lhs, rhs)
            avg, _ = block.matvec(ones)
            avgs[i] = avg

        out = process(avgs, post_hook)
        return out


class CovFactory(Factory):
    diag_blocks: list[Block]
    off_blocks: list[Block]

    def __init__(
        self,
        groups_spec: GroupsSpec,
        block_diag_only: bool = False,
    ):
        super().__init__(groups_spec)

        triu = triu_indices_grid(groups_spec, surrogate=False)
        surr_triu = triu_indices_grid(groups_spec, surrogate=True)

        diag_blocks, off_blocks = triu_blocks(
            groups_spec,
            triu,
            surr_triu,
            block_diag_only=block_diag_only,
        )

        self.triu = triu
        self.surr_triu = surr_triu
        self.diag_blocks = nn.ModuleList(diag_blocks)
        self.off_blocks = nn.ModuleList(off_blocks)
        self.block_diag_only = block_diag_only

    def cov(self, surrogate: bool = False, device=None, dtype=None):
        if not self.block_diag_only:
            shape = self.cov_shape(surrogate)
            cov = torch.zeros(shape, device=device, dtype=dtype)

            for block in self.diag_blocks:
                block.cov(out=cov, diag=True, surrogate=surrogate)

            for block in self.off_blocks:
                block.cov(out=cov, surrogate=surrogate)

            return cov
        else:
            cov = []
            for block in self.diag_blocks:
                c = block.cov(diag=True, surrogate=surrogate)
                cov.append(c)

            return cov

    def matvec(self, vectors):
        pre_hooks = self.groups_spec.pre_hooks
        post_hooks = self.groups_spec.post_hooks

        vecs = process(vectors, pre_hooks)

        cumvecs: list[Tensor] = [torch.zeros_like(v) for v in vecs]

        for block in self.diag_blocks:
            d = block.pos[0]
            vec = vecs[d]
            res, _ = block.matvec(vec)
            cumvecs[d] += res

        for block in self.off_blocks:
            i, j = block.pos
            lhs, rhs = vecs[j], vecs[i]
            lhs_res, rhs_res = block.matvec(lhs, rhs)
            assert rhs_res is not None
            cumvecs[i] += lhs_res
            cumvecs[j] += rhs_res

        out = process(cumvecs, post_hooks)
        return out

    def cov_shape(self, surrogate: bool = False):
        s = self.surr_triu.size if surrogate else self.triu.size
        return (s, s)

    def cov_block_diag_update(
        self, block_diag_covs: Sequence[Tensor], surrogate: bool = False
    ):
        assert self.block_diag_only
        for cov, block in zip(block_diag_covs, self.diag_blocks):
            block.update_with_cov(cov, surrogate=surrogate)

    def cov_update(self, cov, surrogate: bool = False):
        blocks = [*self.diag_blocks, *self.off_blocks]
        for block in blocks:
            r, c = block.offset(surrogate=surrogate)
            rs, cs = block.shape(surrogate=surrogate)
            re, ce = r + rs, c + cs
            cov_block = cov[r:re, c:ce]
            block.update_with_cov(cov_block, surrogate=surrogate)

    def outer_update(self, updates):
        pre_hooks = self.groups_spec.pre_hooks
        upds = process(updates, pre_hooks)

        for block in self.diag_blocks:
            i, _ = block.pos
            vec = upds[i]
            block.update_with_vectors(vec, vec)

        for block in self.off_blocks:
            i, j = block.pos
            lhs, rhs = upds[i], upds[j]
            block.update_with_vectors(lhs, rhs)


class TriuGrid(NamedTuple):
    size: int
    diag_indices: np.ndarray
    off_indices: np.ndarray
    offsets: np.ndarray


def stable_size(group: Literal["O", "B", "S"], size: int) -> int:
    return {
        "O": {2: 1, 4: 3},
        "B": {2: 1, 4: 3},
        "S": {
            2: 2,
            4: 4,
        },  # bla: need to change to {'S': {2: 4, 4: 4}} for testing !!!!
    }[group][size]


def build_surr_dims(groups, dims):
    surr_dims = {}
    for g in groups:
        types, idx = zip(*[axis.split("_") for axis in g])
        for i in set(idx):
            count = {}
            for tt, x in zip(types, idx):
                if x == i:
                    if tt == "I":
                        if tt not in surr_dims:
                            surr_dims[tt] = {}
                        surr_dims[tt][i] = dims[i]
                    else:
                        if tt not in count:
                            count[tt] = 1
                        else:
                            count[tt] += 1

            for k, v in count.items():
                size = stable_size(k, v * 2)
                if k not in surr_dims:
                    surr_dims[k] = {}
                if i not in surr_dims[k]:
                    surr_dims[k][i] = size
                else:
                    surr_dims[k][i] = max(size, surr_dims[k][i])
    return surr_dims


def triu_indices_grid(
    groups_spec: GroupsSpec,
    surrogate: bool = True,
    origin: int = 0,
) -> TriuGrid:
    n = groups_spec.num_groups
    triu = np.triu_indices(n)
    sizes = []

    group_shapes = groups_spec.shapes
    k = 1 if surrogate else 0
    shapes = group_shapes[k]

    for i, _ in enumerate(groups_spec.groups):
        shape = shapes[i]
        blkdiag_shape = (shape, shape)
        shape = flat_shape(blkdiag_shape)
        height = shape[0]
        sizes.append(height)

    cumsizes = np.cumsum(sizes)

    size = cumsizes[-1]
    offsets = np.array([0, *cumsizes[:-1]], dtype=int)
    offsets += origin

    mask = triu[0] == triu[1]
    diag = triu[0][mask]
    off_x = triu[0][~mask]
    off_y = triu[1][~mask]

    diag_inds = np.array((diag, diag), dtype=int).T
    off_inds = np.array((off_x, off_y), dtype=int).T

    return TriuGrid(size, diag_inds, off_inds, offsets)


def mean_blocks(groups_spec: GroupsSpec):
    blocks = []
    shapes, surr_shapes = groups_spec.shapes

    offs = np.array((0, 0), dtype=int)
    surr_offs = np.array((0, 0), dtype=int)

    gc = ("I_dummy",)
    for i, gr in enumerate(groups_spec.groups):
        pos = np.array((i, 0), dtype=int)

        gr_shape = shapes[i]
        gr_surr_shape = surr_shapes[i]

        group_inv = invariance_from_spec(gr, gc)

        if len(group_inv) == 0:
            continue

        block = Block(
            group_inv,
            (gr_shape, (1,)),
            (gr_surr_shape, (1,)),
            pos=pos,
            offset=offs,
            surr_offset=surr_offs,
        )

        blocks.append(block)

        offs[0] += gr_shape[0]
        surr_offs[0] += gr_surr_shape[0]

    return blocks


def triu_blocks(
    groups_spec: GroupsSpec,
    triu: TriuGrid,
    surr_triu: TriuGrid,
    block_diag_only: bool = False,
) -> tuple[list[Block], list[Block]]:
    diag_blocks = []
    groups = groups_spec.groups
    shapes, surr_shapes = groups_spec.shapes

    for i, g in enumerate(groups_spec.groups):
        pos = triu.diag_indices[i]
        offset = triu.offsets[pos]
        surr_offset = surr_triu.offsets[pos]
        shape = shapes[i]
        surr_shape = surr_shapes[i]

        group_inv = invariance_from_spec(g, g)

        if len(group_inv) == 0:
            raise ValueError(f"Expected non-zero diagonal factor for ({g}, {g})")

        block = Block(
            group_inv,
            (shape, shape),
            (surr_shape, surr_shape),
            pos=pos,
            offset=offset,
            surr_offset=surr_offset,
        )

        diag_blocks.append(block)

    if block_diag_only:
        return diag_blocks, []

    off_blocks = []

    groups = groups_spec.groups
    for k in range(triu.off_indices.shape[0]):
        pos = triu.off_indices[k]
        i, j = pos

        gi = groups[i]
        gj = groups[j]

        offset = triu.offsets[pos]
        surr_offset = surr_triu.offsets[pos]

        gi_shape = shapes[i]
        gi_surr_shape = surr_shapes[i]

        gj_shape = shapes[j]
        gj_surr_shape = surr_shapes[j]

        group_inv = invariance_from_spec(gi, gj)

        if len(group_inv) == 0:
            continue

        block = Block(
            group_inv,
            (gi_shape, gj_shape),
            (gi_surr_shape, gj_surr_shape),
            pos=pos,
            offset=offset,
            surr_offset=surr_offset,
        )

        off_blocks.append(block)

    return diag_blocks, off_blocks


def split_group_dim(value: str):
    group, dim = value.split("_")
    return group, dim


def process(
    values: Sequence[Tensor], funcs: Sequence[Callable] | None
) -> Sequence[Tensor]:
    if funcs is None:
        return values

    return tuple([exec_if_not_none(fn, v) for fn, v in zip(funcs, values)])


def exec_if_not_none(func, value):
    if func is not None and callable(func):
        return func(value)
    return value
