import logging
import os
import random
from itertools import chain
from typing import Optional, Tuple, Union
import networkx as nx
import numpy as np
import sklearn.datasets as skdatasets
import torch
from server import DeFW_Server, DVRGTFW_Server, DsgFw_Server

logger = logging.getLogger(__name__)


def set_all_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def LoadDataset(
    dataset: str, device: Union[torch.device, str]='cpu'
) -> Tuple[torch.Tensor, torch.Tensor]:
    base_path = os.getcwd()
    # if dataset in ["rcv1", "gisette", "real-sim"]:
    data_path = os.path.join(base_path, "data", f"{dataset}.bz2")
    data, label = skdatasets.load_svmlight_file(data_path)

    data, label = (
        torch.tensor(data.toarray(), dtype=torch.float32, device=device),
        torch.tensor(label, dtype=torch.float32, device=device),
    )

    logger.info(f"Load {dataset} with the shape {data.shape}")
    return (data, label)


def LoadModel(
    method,
    x: torch.Tensor,
    y: torch.Tensor,
    scale: float,
    n_nodes: int,
    gossip_matrix: torch.Tensor,
    batch_size: Optional[int] = None,
    rho: Optional[float] = None,
    base_lr: float = 1.0,
    is_convex: bool = True,
    **kwargs,
):
    if method == "DVRGTFW":
        logger.info("Utilizing DVRGTFW algorithm")
        logger.info(
            f"The batch size is {batch_size}"
        )
        assert rho is not None
        method = DVRGTFW_Server(
            data=x,
            label=y,
            scale=scale,
            n_nodes=n_nodes,
            gossip_matrix=gossip_matrix,
            rho=rho,
            batch_size=batch_size,
            base_lr=base_lr,
            is_convex=is_convex,
        )
        return method
    elif method == 'DsgFW':
        logger.info("Utilizting DsgFW algorithm")
        method = DsgFw_Server(
            data=x, 
            label=y, 
            scale=scale, 
            n_nodes=n_nodes, 
            gossip_matrix=gossip_matrix, 
            batch_size=batch_size, 
            base_lr=base_lr, 
            is_convex=is_convex,
        )
        return method
    elif method == 'DeFW':
        logger.info("Utilizing DeFW algorithm")
        method = DeFW_Server(
            data=x, 
            label=y, 
            scale=scale, 
            n_nodes=n_nodes, 
            gossip_matrix=gossip_matrix, 
            batch_size=batch_size, 
            base_lr=base_lr, 
            is_convex=is_convex,
        )
        return method
    else:
        raise NotImplementedError("Unsupported method!")


def generate_filename(hyperparam_dict):
    hyperparam = [str(x) for x in chain(*hyperparam_dict.items())]
    name = "_".join(hyperparam)
    name += ".pkl"

    losses_file, steps_file, consensus_time_file = (
        "loss_" + name,
        "step_" + name,
        "consensus_time_" + name,
    )
    return losses_file, steps_file, consensus_time_file



from math import floor
from collections import defaultdict

import numpy as np


def generate_scw(d:int, e:int) -> np.ndarray:
    """Generate strongly connected graph.

    Args:
        d (int): the dimension size 
        e (int): the number of additional edge

    Returns:
        np.ndarray: the strongly connected adjacent matrix
    """
    w = np.zeros((d, d))

    for i in range(d - 1):
        w[i, i+1] = 1
    w[d - 1, 0] = 1

    w = w + w.T

    e_num = 0
    while e_num < e:
        x, y = floor(d * np.random.rand()), floor(d * np.random.rand())
        if x != y:
            e_num += 1
            w[x, y] += 1
    
    return w

class Graph:
 
    def __init__(self, vertices):
        self.V = vertices  # No. of vertices
        self.graph = defaultdict(list)  # default dictionary to store graph
 
    # function to add an edge to graph
    def addEdge(self, u, v):
        self.graph[u].append(v)
 
     # A function used by isSC() to perform DFS
    def DFSUtil(self, v, visited):
 
        visited[v] = True

        for i in self.graph[v]:
            if visited[i] == False:
                self.DFSUtil(i, visited)
 
    def getTranspose(self):
 
        g = Graph(self.V)
 
        for i in self.graph:
            for j in self.graph[i]:
                g.addEdge(j, i)
 
        return g
 
    def isSC(self):
 
        visited =[False]*(self.V)
         
        self.DFSUtil(0,visited)
 
        if any(i == False for i in visited):
            return False
 
        gr = self.getTranspose()
         
        visited =[False]*(self.V)
 
        gr.DFSUtil(0,visited)
 
        if any(i == False for i in visited):
            return False
 
        return True
    
