# Copyright 2021 Yifei Ma
# with references from "sklearn.decomposition.LatentDirichletAllocation"
# with the following original authors:
# * Chyi-Kwei Yau (the said scikit-learn implementation)
# * Matthew D. Hoffman (original onlineldavb implementation)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import collections
import functools
import io
import os
import warnings

import numpy as np
import scipy as sp
import torch

import dgl

try:
    from functools import cached_property
except ImportError:
    try:
        from backports.cached_property import cached_property
    except ImportError:
        warnings.warn("cached_property not found - using property instead")
        cached_property = property


class EdgeData:
    def __init__(self, src_data, dst_data):
        self.src_data = src_data
        self.dst_data = dst_data

    @property
    def loglike(self):
        return (self.src_data["Elog"] + self.dst_data["Elog"]).logsumexp(1)

    @property
    def phi(self):
        return (
            self.src_data["Elog"]
            + self.dst_data["Elog"]
            - self.loglike.unsqueeze(1)
        ).exp()

    @property
    def expectation(self):
        return (
            self.src_data["expectation"] * self.dst_data["expectation"]
        ).sum(1)


class _Dirichlet:
    def __init__(self, prior, nphi, _chunksize=int(1e6)):
        self.prior = prior
        self.nphi = nphi
        self.device = nphi.device
        self._sum_by_parts = lambda map_fn: functools.reduce(
            torch.add,
            [
                map_fn(slice(i, min(i + _chunksize, nphi.shape[1]))).sum(1)
                for i in list(range(0, nphi.shape[1], _chunksize))
            ],
        )

    def _posterior(self, _ID=slice(None)):
        return self.prior + self.nphi[:, _ID]

    @cached_property
    def posterior_sum(self):
        return self.nphi.sum(1) + self.prior * self.nphi.shape[1]

    def _Elog(self, _ID=slice(None)):
        return torch.digamma(self._posterior(_ID)) - torch.digamma(
            self.posterior_sum.unsqueeze(1)
        )

    @cached_property
    def loglike(self):
        neg_evid = -self._sum_by_parts(
            lambda s: (self.nphi[:, s] * self._Elog(s))
        )

        prior = torch.as_tensor(self.prior).to(self.nphi)
        K = self.nphi.shape[1]
        log_B_prior = torch.lgamma(prior) * K - torch.lgamma(prior * K)

        log_B_posterior = self._sum_by_parts(
            lambda s: torch.lgamma(self._posterior(s))
        ) - torch.lgamma(self.posterior_sum)

        return neg_evid - log_B_prior + log_B_posterior

    @cached_property
    def n(self):
        return self.nphi.sum(1)

    @cached_property
    def cdf(self):
        cdf = self._posterior()
        torch.cumsum(cdf, 1, out=cdf)
        cdf /= cdf[:, -1:].clone()
        return cdf

    def _expectation(self, _ID=slice(None)):
        expectation = self._posterior(_ID)
        expectation /= self.posterior_sum.unsqueeze(1)
        return expectation

    @cached_property
    def Bayesian_gap(self):
        return 1.0 - self._sum_by_parts(lambda s: self._Elog(s).exp())

    _cached_properties = [
        "posterior_sum",
        "loglike",
        "n",
        "cdf",
        "Bayesian_gap",
    ]

    def clear_cache(self):
        for name in self._cached_properties:
            try:
                delattr(self, name)
            except AttributeError:
                pass

    def update(self, new, _ID=slice(None), rho=1):
        """inplace: old * (1-rho) + new * rho"""
        self.clear_cache()
        mean_change = (self.nphi[:, _ID] - new).abs().mean().tolist()

        self.nphi *= 1 - rho
        self.nphi[:, _ID] += new * rho
        return mean_change


class DocData(_Dirichlet):
    """nphi (n_docs by n_topics)"""

    def prepare_graph(self, G, key="Elog"):
        G.nodes["doc"].data[key] = getattr(self, "_" + key)().to(G.device)

    def update_from(self, G, mult):
        new = G.nodes["doc"].data["nphi"] * mult
        return self.update(new.to(self.device))


