import numpy as np 
import time
import os
import glob
import re
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset,DataLoader
from sklearn.metrics import accuracy_score
from sklearn.metrics import log_loss
from matplotlib import pyplot as plt
from collections import defaultdict
"""
Helper functions in the main function
"""
def attack_and_save(save2hd=False, **kwargs):
    """
    Use ByzantiumAttack to generate attacked data, save it to files, then load it to the device.
    """    
    n_workers = kwargs.get("n_workers"); byz_ratio = kwargs.get("byz_ratio")
    attack_typ = kwargs.get("attack_typ"); random_state = kwargs.get("random_state")
    X_train = kwargs.get("X_train"); y_train = kwargs.get("y_train")
    data_path = kwargs.get("data_path"); data_name = kwargs.get("data_name")
    device = kwargs.get("device")
    # Construct ByzantiumAttack instance and generate attacked data
    byz_model = ByzantiumAttack(n_workers=n_workers, byz_ratio=byz_ratio,
                                attack_typ=attack_typ, random_state=random_state)
    '''
    X_all, y_all = byz_model.generate_attack(X_train, y_train)
    Xs_all, ys_all = byz_model.get_data(X_all, y_all)
    X_oracle, y_oracle = byz_model.get_oracle_data(X_all, y_all)
    '''
    X_all, y_all, buckets = byz_model.generate_attack(X_train, y_train)
    Xs_all, ys_all = byz_model.get_data(X_all, y_all)
    X_oracle, y_oracle = byz_model.get_oracle_data(X_train, y_train, buckets)
    byz_labels = byz_model.byz_labels
    if save2hd == True:
        # If data_path exists, delete all files in that directory
        if os.path.exists(data_path):
            _ = [os.remove(os.path.join(data_path, file)) 
                 for file in os.listdir(data_path) 
                 if os.path.isfile(os.path.join(data_path, file))]
        # Save data to the specified directory
        save_data(Xs_all, ys_all, data_path, data_name)
        # Load data from files to the specified device
        Xs, ys = load_data_to_device(data_path, device, data_name)
    Xs_all = [torch.from_numpy(x).to(device) for x in Xs_all]
    ys_all = [torch.from_numpy(y).to(device) for y in ys_all]
    X_oracle_tensor = torch.from_numpy(X_oracle).to(device)
    y_oracle_tensor = torch.from_numpy(y_oracle).to(device)
    return Xs_all, ys_all, X_oracle_tensor, y_oracle_tensor, byz_labels

# # Show one incorrect image
def plot_sample(Xs_all, ys_all, idx=-3):
    XX = Xs_all[idx].cpu().numpy(); yy = ys_all[idx].cpu().numpy()
    XX_re = XX[0]*0.3081 + 0.1307
    plt.imshow(XX_re.transpose(1, 2, 0))
    plt.title(yy[0])

"""
Helper functions for the algorithm
"""

def generate_adj_matrix(M, q, seed=None):
    """
    Generate an adjacency matrix A for an undirected Erdős–Rényi graph, ensuring no isolated nodes 
    (i.e., no row is all zeros).
    
    Parameters:
      M    - number of nodes
      q    - edge probability
      seed - random seed (optional)
      
    Returns:
      A - (M x M) symmetric adjacency matrix with 0s on the diagonal and no all-zero rows
    """
    if seed is not None:
        np.random.seed(seed)
    r = np.random.rand(M, M)
    A_upper = np.triu((r < q).astype(int), k=1)
    A = A_upper + A_upper.T
    for i in range(M):
        if not A[i].any():  # this row is all 0s
            j = np.random.randint(M)
            while j == i:
                j = np.random.randint(M)
            A[i, j] = 1
            A[j, i] = 1
    return A

def generate_network(M, typ='circle', q=0.1, seed=None):
    '''
    Args:
    M -- number of clients
    typ -- network type (circle, er)
    q -- degree ratio

    Returns:
    W -- The Weighted Network W (a NumPy array)
    '''
    W = np.zeros((M, M), dtype=np.float32)

    if typ == 'circle':
        d = max(int(M * q), 1)
        for i in range(M):
            idx = (np.arange(i + 1, i + 1 + d)) % M
            W[i, idx] = 1
    elif typ == 'er':
        W = generate_adj_matrix(M, q, seed=seed)
    W = W / np.sum(W, axis=1, keepdims=True)
    return W

