import numpy as np 
import time
import os
import glob
import re
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset,DataLoader
from sklearn.metrics import accuracy_score
from sklearn.metrics import log_loss
from matplotlib import pyplot as plt


"""
Auxiliary functions used in the main function
"""
def attack_and_save(save2hd=False, **kwargs):
    """
    Use ByzantiumAttack to generate poisoned data, optionally save it to disk, 
    and then load it onto the target device.
    """    
    n_workers = kwargs.get("n_workers"); byz_ratio = kwargs.get("byz_ratio")
    attack_typ = kwargs.get("attack_typ"); random_state = kwargs.get("random_state")
    X_train = kwargs.get("X_train"); y_train = kwargs.get("y_train")
    data_path = kwargs.get("data_path"); data_name = kwargs.get("data_name")
    device = kwargs.get("device")

    # Create an instance of ByzantiumAttack and generate poisoned data
    byz_model = ByzantiumAttack(n_workers=n_workers, byz_ratio=byz_ratio,
                                attack_typ=attack_typ, random_state=random_state)
    X_all, y_all = byz_model.generate_attack(X_train, y_train)
    Xs_all, ys_all = byz_model.get_data(X_all, y_all)
    X_oracle, y_oracle = byz_model.get_oracle_data(X_all, y_all)
    byz_labels = byz_model.byz_labels

    if save2hd == True:
        # If data_path exists, delete all files in the directory
        if os.path.exists(data_path):
            _ = [os.remove(os.path.join(data_path, file)) 
                 for file in os.listdir(data_path) 
                 if os.path.isfile(os.path.join(data_path, file))]
        # Save data to the specified directory
        save_data(Xs_all, ys_all, data_path, data_name)
        # Load data from file to the target device
        Xs, ys = load_data_to_device(data_path, device, data_name)

    Xs_all = [torch.from_numpy(x).to(device) for x in Xs_all]
    ys_all = [torch.from_numpy(y).to(device) for y in ys_all]
    X_oracle_tensor = torch.from_numpy(X_oracle).to(device)
    y_oracle_tensor = torch.from_numpy(y_oracle).to(device)
    return Xs_all, ys_all, X_oracle_tensor, y_oracle_tensor, byz_labels


# Show one incorrect sample
def plot_sample(Xs_all, ys_all, idx=-3):
    XX = Xs_all[idx].cpu().numpy(); yy = ys_all[idx].cpu().numpy()
    XX_re = XX[0] * 0.3081 + 0.1307
    plt.imshow(XX_re.transpose(1, 2, 0))
    plt.title(yy[0])

"""
Helper functions used in the algorithm
"""

def generate_adj_matrix(M, q, seed=None):
    """
    Generate an undirected Erdős–Rényi graph's adjacency matrix A,
    ensuring no node is isolated (i.e., no row is all zeros).

    Args:
      M    - number of nodes
      q    - edge generation probability
      seed - random seed (optional)

    Returns:
      A - (M x M) symmetric adjacency matrix with zero diagonal and no row being all zero
    """
    if seed is not None:
        np.random.seed(seed)
    r = np.random.rand(M, M)
    A_upper = np.triu((r < q).astype(int), k=1)
    A = A_upper + A_upper.T

    # Ensure no row is entirely zero (no isolated node)
    for i in range(M):
        if not A[i].any():
            j = np.random.randint(M)
            while j == i:
                j = np.random.randint(M)
            A[i, j] = 1
            A[j, i] = 1
    return A

def generate_network(M, typ='circle', q=0.1, seed=None):
    '''
    Generate a network topology.

    Args:
    M -- number of clients
    typ -- network type ('circle', 'er')
    q -- connection probability or degree ratio

    Returns:
    W -- Weighted network matrix (NumPy array)
    '''
    W = np.zeros((M, M), dtype=np.float32)

    if typ == 'circle':
        d = max(int(M * q), 1)
        for i in range(M):
            idx = (np.arange(i + 1, i + 1 + d)) % M
            W[i, idx] = 1
    elif typ == 'er':
        W = generate_adj_matrix(M, q, seed=seed)
    W = W / np.sum(W, axis=1, keepdims=True)
    return W

def process_weights(row, cn):
    norm = torch.norm(row)
    output = torch.exp(-norm * cn)
    return output.item()