class _Distributed(collections.UserList):
    """split on dim=0 and store on multiple devices"""

    def __init__(self, prior, nphi):
        self.prior = prior
        self.nphi = nphi
        super().__init__([_Dirichlet(self.prior, nphi) for nphi in self.nphi])

    def split_device(self, other, dim=0):
        split_sections = [x.shape[0] for x in self.nphi]
        out = torch.split(other, split_sections, dim)
        return [y.to(x.device) for x, y in zip(self.nphi, out)]


class WordData(_Distributed):
    """distributed nphi (n_topics by n_words), transpose to/from graph nodes data"""

    def prepare_graph(self, G, key="Elog"):
        if "_ID" in G.nodes["word"].data:
            _ID = G.nodes["word"].data["_ID"]
        else:
            _ID = slice(None)

        out = [getattr(part, "_" + key)(_ID).to(G.device) for part in self]
        G.nodes["word"].data[key] = torch.cat(out).T

    def update_from(self, G, mult, rho):
        nphi = G.nodes["word"].data["nphi"].T * mult

        if "_ID" in G.nodes["word"].data:
            _ID = G.nodes["word"].data["_ID"]
        else:
            _ID = slice(None)

        mean_change = [
            x.update(y, _ID, rho) for x, y in zip(self, self.split_device(nphi))
        ]
        return np.mean(mean_change)


class Gamma(collections.namedtuple("Gamma", "concentration, rate")):
    """articulate the difference between torch gamma and numpy gamma"""

    @property
    def shape(self):
        return self.concentration

    @property
    def scale(self):
        return 1 / self.rate

    def sample(self, shape, device):
        return torch.distributions.gamma.Gamma(
            torch.as_tensor(self.concentration, device=device),
            torch.as_tensor(self.rate, device=device),
        ).sample(shape)