def process_weights(row, cn):
    norm = torch.norm(row)
    output = torch.exp(-norm * cn)
    return output.item()

def get_neighbors(weight_matrix, include_diag=False):
    weight_matrix = np.array(weight_matrix)  # ensure input is a NumPy array
    M = weight_matrix.shape[0]
    result = []
    for i in range(M):
        cols = list(np.nonzero(weight_matrix[i])[0])
        if include_diag and i not in cols:
            cols.append(i)
        result.append(cols)
    return result

def get_adj_from_neib(neighbors, M=None):
    M = M or len(neighbors)
    adj_matrix = np.zeros((M, M), dtype=int)
    for i, neigh in enumerate(neighbors):
        adj_matrix[i, neigh] = 1
    return adj_matrix

"""
Helper functions for file operations
"""
def download_parameters(param_path='.', m=0, n_iter=0):
    param_m_path = f'{param_path}/tmp_weights_{m}_{n_iter}.pth'
    loaded_parameters = torch.load(param_m_path)
    return loaded_parameters

def upload_parameters(param_m, param_path='.', m=0, n_iter=0):
    param_m_path = f'{param_path}/tmp_weights_{m}_{n_iter}.pth'
    torch.save(param_m, param_m_path)

def check_param_path(param_path='.', m=0, n_iter=0):
    filename = f'{param_path}/tmp_weights_{m}_{n_iter}.pth'
    ct = 0
    while not os.path.exists(filename):
        time.sleep(1)
        ct += 1
        if ct > 60:
            print(f'filename:{filename}')
            print(f'client:{m}\t iter:{n_iter}')
            raise Exception(f"Exceeded max wait time {ct}s, something went wrong")
    while(os.path.getsize(filename) < 180700):  # Wait until file is fully written
        time.sleep(0.5)

def save_data(Xs_all, ys_all, save_path, data_name='mnist'):
    """
    Parameters:
      Xs_all, ys_all: np.array-formatted complete data, lists
      save_path: directory to save files
      
    Save files with names like: mnist_train_0.npz, mnist_train_1.npz, ...
    """
    os.makedirs(save_path, exist_ok=True)
    for m, (X_split, y_split) in enumerate(zip(Xs_all, ys_all)):
        filename = os.path.join(save_path, f"{data_name}_train_{m}.npz")
        np.savez(filename, images=X_split, labels=y_split)
        print(f"\rSaved: {filename}", end='')

def load_data_to_device(save_path, device, data_name='mnist'):
    """
    Traverse all files in the specified directory matching the pattern {data_name}_train_*.npz,
    Load data in order of train_index in the filename,
    Reshape images and move them as PyTorch tensors to the given device,
    Also convert labels.
    """
    pattern = os.path.join(save_path, f"{data_name}_train_*.npz")
    file_list = glob.glob(pattern)
    file_list.sort(key=lambda f: int(re.search(rf"{data_name}_train_(\d+)\.npz", os.path.basename(f)).group(1)))
    
    Xs_list, ys_list = zip(*[
        (
            torch.from_numpy(np.load(f)['images'].reshape([np.load(f)['images'].shape[0], -1])).to(device),
            torch.from_numpy(np.load(f)['labels']).to(device)
        )
        for f in file_list
    ])
    print("\rFiles loaded in order:", file_list, end='')
    print(f"\rLoaded {len(file_list)} files and moved data to {device}.")
    return list(Xs_list), list(ys_list)

def split_and_save_oracle_data(X_all, y_all, M, save_path):
    """
    1. Split X_all and y_all into M parts using np.array_split
    2. Save as mnist_train_oracle_m in the specified path
    
    Parameters:
      X_all, y_all: np.array-format complete data
      M: number of splits
      save_path: directory to save files
      
    Save file names like: mnist_train_oracle_0.npz, mnist_train_oracle_1.npz, ...
    """
    os.makedirs(save_path, exist_ok=True)
    X_splits = np.array_split(X_all, M)
    y_splits = np.array_split(y_all, M)
    
    for m, (X_split, y_split) in enumerate(zip(X_splits, y_splits)):
        filename = os.path.join(save_path, f"mnist_train_oracle_{m}.npz")
        np.savez(filename, images=X_split, labels=y_split)
        print(f"Saved: {filename}")

"""
Helper functions for data processing
"""

def shuffle_data(label, X=None, seed=None):
    np.random.seed(seed)
    indices = np.random.permutation(len(label))
    y_shuffled = label[indices]
    if X is not None:
        X_shuffled = X[indices]
        return X_shuffled, y_shuffled
    else:
        return y_shuffled

