import numpy as np
import numpy.ma as ma
from sklearn.model_selection import train_test_split

 
def ar1(rho, p):
    '''
    generate AR(1) covariance matrix 
    
    Args:
    rho: correlation coefficient
    p: dimension of the matrix
    '''

    cov = np.zeros([p,p])
    for i, j in np.ndindex(cov.shape):
        cov[i, j] = rho**(abs(i-j))
    return cov


def process_row(row, cn):  
    norm = np.linalg.norm(row)       
    output = np.exp(-norm * cn)
    
    return output


def map_row(tensor, subfunc, n_features=10, cn=10):
    reshaped = tensor.reshape(-1, n_features)
    return np.apply_along_axis(subfunc, 1, reshaped, cn=cn)  


def generate_network(M, mask=None, typ=1, d=1, q=0.5, random_state = 1):
    '''
    Generate a weighted network based on the specified type.

    Args:
        M (int): Number of clients (nodes) in the network.
        mask (numpy.ndarray, optional): A boolean mask for generating fixed-degree networks. Defaults to None.
        typ (int, optional): Type of network to generate. 
                             1: Central-client network, 
                             2: Fixed-degree network, 
                             3: Circle-type network, 
                             4: Erdős–Rényi (ER) graph. Defaults to 1.
        d (int, optional): Degree parameter for network types 2 and 3. Defaults to 1.
        q (float, optional): Link probability for ER graph (type 4). Defaults to 0.5.

    Returns:
        W (numpy.ndarray): The weighted network matrix (M x M).
    '''
    np.random.seed(random_state)
    
    W = np.zeros((M, M), dtype=np.float32)

    if typ == 1:
        # Circle-type network: Nodes are connected in a circle, each with `d` neighbors.
        for i in range(M):
            idx = (np.arange(i + 1, i + 1 + d)) % M  # Wrap around for circular connections
            W[i, idx] = 1  # Set connections
        W = W / np.sum(W, axis=1, keepdims=True)  # Normalize rows
        

    elif typ == 2:
        # Erdős–Rényi (ER) graph: Each edge exists with probability `q`.
        A = np.random.rand(M, M)  # Generate random values for edges
        
        # Determine edges based on probability `q`
        A = (A < q).astype(int)
        
        # Ensure the adjacency matrix is symmetric (undirected graph)
        A = np.triu(A)  # Take the upper triangular part
        A = A + A.T - np.diag(A.diagonal())  # Make symmetric

        # Ensure each node has at least one connection
        for i in range(M):
            if np.sum(A[i]) == 0:  # If node `i` has no connections
                j = np.random.choice([x for x in range(M) if x != i])  # Randomly connect to another node
                A[i, j] = 1
                A[j, i] = 1  # Ensure symmetry
        W = A / np.sum(A, axis=1, keepdims=True)  # Normalize rows
        

    elif typ == 3:
        # Central-client network: All clients are connected to a central node.
        W[0, :] = 1 / (M - 1)  # Central node connects to all others
        W[:, 0] = 1            # All nodes connect to the central node
        W[0, 0] = 0            # No self-connection for the central node
        W = W / np.sum(W, axis=1, keepdims=True)  # Normalize rows

    
    elif typ == 4:
        # Fixed-degree network: Each node is connected to exactly `d` neighbors.
        if mask is None:
            mask = np.ones((M, M), dtype=bool)  # Initialize mask if not provided
        np.fill_diagonal(mask, False)  # Exclude self-connections
        W = np.zeros((M, M), dtype=np.float32)
        for i in range(M):
            possible_nodes = np.where(mask[i])[0]  # Find possible neighbors
            if len(possible_nodes) < d:
                selected = possible_nodes  # If fewer than `d` options, select all
            else:
                selected = np.random.choice(possible_nodes, size=d, replace=False)  # Randomly select `d` neighbors
            W[i, selected] = 1  # Set connections
        W = W / np.sum(W, axis=1, keepdims=True)  # Normalize rows

    
    return W
    
    

    
def generate_data(n_samples,n_features,coefs,intercept=0.0, mean_x = None, corr=0.0, sigma=1.0,
                  model_typ='linear',random_state=None):
    """
    """
    if random_state is not None:
        np.random.seed(random_state)
        
    
    Sigma_x = ar1(corr, n_features)
    if mean_x is None:
        mean_x = np.zeros([n_features])
        
    X = np.random.multivariate_normal(mean=mean_x, cov=Sigma_x, size=n_samples)
    
    if model_typ == 'linear':
        y = (X@coefs + intercept) + sigma * np.random.normal(size = (n_samples,1))
    elif model_typ=='logit':
        prob = 1/(1+np.exp(-X@coefs + intercept))
        y = np.zeros([n_samples, ])
        y[np.random.rand(n_samples) < prob.reshape([-1, ])] = 1         
    y = y.reshape([-1,1])  
    return X,y