def get_neighbors(weight_matrix, include_diag=False):
    weight_matrix = np.array(weight_matrix)
    M = weight_matrix.shape[0]
    result = []
    for i in range(M):
        cols = list(np.nonzero(weight_matrix[i])[0])
        if include_diag and i not in cols:
            cols.append(i)
        result.append(cols)
    return result

def get_adj_from_neib(neighbors, M=None):
    M = M or len(neighbors)
    adj_matrix = np.zeros((M, M), dtype=int)
    for i, neigh in enumerate(neighbors):
        adj_matrix[i, neigh] = 1
    return adj_matrix

"""
Helper functions for file operations
"""

def download_parameters(param_path='.', m=0, n_iter=0):
    param_m_path = f'{param_path}/tmp_weights_{m}_{n_iter}.pth'
    loaded_parameters = torch.load(param_m_path)
    return loaded_parameters

def upload_parameters(param_m, param_path='.', m=0, n_iter=0):
    param_m_path = f'{param_path}/tmp_weights_{m}_{n_iter}.pth'
    torch.save(param_m, param_m_path)

def check_param_path(param_path='.', m=0, n_iter=0):
    filename = f'{param_path}/tmp_weights_{m}_{n_iter}.pth'
    ct = 0
    while not os.path.exists(filename):
        time.sleep(1)
        ct += 1
        if ct > 60:
            print(f'filename: {filename}')
            print(f'client: {m}\t iter: {n_iter}')
            raise Exception(f"Exceeded maximum waiting time {ct}s, something went wrong")
    while os.path.getsize(filename) < 180700:
        time.sleep(0.5)

def save_data(Xs_all, ys_all, save_path, data_name='mnist'):
    """
    Save split data to disk.

    Args:
      Xs_all, ys_all: list of np.arrays containing full data splits
      save_path: directory to save the files

    Saved filenames will look like: mnist_train_0.npz, mnist_train_1.npz, ...
    """
    os.makedirs(save_path, exist_ok=True)
    for m, (X_split, y_split) in enumerate(zip(Xs_all, ys_all)):
        filename = os.path.join(save_path, f"{data_name}_train_{m}.npz")
        np.savez(filename, images=X_split, labels=y_split)
        print(f"\rSaved: {filename}", end='')

def load_data_to_device(save_path, device, data_name='mnist'):
    """
    Load all files in the specified directory that match the pattern {data_name}_train_*.npz.
    Sort them by the number in the filename and load them to the specified device.

    Returns:
      List of tensors for X and y.
    """
    pattern = os.path.join(save_path, f"{data_name}_train_*.npz")
    file_list = glob.glob(pattern)
    
    file_list.sort(key=lambda f: int(re.search(rf"{data_name}_train_(\d+)\.npz", os.path.basename(f)).group(1)))
    
    Xs_list, ys_list = zip(*[
        (
            torch.from_numpy(np.load(f)['images'].reshape([np.load(f)['images'].shape[0], -1])).to(device),
            torch.from_numpy(np.load(f)['labels']).to(device)
        )
        for f in file_list
    ])
    print("\rFiles loaded in order:", file_list, end='')
    print(f"\rLoaded {len(file_list)} files and moved data to {device}.")
    return list(Xs_list), list(ys_list)

def split_and_save_oracle_data(X_all, y_all, M, save_path):
    """
    1. Split X_all and y_all into M parts using np.array_split
    2. Save to files named like mnist_train_oracle_m.npz

    Args:
      X_all, y_all: Full dataset in np.array format
      M: number of splits
      save_path: directory to save the files

    Example filenames: mnist_train_oracle_0.npz, mnist_train_oracle_1.npz, ...
    """
    os.makedirs(save_path, exist_ok=True)

    X_splits = np.array_split(X_all, M)
    y_splits = np.array_split(y_all, M)

    for m, (X_split, y_split) in enumerate(zip(X_splits, y_splits)):
        filename = os.path.join(save_path, f"mnist_train_oracle_{m}.npz")
        np.savez(filename, images=X_split, labels=y_split)
        print(f"Saved: {filename}")

"""
Auxiliary functions for data processing
"""

def shuffle_data(label, X=None, seed=None):
    np.random.seed(seed)
    indices = np.random.permutation(len(label))
    y_shuffled = label[indices]
    if X is not None:
        X_shuffled = X[indices]
        return X_shuffled, y_shuffled
    else:
        return y_shuffled