def fixed_50_buckets_five_labels(labels, X=None, seed=None):
    np.random.seed(seed)
    
    unique_labels = np.unique(labels)
    num_labels = len(unique_labels)
    num_buckets = 50
    labels_per_bucket = 5
    ratio_per_label = 0.04  # The proportion of each label to be sampled from its total

    # Step 1: Collect and shuffle the indices corresponding to each label
    label_indices = {label: np.where(labels == label)[0].tolist() for label in unique_labels}
    for label in label_indices:
        np.random.shuffle(label_indices[label])

    label_ptrs = {label: 0 for label in unique_labels}
    label_total = {label: len(label_indices[label]) for label in unique_labels}

    # Step 2: Assign 5 labels to each bucket (sliding window style)
    bucket_label_sets = []
    for i in range(num_buckets):
        label_set = [unique_labels[(i + j) % num_labels] for j in range(labels_per_bucket)]
        bucket_label_sets.append(label_set)

    # Step 3: Allocate samples to each bucket
    buckets = [[] for _ in range(num_buckets)]
    exhausted = False

    while not exhausted:
        exhausted = True
        for i in range(num_buckets):
            bucket_labels = bucket_label_sets[i]
            for label in bucket_labels:
                if label_ptrs[label] < label_total[label]:
                    exhausted = False
                    total = label_total[label]
                    take_n = min(int(total * ratio_per_label), label_total[label] - label_ptrs[label])
                    start = label_ptrs[label]
                    end = start + take_n
                    buckets[i].extend(label_indices[label][start:end])
                    label_ptrs[label] += take_n

    # Step 4: Shuffle within each bucket
    for bucket in buckets:
        np.random.shuffle(bucket)

    # Step 5: Return the results
    y_shuffled = [labels[buckets[i]] for i in range(num_buckets)]
    if X is not None:
        X_shuffled = [X[buckets[i]] for i in range(num_buckets)]
        return X_shuffled, y_shuffled, buckets
    else:
        return y_shuffled, buckets

def generate_label_attack(ym, seed=None):
    np.random.seed(seed)    
    K = np.max(ym) + 1
    ym_byz = (ym + 1) % K  # Label transformation
    return ym_byz

def generate_feature_attack(Xm, seed=None):
    '''
    Add Gaussian noise to the features.
    '''
    if seed is not None:
        np.random.seed(seed)
    noise = np.random.randn(*Xm.shape).astype(Xm.dtype)  # Ensure dtype consistency
    Xm_byz = 0.3 * Xm + 3 * noise
    return Xm_byz

# def generate_feature_attack(Xm, seed=None):
#     """
#     Multiply half of the matrix (in the height dimension H) of each channel by -2.
#     Assumes "half matrix" refers to the upper half in height.
#     
#     Args:
#     Xm: Tensor or array of shape (N, C, H, W), e.g., a NumPy array
#     seed: Random seed for reproducibility, set to None if not needed
#     
#     Returns:
#     Modified Xm
#     """
#     if seed is not None:
#         np.random.seed(seed)
#     N, C, H, W = Xm.shape
#     half_H = H // 2  # Define half of the height
#     Xm_byz = Xm.copy()
#     Xm_byz[:, :, :half_H, :] = -0.2 * Xm_byz[:, :, :half_H, :] + np.random.randn(*Xm_byz[:, :, :half_H, :].shape)
#     # To use the bottom half instead, replace with Xm[:, :, half_H:, :] *= -2
#     return Xm_byz

"""
Validation and evaluation
"""
def evaluate(X_test, y_test, model, device):
    if device == 'cuda':
        y_test = y_test.cpu().numpy()
    y_pred = model.predict(X_test)
    acc = accuracy_score(y_test, y_pred)
    print("Test accuracy:", acc)
    y_proba = model.predict_proba(X_test)
    val = log_loss(y_test, y_proba)
    print("Test loss:", val)
    return acc, val 

"""
Add evaluate_batch
"""