def get_matform(Xs=None,ys=None,coefs=None,n_workers=10):
    """
    Prepare stacked parameter coefs_star, 
    covariate X_star and response Y_star
    """
    coefs_star = X_star = y_star = None
    if Xs is not None:
        _, n_features = Xs[0].shape
    else:
        n_features = len(coefs)
    coefs_star = np.zeros([n_features*n_workers, 1])
    if Xs is not None:
        X_blocks = [[Xs[i] if i == j else np.zeros_like(Xs[i]) for j in range(n_workers)] for i in range(n_workers)]
        X_star = np.block(X_blocks)
        y_star = np.vstack(ys)
    if coefs is not None:
        for i in range(n_workers):
            coefs_star[n_features*i:n_features*(i+1)] = coefs * 1.0
    return coefs_star, X_star, y_star


class ByzantiumDFL():
    def __init__(self, n_workers=10,
                 model_typ='linear',coefs_true=None,intercept_true=0.0,
                random_state=None): 
        self.random_state = random_state
        self.model_typ = model_typ
        self.coefs_true = np.atleast_1d(coefs_true)
        self.intercept_true = intercept_true
        self.n_workers = n_workers
        
    def generate_byz_data(self,byz_ratio=0,n_samples=10, is_iid=True, byz_typ = 1): 

        """
        corrs: array([M,])
        is_iid: is X_i's are sampled iid?
        """
        n_workers = self.n_workers
        M_byz = int(n_workers * byz_ratio)
        M_normal = n_workers - M_byz
        
        coefs_true = self.coefs_true
        intercept_true = self.intercept_true
        n_features = len(coefs_true)

        X = np.zeros([n_samples*n_workers,n_features])
        y = np.zeros([n_samples*n_workers,1])
        
        if is_iid: 
            corrs = np.zeros([n_workers])
            means = np.zeros([n_workers, n_features])
        else:
            np.random.seed(self.random_state)
            corrs = np.random.uniform(low=0.2, high=0.3, size=n_workers)
            means = np.random.uniform(low=-0.5, high=0.5, size=[n_workers, n_features]) #np.arange(1,1+n_workers)/n_workers
        
            
        
        for m in range(M_normal):
            X_m,y_m = generate_data(n_samples,n_features,coefs_true,intercept_true,mean_x=means[m], corr=corrs[m],
                model_typ=self.model_typ,random_state=self.random_state+m)
            X[n_samples*m:(m+1)*n_samples] = X_m
            y[n_samples*m:(m+1)*n_samples] = y_m
            
        for m in range(M_byz):
            # generate byzantine
            X_m,y_m = generate_data(n_samples,n_features,coefs_true,intercept_true, mean_x = means[M_normal+m], corr=corrs[M_normal+m],
                    model_typ=self.model_typ,random_state=self.random_state+M_normal+m)

            
            if byz_typ == 1: 
                # Bit-Flipping (BF) Corruption
                if self.model_typ == 'linear':
                    X[n_samples*(M_normal+m):(M_normal+m+1)*n_samples] = X_m
                    y[n_samples*(M_normal+m):(M_normal+m+1)*n_samples] = -y_m
                elif self.model_typ == 'logit':
                    X[n_samples*(M_normal+m):(M_normal+m+1)*n_samples] = X_m
                    y[n_samples*(M_normal+m):(M_normal+m+1)*n_samples] = 1-y_m

            elif byz_typ == 2:
                # Out-of-Distribution (OOD) Corruption
                np.random.seed(self.random_state+M_normal+m)
                V = np.random.uniform(size=[n_samples, n_features])
                X[n_samples*(M_normal+m):(M_normal+m+1)*n_samples] = 0.7 * X_m + V
                y[n_samples*(M_normal+m):(M_normal+m+1)*n_samples] = y_m

            elif byz_typ == 3:
                # Model Parameter (MP) Corruption)
                coefs_byz = coefs_true * 1.0
                coefs_byz[int(0.1*n_features):] = 0
                X_m,y_m = generate_data(n_samples,n_features,coefs_byz,intercept_true, mean_x = means[M_normal+m], corr=corrs[M_normal+m],
                    model_typ=self.model_typ,random_state=self.random_state+M_normal+m)
                
                X[n_samples*(M_normal+m):(M_normal+m+1)*n_samples] = X_m
                y[n_samples*(M_normal+m):(M_normal+m+1)*n_samples] = y_m

        Xs,ys = np.split(X,n_workers), np.split(y,n_workers)

        ## Generate an index list and shuffle it
        np.random.seed(self.random_state)
        indices = np.arange(n_workers)
        np.random.shuffle(indices)
        # Reconstruct two lists using shuffled indices
        Xs = [Xs[i] for i in indices]; ys = [ys[i] for i in indices]
        

        self.byz_labels = (np.arange(n_workers) >= M_normal)[indices]    
   
        return X,y,Xs,ys
    
    

    