def fixed_50_buckets_five_labels(labels, X=None, seed=None):
    np.random.seed(seed)
    
    unique_labels = np.unique(labels)
    num_labels = len(unique_labels)
    num_buckets = 50
    labels_per_bucket = 5
    ratio_per_label = 0.04  # The proportion of each label's total samples to be drawn

    # Step 1: Collect sample indices for each label and shuffle them
    label_indices = {label: np.where(labels == label)[0].tolist() for label in unique_labels}
    for label in label_indices:
        np.random.shuffle(label_indices[label])

    label_ptrs = {label: 0 for label in unique_labels}
    label_total = {label: len(label_indices[label]) for label in unique_labels}

    # Step 2: Assign 5 labels to each bucket (sliding window style)
    bucket_label_sets = []
    for i in range(num_buckets):
        label_set = [unique_labels[(i + j) % num_labels] for j in range(labels_per_bucket)]
        bucket_label_sets.append(label_set)

    # Step 3: Distribute samples into each bucket
    buckets = [[] for _ in range(num_buckets)]
    exhausted = False

    while not exhausted:
        exhausted = True
        for i in range(num_buckets):
            bucket_labels = bucket_label_sets[i]
            for label in bucket_labels:
                if label_ptrs[label] < label_total[label]:
                    exhausted = False
                    total = label_total[label]
                    take_n = min(int(total * ratio_per_label), label_total[label] - label_ptrs[label])
                    start = label_ptrs[label]
                    end = start + take_n
                    buckets[i].extend(label_indices[label][start:end])
                    label_ptrs[label] += take_n

    # Step 4: Shuffle samples within each bucket
    for bucket in buckets:
        np.random.shuffle(bucket)

    # Step 5: Return results
    # all_indices = np.concatenate(buckets)

    y_shuffled = [labels[buckets[i]] for i in range(num_buckets)]
    if X is not None:
        X_shuffled = [X[buckets[i]] for i in range(num_buckets)]
        return X_shuffled, y_shuffled, buckets
    else:
        return y_shuffled, buckets

def generate_label_attack(ym, seed=None):
    np.random.seed(seed)    
    K = np.max(ym) + 1
    ym_byz = (ym + 1) % K  # Label transformation
    return ym_byz

def generate_feature_attack(Xm, seed=None):
    if seed is not None:
        np.random.seed(seed)
    noise = np.random.randn(*Xm.shape)  # Generate noise with the same shape as Xm
    Xm_byz = 0.3 * Xm + 3 * noise
    return Xm_byz

"""
Validation and Evaluation
"""
def evaluate(X_test, y_test, model, device):
    if device == 'cuda':
        y_test = y_test.cpu().numpy()
    y_pred = model.predict(X_test)
    acc = accuracy_score(y_test, y_pred)
    print("Test accuracy:", acc)
    y_proba = model.predict_proba(X_test)
    val = log_loss(y_test, y_proba)
    print("Test loss:", val)
    return acc, val

"""
Add evaluate_batch
"""

def evaluate_batch(model, dataloader, device):
    """
    Evaluate the model using a DataLoader.

    Parameters:
      model: Instance of the Optimizer class
      dataloader: DataLoader for the test set, returning (X, y)
      device: 'cpu' or 'cuda'

    Functionality:
      1. Use model.pred_batch and model.predproba_batch to get predictions;
      2. Collect all true labels and compute accuracy and log loss;
      3. Print evaluation results and return (accuracy, log_loss)
    """
    # Collect all true labels using torch.cat and convert to numpy format
    all_y = torch.cat([y if isinstance(y, torch.Tensor) else torch.tensor(y) 
                       for _, y in dataloader], dim=0)
    y_test = all_y.cpu().numpy()

    # Get prediction results (classes) and probabilities using model's batch prediction interface
    y_pred = model.pred_batch(dataloader)
    y_proba = model.predproba_batch(dataloader)

    acc = accuracy_score(y_test, y_pred)
    val = log_loss(y_test, y_proba)
    print("Test accuracy:", acc)
    print("Test loss:", val)
    return acc, val

def evaluate_list(X_test, y_test, model, device='cpu'):
    if device == 'cuda':
        y_test = y_test.cpu().numpy()
    accs = []; val_losss = []    
    y_preds = model.predict(X_test)
    y_probas = model.predict_proba(X_test)
    M = len(y_preds)
    for m in range(M):
        acc = accuracy_score(y_test, y_preds[m])
        val = log_loss(y_test, y_probas[m])
        accs.append(acc); val_losss.append(val)
    return accs, val_losss