class LatentDirichletAllocation:
    """LDA model that works with a HeteroGraph with doc->word meta paths.
    The model alters the attributes of G arbitrarily.
    This is inspired by [1] and its corresponding scikit-learn implementation.

    Inputs
    ---
    * G: a template graph or an integer showing n_words
    * n_components: latent feature dimension; automatically set priors if missing.
    * prior: parameters in the Dirichlet prior; default to 1/n_components and 1/n_words
    * rho: new_nphi = (1-rho)*old_nphi + rho*nphi; default to 1 for full gradients.
    * mult: multiplier for nphi-update; a large value effectively disables prior.
    * init: sklearn initializers (100.0, 100.0); the sample points concentrate around 1.0
    * device_list: accelerate word_data updates.

    Notes
    ---
    Some differences between this and sklearn.decomposition.LatentDirichletAllocation:
    * default word perplexity is normalized by training set instead of testing set.

    References
    ---
    [1] Matthew Hoffman, Francis Bach, David Blei. Online Learning for Latent
    Dirichlet Allocation. Advances in Neural Information Processing Systems 23
    (NIPS 2010).
    [2] Reactive LDA Library blogpost by Yingjie Miao for a similar Gibbs model
    """

    def __init__(
        self,
        n_words,
        n_components,
        prior=None,
        rho=1,
        mult={"doc": 1, "word": 1},
        init={"doc": (100.0, 100.0), "word": (100.0, 100.0)},
        device_list=["cpu"],
        verbose=True,
    ):
        self.n_words = n_words
        self.n_components = n_components

        if prior is None:
            prior = {"doc": 1.0 / n_components, "word": 1.0 / n_components}
        self.prior = prior

        self.rho = rho
        self.mult = mult
        self.init = init

        assert not isinstance(device_list, str), "plz wrap devices in a list"
        self.device_list = device_list[:n_components]  # avoid edge cases
        self.verbose = verbose

        self._init_word_data()

    def _init_word_data(self):
        split_sections = np.diff(
            np.linspace(0, self.n_components, len(self.device_list) + 1).astype(
                int
            )
        )
        word_nphi = [
            Gamma(*self.init["word"]).sample((s, self.n_words), device)
            for s, device in zip(split_sections, self.device_list)
        ]
        self.word_data = WordData(self.prior["word"], word_nphi)

    def _init_doc_data(self, n_docs, device):
        doc_nphi = Gamma(*self.init["doc"]).sample(
            (n_docs, self.n_components), device
        )
        return DocData(self.prior["doc"], doc_nphi)

    def save(self, f):
        for w in self.word_data:
            w.clear_cache()
        torch.save(
            {
                "prior": self.prior,
                "rho": self.rho,
                "mult": self.mult,
                "init": self.init,
                "word_data": [part.nphi for part in self.word_data],
            },
            f,
        )

    def _prepare_graph(self, G, doc_data, key="Elog"):
        doc_data.prepare_graph(G, key)
        self.word_data.prepare_graph(G, key)

    def _e_step(self, G, doc_data=None, mean_change_tol=1e-3, max_iters=100):
        """_e_step implements doc data sampling until convergence or max_iters"""
        if doc_data is None:
            doc_data = self._init_doc_data(G.num_nodes("doc"), G.device)

        G_rev = G.reverse()  # word -> doc
        self.word_data.prepare_graph(G_rev)

        for i in range(max_iters):
            doc_data.prepare_graph(G_rev)
            G_rev.update_all(
                lambda edges: {"phi": EdgeData(edges.src, edges.dst).phi},
                dgl.function.sum("phi", "nphi"),
            )
            mean_change = doc_data.update_from(G_rev, self.mult["doc"])
            if mean_change < mean_change_tol:
                break

        if self.verbose:
            print(
                f"e-step num_iters={i+1} with mean_change={mean_change:.4f}, "
                f"perplexity={self.perplexity(G, doc_data):.4f}"
            )

        return doc_data

    transform = _e_step

    def predict(self, doc_data):
        pred_scores = [
            # d_exp @ w._expectation()
            (lambda x: x @ w.nphi + x.sum(1, keepdims=True) * w.prior)(
                d_exp / w.posterior_sum.unsqueeze(0)
            )
            for (d_exp, w) in zip(
                self.word_data.split_device(doc_data._expectation(), dim=1),
                self.word_data,
            )
        ]
        x = torch.zeros_like(pred_scores[0], device=doc_data.device)
        for p in pred_scores:
            x += p.to(x.device)
        return x

    def sample(self, doc_data, num_samples):
        """draw independent words and return the marginal probabilities,
        i.e., the expectations in Dirichlet distributions.
        """

        def fn(cdf):
            u = torch.rand(cdf.shape[0], num_samples, device=cdf.device)
            return torch.searchsorted(cdf, u).to(doc_data.device)

        topic_ids = fn(doc_data.cdf)
        word_ids = torch.cat([fn(part.cdf) for part in self.word_data])
        ids = torch.gather(
            word_ids, 0, topic_ids
        )  # pick components by topic_ids

        # compute expectation scores on sampled ids
        src_ids = (
            torch.arange(ids.shape[0], dtype=ids.dtype, device=ids.device)
            .reshape((-1, 1))
            .expand(ids.shape)
        )
        unique_ids, inverse_ids = torch.unique(
            ids, sorted=False, return_inverse=True
        )

        G = dgl.heterograph(
            {("doc", "", "word"): (src_ids.ravel(), inverse_ids.ravel())}
        )
        G.nodes["word"].data["_ID"] = unique_ids
        self._prepare_graph(G, doc_data, "expectation")
        G.apply_edges(
            lambda e: {"expectation": EdgeData(e.src, e.dst).expectation}
        )
        expectation = G.edata.pop("expectation").reshape(ids.shape)

        return ids, expectation

    def _m_step(self, G, doc_data):
        """_m_step implements word data sampling and stores word_z stats.
        mean_change is in the sense of full graph with rho=1.
        """
        G = G.clone()
        self._prepare_graph(G, doc_data)
        G.update_all(
            lambda edges: {"phi": EdgeData(edges.src, edges.dst).phi},
            dgl.function.sum("phi", "nphi"),
        )
        self._last_mean_change = self.word_data.update_from(
            G, self.mult["word"], self.rho
        )

        if self.verbose:
            print(f"m-step mean_change={self._last_mean_change:.4f}, ", end="")
            Bayesian_gap = np.mean(
                [part.Bayesian_gap.mean().tolist() for part in self.word_data]
            )
            print(f"Bayesian_gap={Bayesian_gap:.4f}")

    def partial_fit(self, G):
        doc_data = self._e_step(G)
        self._m_step(G, doc_data)
        return self

    def fit(self, G, mean_change_tol=1e-3, max_epochs=10):
        for i in range(max_epochs):
            if self.verbose:
                print(f"epoch {i+1}, ", end="")
            self.partial_fit(G)

            if self._last_mean_change < mean_change_tol:
                break
        return self

    def perplexity(self, G, doc_data=None):
        """ppl = exp{-sum[log(p(w1,...,wn|d))] / n}
        Follows Eq (15) in Hoffman et al., 2010.
        """
        if doc_data is None:
            doc_data = self._e_step(G)

        # compute E[log p(docs | theta, beta)]
        G = G.clone()
        self._prepare_graph(G, doc_data)
        G.apply_edges(
            lambda edges: {"loglike": EdgeData(edges.src, edges.dst).loglike}
        )
        edge_elbo = (G.edata["loglike"].sum() / G.num_edges()).tolist()
        if self.verbose:
            print(f"neg_elbo phi: {-edge_elbo:.3f}", end=" ")

        # compute E[log p(theta | alpha) - log q(theta | gamma)]
        doc_elbo = (doc_data.loglike.sum() / doc_data.n.sum()).tolist()
        if self.verbose:
            print(f"theta: {-doc_elbo:.3f}", end=" ")

        # compute E[log p(beta | eta) - log q(beta | lambda)]
        # The denominator n for extrapolation perplexity is undefined.
        # We use the train set, whereas sklearn uses the test set.
        word_elbo = sum(
            [part.loglike.sum().tolist() for part in self.word_data]
        ) / sum([part.n.sum().tolist() for part in self.word_data])
        if self.verbose:
            print(f"beta: {-word_elbo:.3f}")

        ppl = np.exp(-edge_elbo - doc_elbo - word_elbo)
        if G.num_edges() > 0 and np.isnan(ppl):
            warnings.warn("numerical issue in perplexity")
        return ppl


def doc_subgraph(G, doc_ids):
    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
    _, _, (block,) = sampler.sample(
        G.reverse(), {"doc": torch.as_tensor(doc_ids)}
    )
    B = dgl.DGLGraph(
        block._graph, ["_", "word", "doc", "_"], block.etypes
    ).reverse()
    B.nodes["word"].data["_ID"] = block.nodes["word"].data["_ID"]
    return B


if __name__ == "__main__":
    print("Testing LatentDirichletAllocation ...")
    G = dgl.heterograph(
        {("doc", "", "word"): [(0, 0), (1, 3)]}, {"doc": 2, "word": 5}
    )
    model = LatentDirichletAllocation(n_words=5, n_components=10, verbose=False)
    model.fit(G)
    model.transform(G)
    model.predict(model.transform(G))
    if hasattr(torch, "searchsorted"):
        model.sample(model.transform(G), 3)
    model.perplexity(G)

    for doc_id in range(2):
        B = doc_subgraph(G, [doc_id])
        model.partial_fit(B)

    with io.BytesIO() as f:
        model.save(f)
        f.seek(0)
        print(torch.load(f))

    print("Testing LatentDirichletAllocation passed!")
