#! -*- coding: utf-8
import typing

import cvxpy as cp
import networkx as nx
import numpy as np
import torch

from .dynamic_graph import DynamicGraph

__all__ = ["Matcha"]

DEFAULT_GML_FILES = {
    11: "./configs/graphs/gml/gaia.gml",
    22: "./configs/graphs/gml/amazon_us.gml",
    40: "./configs/graphs/gml/geantdistance.gml",
    53: "./configs/graphs/gml/n53_seed0.gml",
    79: "./configs/graphs/gml/exodus.gml",
    87: "./configs/graphs/gml/ebone.gml",
}

class Matcha(DynamicGraph):
    def __init__(self, n_nodes: int, gmlfile: str = None,
                 compute_shortest_path: bool = False, C_b: float = 0.5, solver: str = "SCS",
                 penalty: str = "no", nrepeat: int = 5, seed: int = 11):
        if gmlfile is None:
            gmlfile = DEFAULT_GML_FILES[n_nodes]        
        self.delay_matrix = self.compute_delay_matrix(gmlfile,
                                                      compute_shortest_path=compute_shortest_path)
        assert self.delay_matrix is not None, "failed compute delay matrix."
        assert self.delay_matrix.shape[0] == n_nodes, \
            f"Unmatch delay matrix size and graph node: {self.delay_matrix.shape} != {n_nodes}"

        self.p, self.alpha, self.matchings = self.optimize_pa(self.delay_matrix,
                                                              C_b=C_b)

        w = torch.tensor(np.eye(n_nodes))
        super().__init__([w],
                         penalty=penalty, nrepeat=nrepeat, seed=seed)

    def get_neighbors(self, i, idx: int = None) -> typing.Tuple[typing.Dict[int, float], typing.Dict[int, float]]:
        # re-sample mixing matrix
        Wk, E_k = self.sample_mixing()
        self.w_list = [torch.tensor(Wk)]
        return super().get_neighbors(i, idx=idx)

    def compute_delay_matrix(self, gmlname: str = './gml/amazon_us.gml',
                             compute_shortest_path: bool = False,) -> np.ndarray:
        M_over_A = 4.62

        G = nx.read_gml(gmlname)

        nodes = sorted(G.nodes())
        n = len(nodes)

        node_index = {node: i for i, node in enumerate(nodes)}

        d_matrix = np.full((n, n), np.inf)

        for i in range(n):
            d_matrix[i, i] = 0.0
        for u, v, data in G.edges(data=True):
            i = node_index[u]
            j = node_index[v]
            l_ij = G.adj[u][v]["distance"]
            d_value = l_ij * 0.0085 + M_over_A
            d_matrix[i, j] = d_value
            if d_matrix[j, i] == np.inf:
                d_matrix[j, i] = d_value
            else:
                if d_matrix[j, i] != d_value:
                    raise ValueError(
                        f"Edge ({u}, {v}) has inconsistent distances: {d_matrix[j, i]} vs {d_value}")
        if not compute_shortest_path:
            return d_matrix
        for i, j in np.ndindex(n, n):
            # if d_matrix[i, j] == np.inf:
            d_matrix[i, j] = min(d_matrix[i, j], nx.shortest_path_length(
                G, source=nodes[i], target=nodes[j], weight='distance', method='dijkstra'))
        return d_matrix

    def optimize_pa(self, delay_matrix: np.ndarray,
                    C_b: float = 0.5, solver: str = "SCS",
                    ) -> typing.Tuple[np.ndarray, float, typing.List[typing.List[typing.Tuple[int, int]]]]:

        n = delay_matrix.shape[0]

        G = nx.Graph()
        G.add_nodes_from(range(n))
        for i in range(n):
            for j in range(i + 1, n):
                if np.isfinite(delay_matrix[i, j]) and delay_matrix[i, j] > 0.0:
                    G.add_edge(i, j)

        matchings: typing.List[typing.List[typing.Tuple[int, int]]] = []
        H = G.copy()
        while H.number_of_edges() > 0:
            M_edges = nx.maximal_matching(H)
            matchings.append(list(M_edges))
            H.remove_edges_from(M_edges)
        M = len(matchings)

        L_list: typing.List[np.ndarray] = []
        for edges in matchings:
            W = np.zeros((n, n))
            for u, v in edges:
                W[u, v] = W[v, u] = 1.0
            L_list.append(np.diag(W.sum(1)) - W)

        p = cp.Variable(M)
        gamma = cp.Variable()
        Q = np.eye(n) - np.ones((n, n)) / n
        L_tot = sum(p[j] * L_list[j] for j in range(M))
        prob_p = cp.Problem(
            cp.Maximize(gamma),
            [
                L_tot - gamma * Q >> 0,
                p >= 0, p <= 1,
                cp.sum(p) <= C_b * M
            ],
        )
        prob_p.solve(solver=solver, verbose=False)
        p_star = p.value
        if p_star is None:
            raise RuntimeError("p SDP failed")

        L_mean = sum(p_star[j] * L_list[j] for j in range(M))
        L_var = sum(p_star[j] * (1 - p_star[j]) * L_list[j] for j in range(M))

        alpha = cp.Variable()
        beta = cp.Variable()
        rho = cp.Variable()
        S = (np.eye(n)
             - 2 * alpha * L_mean
             + beta * (L_var @ L_var + 2 * L_var)
             - np.ones((n, n)) / n)
        prob_a = cp.Problem(
            cp.Minimize(rho),
            [
                alpha**2 - beta <= 0,
                S << rho * np.eye(n)
            ],
        )
        prob_a.solve(solver=solver, verbose=False)
        alpha_opt = alpha.value
        if alpha_opt is None:
            raise RuntimeError("alpha SDP failed")

        Delta_max = max(max(dict(G.degree()).values()), 1)
        alpha_clipped = min(alpha_opt, 1.0 / Delta_max)

        return p_star, float(alpha_clipped), matchings

    def sample_mixing(self,) -> typing.Tuple[np.ndarray, typing.List[typing.Tuple[int, int]]]:
        n = len({v for m in self.matchings for e in m for v in e})
        Wk = np.eye(n)
        deg = np.zeros(n, dtype=int)
        edges: typing.List[typing.Tuple[int, int]] = []

        for j, Ej in enumerate(self.matchings):
            if self.rs.random() < self.p[j]:
                edges.extend(Ej)
                for u, v in Ej:
                    Wk[u, v] += self.alpha
                    Wk[v, u] += self.alpha
                    deg[u] += 1
                    deg[v] += 1

        Wk[np.diag_indices_from(Wk)] -= self.alpha * deg

        Wk[Wk < 0] = 0.0
        Wk /= Wk.sum(1, keepdims=True)
        return Wk, edges
