from __future__ import annotations
import numpy as np
from scipy import sparse
from scipy.sparse import linalg as spla
from scipy.sparse.linalg import svds
from scipy.stats import rankdata, kendalltau
from sklearn.neighbors import BallTree


class STAGE:
    def __init__(self, r=None, k=None, rtol=1e-6, neighbour_min=6,
                 pca_full_dim=False, embedding="laplacian"):
        if (r is None) == (k is None):
            raise ValueError("Specify exactly one of r or k")
        self.r = r
        self.k = k
        self.rtol = rtol
        self.neighbour_min = neighbour_min
        self.pca_full_dim = pca_full_dim
        self.embedding = embedding

        self.X = None
        self.edges = None
        self.tangents_raw = None
        self.tangents = None
        self.signs = None
        self.y = None
        self.order = None

    def fit(self, X):
        self.X = np.asarray(X, dtype=np.float64, order="C")
        if self.X.ndim != 2:
            raise ValueError("X must be 2-D")

        self._build_graph()
        self._estimate_tangents()
        self._align_tangents()
        self._embed()
        return self

    def fit_transform(self, X):
        self.fit(X)
        return self.y, self.order

    def _build_graph(self):
        tree = BallTree(self.X)
        if self.r is not None:
            neigh_ind = tree.query_radius(self.X, self.r)
        else:
            neigh_ind = tree.query(self.X, k=self.k + 1,
                                   return_distance=False)

        rows, cols = [], []
        for i, neigh in enumerate(neigh_ind):
            neigh = neigh[neigh != i]
            rows.append(np.full(neigh.size, i, np.int32))
            cols.append(neigh.astype(np.int32))
        rows = np.concatenate(rows)
        cols = np.concatenate(cols)
        mask = rows < cols
        self.edges = np.stack((rows[mask], cols[mask]), axis=1)

    def _estimate_tangents(self):
        n, d = self.X.shape
        adj = [[] for _ in range(n)]
        for i, j in self.edges:
            adj[i].append(j)
            adj[j].append(i)

        tangents = np.empty((n, d))
        rng = np.random.default_rng()

        for i, neigh in enumerate(adj):
            if len(neigh) < 2:
                tangents[i] = np.zeros(d)
                continue
            Y = self.X[neigh] - self.X[neigh].mean(axis=0, keepdims=True)
            m = Y.shape[0]
            if self.pca_full_dim or m < d:
                _, _, vt_top = svds(Y, k=1, which='LM', tol=1e-4)
                tangents[i] = vt_top[0]
            else:
                v = rng.standard_normal(d)
                v /= np.linalg.norm(v)
                for _ in range(3):
                    v = Y.T @ (Y @ v)
                    v /= np.linalg.norm(v)
                tangents[i] = v
        self.tangents_raw = tangents

    def _align_tangents(self):
        n = self.tangents_raw.shape[0]
        r, c = self.edges[:, 0], self.edges[:, 1]
        dot = np.einsum("ij,ij->i", self.tangents_raw[r],
                        self.tangents_raw[c])
        Q = sparse.csr_matrix(
            (np.concatenate([dot, dot]),
             (np.concatenate([r, c]), np.concatenate([c, r]))),
            shape=(n, n))
        eigval, eigvec = spla.eigsh(Q, k=1, which="LA")
        signs = np.sign(eigvec[:, 0])
        signs[signs == 0] = 1.0
        self.signs = signs
        self.tangents = self.tangents_raw * signs[:, None]

    def _embed(self):
        if self.embedding == "laplacian":
            self.y = self._embed_laplacian()
        elif self.embedding == "linreg":
            self.y = self._embed_linreg()
        else:
            raise ValueError("embedding must be laplacian or linreg")
        self.order = rankdata(self.y, method='average')

    def _embed_laplacian(self):
        n = self.X.shape[0]
        r, c = self.edges[:, 0], self.edges[:, 1]
        m = self.edges.shape[0]
        data = np.tile([1.0, -1.0], m)
        row_idx = np.repeat(np.arange(m), 2)
        col_idx = np.concatenate([r, c])
        B = sparse.csr_matrix((data, (row_idx, col_idx)), shape=(m, n))
        w = 0.5 * np.einsum("ij,ij->i",
                            self.tangents[r] + self.tangents[c],
                            self.X[c] - self.X[r])

        def mv(z):
            return B.T @ (B @ z)

        L = spla.LinearOperator((n, n), matvec=mv, dtype=np.float64)
        try:
            y, info = spla.cg(L, B.T @ w, rtol=self.rtol, atol=0.0)
        except TypeError:
            y, info = spla.cg(L, B.T @ w, tol=self.rtol)
        if info != 0:
            raise RuntimeError(f"CG did not converge (info={info})")
        return y - y.mean()

    def _embed_linreg(self):
        n = self.X.shape[0]
        adj = [[] for _ in range(n)]
        for i, j in self.edges:
            adj[i].append(j)
            adj[j].append(i)

        rows, cols, data = [], [], []
        t_vec = []
        row_id = 0
        for i, neigh in enumerate(adj):
            if not neigh:
                continue
            mu_i = self.X[neigh].mean(axis=0)
            v_i = self.tangents[i]
            for j in neigh:
                rows.extend([row_id, row_id])
                cols.extend([j, n + i])
                data.extend([1.0, -1.0])
                t_vec.append(np.dot(self.X[j] - mu_i, v_i))
                row_id += 1

        m = row_id
        A = sparse.csr_matrix((data, (rows, cols)), shape=(m, 2 * n))
        t_vec = np.asarray(t_vec, dtype=np.float64)
        res = spla.lsqr(A, t_vec, atol=0.0, btol=self.rtol)
        y = res[0][:n]
        return y - y.mean()


def stage_embedding(X, *, r=None, k=None, rtol=1e-6, neighbour_min=6,
                    pca_full_dim=False, embedding="laplacian",
                    return_intermediates=False):
    model = STAGE(r=r, k=k, rtol=rtol, neighbour_min=neighbour_min,
                  pca_full_dim=pca_full_dim, embedding=embedding)
    model.fit(X)
    if return_intermediates:
        return model
    return model.y, model.order


def evaluate_kendall_abs(order, truth):
    tau, _ = kendalltau(order, truth)
    return abs(tau)


def evaluate_kendall(order, truth):
    tau, _ = kendalltau(order, truth)
    return tau