def get_adjacency_mat(W,set_diag=False):
    """
    Set all non-zero elements in matrix W to 1, and set the diagonal elements to 1.
    """
    # 将所有非零元素设置为 1
    W = (W != 0).astype(int)
    if set_diag:
        np.fill_diagonal(W, 1)
    return W



def grouped_mean(loss, M):
    """
    Calculate the mean of loss by grouping, with each group containing n rows, ultimately returning an M-dimensional vector.

    Args:
    loss: numpy.ndarray, the original loss vector with shape (n * M,)
    n: number of rows per group
    M: total number of groups

    Returns:
    numpy.ndarray, the average value for each group with shape (M,)
    """
    # 将损失向量 reshape 成 (M, n)，然后按行求平均
    return loss.reshape([M, -1]).mean(axis=1)


def get_prob(z,model_typ='linear'):
    if model_typ == 'linear':
        return z
    elif model_typ == 'logit':
        return 1 / (1 + np.exp(-z))
    elif model_typ == 'probit':
        return np.exp(z)
    else:
        return z
    

def list_train_test_split(Xs, ys, test_size=0.2,
                          random_state=None):
    X_train_list, X_test_list = [], []
    y_train_list, y_test_list = [], []
    
    for X, y in zip(Xs, ys):
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=test_size, random_state=random_state
        )
        X_train_list.append(X_train)
        X_test_list.append(X_test)
        y_train_list.append(y_train)
        y_test_list.append(y_test)
    
    return X_train_list, X_test_list, y_train_list, y_test_list




from scipy.stats import trim_mean
def aggregate_robust(A, vec, method='median', proportiontocut=0.1):

    """
    Perform robust aggregation of parameter vectors across nodes in a decentralized network.

    Args:
    -----------
    A : numpy.ndarray
        Adjacency matrix of shape (n_workers, n_workers). A[i, j] = 1 if node j can transmit data to node i.
    vec : numpy.ndarray
        Parameter vectors of shape (n_workers * n_features, 1). Each node's vector is of length n_features.
    method : str, optional
        The robust aggregation method to use. Options are 'median' or 'trimmed_mean'. Default is 'median'.
    proportiontocut : float, optional
        The proportion of data to trim from each end when using 'trimmed_mean'. Default is 0.1.

    Returns:
    --------
    new_vec : numpy.ndarray
        Aggregated parameter vectors of shape (n_workers * n_features, 1).
    """

    n_workers = A.shape[0]
    n_features = vec.shape[0] // n_workers

    # Reshape vec to (n_workers, n_features)
    vec_reshaped = vec.reshape(n_workers, n_features)

    # Initialize new_vec
    new_vec = np.zeros_like(vec_reshaped)

    for i in range(n_workers):
        # Find neighbors who can transmit data to node i
        neighbors = np.where(A[i, :] == 1)[0]
        # Extract parameter vectors of neighbors
        neighbor_vectors = vec_reshaped[neighbors, :]  # Shape: (len(neighbors), n_features)

        if method == 'median':
            # Compute the median across neighbors for each feature
            aggregated_vector = np.median(neighbor_vectors, axis=0)
        elif method == 'trimmed_mean':
            # For small number of neighbors, use the mean instead
            if len(neighbors) < 3:
                aggregated_vector = np.mean(neighbor_vectors, axis=0)
            else:
                # Compute the trimmed mean across neighbors for each feature
                proportiontocut = np.max([np.min([proportiontocut, 0.5-1/len(neighbors)]), 1/len(neighbors)])
                aggregated_vector = trim_mean(neighbor_vectors, proportiontocut=proportiontocut, axis=0)
        else:
            raise ValueError(f"Unknown method: {method}")

        # Assign the aggregated vector to new_vec
        new_vec[i, :] = aggregated_vector

    # Reshape new_vec back to (n_workers * n_features, 1)
    new_vec = new_vec.reshape(n_workers * n_features, 1)
    return new_vec