def evaluate_batch(model, dataloader, device):
    """
    Evaluate the model using a DataLoader.

    Args:
      model: An instance of the Optimizer class
      dataloader: Test DataLoader returning (X, y) data
      device: 'cpu' or 'cuda'

    Functionality:
      1. Use model.pred_batch and model.predproba_batch to get predictions;
      2. Collect all true labels and compute accuracy and log loss;
      3. Print evaluation results and return (accuracy, log_loss)
    """
    all_y = torch.cat([y if isinstance(y, torch.Tensor) else torch.tensor(y) 
                       for _, y in dataloader], dim=0)
    y_test = all_y.cpu().numpy()

    y_pred = model.pred_batch(dataloader)
    y_proba = model.predproba_batch(dataloader)

    acc = accuracy_score(y_test, y_pred)
    val = log_loss(y_test, y_proba)
    print("Test accuracy:", acc)
    print("Test loss:", val)
    return acc, val

def evaluate_list(X_test, y_test, model, device='cpu'):
    if device == 'cuda':
        y_test = y_test.cpu().numpy()
    accs = []
    val_losss = []    
    y_preds = model.predict(X_test)
    y_probas = model.predict_proba(X_test)
    M = len(y_preds)
    for m in range(M):
        acc = accuracy_score(y_test, y_preds[m])
        val = log_loss(y_test, y_probas[m])
        accs.append(acc)
        val_losss.append(val)
    return accs, val_losss

"""
torch related
"""
def data_to_tensor(X, Y, device='cpu'):
    """
    Convert data X, Y to PyTorch tensors.
    X, Y can be numpy.ndarray or Python lists.
    The device parameter can be 'cpu', 'cuda', etc.
    """
    X_tensor = torch.tensor(X, dtype=torch.float32).to(device)
    Y_tensor = torch.tensor(Y, dtype=torch.long).to(device)
    return X_tensor, Y_tensor


class NpzDataset(Dataset):
    def __init__(self, npz_file_path):
        # Load all data into memory at once
        data = np.load(npz_file_path)
        self.images = data['images']  # [N, H, W, C]
        self.labels = data['labels']  # [N]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        # Convert to PyTorch Tensor
        image = torch.tensor(image, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.long)

        return image, label


"""
ByzantiumAttack class
"""
class ByzantiumAttack():
    def __init__(self, n_workers=10,byz_ratio=0,attack_typ='labelatt',
                random_state=2024): 
        self.n_workers = n_workers
        self.byz_ratio = byz_ratio
        self.attack_typ = attack_typ
        self.random_state = random_state

    def generate_attack(self,X_true,y_true):

        """
        -- attack_typ: LabelAtt, FeatureAtt
           labelatt: Labels are randomly shuffled
           featureatt: Xi = 0.7Xi + Vi; Vi ~ U(0,1)
        """
        n_workers = self.n_workers
        M_byz = int(n_workers * self.byz_ratio)
        M_normal = n_workers - M_byz

        X_true,y_true,buckets = fixed_50_buckets_five_labels(y_true,X=X_true, seed=self.random_state) # Specifically for MNIST experiments

        X = [x * 1.0 for x in X_true]
        y = [y * 1 for y in y_true]

        for m in range(M_byz):
            Xm = X[m+M_normal]
            ym = y[m+M_normal]
            if self.attack_typ == 'labelatt':
                y[m+M_normal] = generate_label_attack(ym, seed=self.random_state+m)
            elif self.attack_typ == 'featureatt':
                X[m+M_normal] = generate_feature_attack(Xm, seed=self.random_state+m)  
                # X[m+M_normal] = Xm  ## For debugging
            else:
                print('att type error!')
                
        return X,y,buckets

    
    def get_data(self,X,y):
        n_workers = self.n_workers
        M_normal = n_workers - int(n_workers * self.byz_ratio)
        Xs,ys = X,y

        # Generate shuffled indices
        np.random.seed(self.random_state)
        indices = np.arange(n_workers)
        np.random.shuffle(indices)
        # Reorder both lists using the shuffled indices
        Xs = [Xs[i] for i in indices]
        ys = [ys[i] for i in indices]
        
        ## Return byzantine flags as a boolean array
        self.byz_labels = (np.arange(n_workers) >= M_normal)[indices]
        return Xs, ys
    
    def get_oracle_data(self,X,y,buckets):
        """
        The first n_workers - int(n_workers * byz_ratio) are normal
        """
        n_workers = self.n_workers
        n_samples = int(X.shape[0]/n_workers)
        M_normal = n_workers - int(n_workers * self.byz_ratio)
        all_indices = np.concatenate(buckets[:M_normal])
        X_oracle = X[all_indices]
        y_oracle = y[all_indices]
        return X_oracle, y_oracle