"""
torch related
"""
def data_to_tensor(X, Y, device='cpu'):
    """
    Convert data X, Y into PyTorch tensors
    X, Y can be numpy.ndarray or Python lists
    device parameter can choose 'cpu' or 'cuda' etc.
    """
    X_tensor = torch.tensor(X, dtype=torch.float32).to(device)
    Y_tensor = torch.tensor(Y, dtype=torch.long).to(device)
    return X_tensor, Y_tensor


class NpzDataset(Dataset):
    def __init__(self, npz_file_path):
        # Load all data into memory at once
        data = np.load(npz_file_path)
        self.images = data['images']  # [N, H, W, C]
        self.labels = data['labels']  # [N]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        # Convert to PyTorch Tensor
        image = torch.tensor(image, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.long)

        return image, label


"""
ByzantiumAttack class
"""
class ByzantiumAttack():
    def __init__(self, n_workers=10,byz_ratio=0,attack_typ='labelatt',
                random_state=2024): 
        self.n_workers = n_workers
        self.byz_ratio = byz_ratio
        self.attack_typ = attack_typ
        self.random_state = random_state
        
    def generate_attack(self,X_true,y_true):
        """
        -- attack_typ: LabelAtt, FeatureAtt
        """
        n_workers = self.n_workers
        n_samples = int(X_true.shape[0]/n_workers)
        M_byz = int(n_workers * self.byz_ratio)
        M_normal = n_workers - M_byz
        X_true,y_true,buckets = fixed_50_buckets_five_labels(y_true,X=X_true, seed=self.random_state)
        X = X_true * 1.0; y = y_true * 1  
        
        for m in range(M_byz):
            Xm = X[n_samples*(M_normal+m):(M_normal+m+1)*n_samples]
            ym = y[n_samples*(M_normal+m):(M_normal+m+1)*n_samples]
            if self.attack_typ == 'labelatt':
                y[n_samples*(M_normal+m):(M_normal+m+1)*n_samples] = generate_label_attack(ym,
                                                        seed=self.random_state+m)
            elif self.attack_typ == 'featureatt':
                X[n_samples*(M_normal+m):(M_normal+m+1)*n_samples] = generate_feature_attack(Xm,
                                                        seed=self.random_state+m)  
            else:
                print('Attack type error!')
                
        return X,y

    def get_data_on_device(self,X,y,device='cpu'):
        n_workers = self.n_workers
        M_normal = n_workers - int(n_workers * self.byz_ratio)
        X_tensor, y_tensor = data_to_tensor(X, y,
                                            device=device)
        Xs,ys = np.split(X_tensor,n_workers), np.split(y_tensor,
                                                       n_workers)

         # Generate index list and shuffle
        np.random.seed(self.random_state)
        indices = np.arange(n_workers)
        np.random.shuffle(indices)
        # Rebuild the two lists using shuffled indices
        Xs = [Xs[i] for i in indices]; ys = [ys[i] for i in indices]
        
        ## Return byz list (boolean values)
        self.byz_labels = (np.arange(n_workers) >= M_normal)[indices]
        return X_tensor, y_tensor, Xs, ys
    
    def get_data(self,X,y):
        n_workers = self.n_workers
        M_normal = n_workers - int(n_workers * self.byz_ratio)
        Xs,ys = np.split(X,n_workers), np.split(y,n_workers)

         # Generate index list and shuffle
        np.random.seed(self.random_state)
        indices = np.arange(n_workers)
        np.random.shuffle(indices)
        # Rebuild the two lists using shuffled indices
        Xs = [Xs[i] for i in indices]; ys = [ys[i] for i in indices]
        
        ## Return byz list (boolean values)
        self.byz_labels = (np.arange(n_workers) >= M_normal)[indices]
        return Xs, ys
    
    def get_oracle_data(self,X,y):
        """
        The first n_workers - int(n_workers * self.byz_ratio) are normal
        """
        n_workers = self.n_workers
        n_samples = int(X.shape[0]/n_workers)
        M_normal = n_workers - int(n_workers * self.byz_ratio)
        
        X_oracle = X[:(M_normal)*n_samples]
        y_oracle = y[:(M_normal)*n_samples]
        return X_oracle, y_oracle