def test_sc(w:np.ndarray) -> bool:
    """Test whether the generated adjacent matrix is strongly connect

    Args:
        w (np.ndarray): the generated adjacent matrix with shape (d, d)

    Returns:
        bool: Pass
    """
    d = len(w)
    graph = Graph(d)
    for i in range(d):
        for j in range(d):
            if i != j and w[i, j] > 0:
                graph.addEdge(i, j)
    
    return graph.isSC()

def get_balance(w:np.ndarray) -> np.ndarray:
    """Post-processing the generated strongly connect adjacent matrix to make it balance.

    Args:
        w (np.ndarray): the generated adjacent matrix with shape (d, d)
    
    Returns:
        np.ndarray: the reference of the balance matrix w.
    """
    worker_num = len(w)
    in_degrees = w.sum(axis=0)
    out_degrees = w.sum(axis=1)

    while not np.all(in_degrees == out_degrees):
        for i in range(worker_num):
            if in_degrees[i] > out_degrees[i]:
                non_zero = np.where(w[i, :] > 0)[0]
                neighbors = [[w[i, idx], idx] for idx in non_zero]
                idx = np.argmin(neighbors, axis=0)[0]
                idx = neighbors[idx][1]

                w[i, idx] += in_degrees[i] - out_degrees[i]
        
        in_degrees = w.sum(axis=0)
        out_degrees = w.sum(axis=1)
    
    return w

def test_balanace(w:np.ndarray) -> bool:
    """test whether the matrix w is balanace.

    Args:
        w (np.ndarray): the adjacent matrix with the shape (d, d)

    Returns:
        bool: whether the matrix is balanace.
    """
    return np.allclose(w.sum(axis=0), w.sum(axis=1))

def get_doubly(w:np.ndarray) -> np.ndarray:
    """post-processing the balance matrix to make it doubly stochastic

    Args:
        w (np.ndarray): the balanace adjacent matrix.

    Returns:
        np.ndarray: the reference to the doubly stochastic matrix.
    """
    d = len(w)
    out_degrees = w.sum(axis=1)
    max_out_degrees = max(out_degrees) + 1

    for i in range(d):
        w[i, i] = max_out_degrees - out_degrees[i]
    
    row_stochastic = w @ np.ones((d, 1))
    col_stochastic = np.ones((1, d)) @ w
    
    assert np.all(row_stochastic == row_stochastic[0])
    assert np.all(col_stochastic == row_stochastic[0])

    print("The double stochastic matrix is constructed")
    positive_w = (w / w.sum(axis=1) + np.eye(d)) / 2
    assert np.all(np.linalg.eigvals(positive_w) >= 0)
    return positive_w


def generate_cycle(d:int, e:int) -> np.ndarray:
    """generate the symmetric double stochastic matrix

    Args:
        d (int): the dimension size

    Returns:
        np.ndarray: the symmetric adjacent matrix
    """
    sc_w = generate_scw(d, 0)
    
    in_degree = sc_w.sum(axis=0)
    for i in range(d):
        for j in range(d):
            if i != j and sc_w[i, j] > 0:
                sc_w[i, j] = 1 / max(in_degree[i], in_degree[j])
    
    for i in range(d):
        sc_w[i, i] = 1 - sc_w[i, :].sum()

    w = (np.eye(d) + sc_w) / 2
    return w

def generate_graph(num_nodes:int, p:float):

    G = nx.gnp_random_graph(num_nodes, p)
    L = nx.laplacian_matrix(G).toarray()
    eigenvalues = np.linalg.eigvals(L)
    lambda_max = np.max(eigenvalues)
    # 构建W矩阵 
    W = np.eye(num_nodes) - L / lambda_max

    # 计算W矩阵的第二小特征值（代数连通性）
    eigenvalues_W = np.linalg.eigvalsh(W)
    eigenvalues_W.sort()
    lambda_2_W = eigenvalues_W[1]

    # 计算 1 - lambda_2(W)
    one_minus_lambda_2_W = 1 - lambda_2_W

    return W
    # w = np.zeros((d, d))

    # for i in range(d - 1):
    #     w[i, i+1] = 1
    # w[d - 1, 0] = 1

    # e_num = 0
    # while e_num < e:
    #     x, y = floor(d * np.random.rand()), floor(d * np.random.rand())
    #     if x != y:
    #         e_num += 1
    #         w[x, y] += 1

    # w = w + w.T

    # for i in range(d):
    #     w[i, i] = 0

    # in_degree = w.sum(axis=0)
    # for i in range(d):
    #     for j in range(d):
    #         if i != j and w[i, j] > 0:
    #             w[i, j] = 1 / max(in_degree[i], in_degree[j])
    
    # for i in range(d):
    #     w[i, i] = 1 - w[i, :].sum()

    # w = (np.eye(d) + w) / 2
    # return w