"""
This script defines a Flower client for federated learning experiments reflecting SoteriaFL implementation.
Each client locally trains a model using the received global parameters, add differential privacy and 
perform a shifted compression of the gradients before sending it to the server for aggregation. The client 
also evaluates the model's performance and privacy leakage (MIA).
It can be deployed locally or across distributed systems by adjusting `server_address`.

Clients are assigned dataset shards, train models locally, and optionally evaluate privacy leakage
using canary samples. After training, results and plots are saved to track performance and leakage.
"""

# Libraies
from collections import OrderedDict
from torch.utils.data import DataLoader
import torch
import flwr as fl
import argparse
from torch.utils.data import Subset, random_split
import torch.nn.functional as F
import numpy as np
import copy
import opacus # type: ignore
import gc

import sys
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
sys.path.append(parent_dir)
from public import models
from public import utils
from public import config as cfg



# Define Flower client 
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, 
                 model: torch.nn.Module,
                 train_loader: DataLoader,
                 val_loader: DataLoader,
                 optimizer: torch.optim.Optimizer,
                 criterion: torch.nn.Module,
                 num_examples: dict, 
                 client_id: int,
                 train_fn: callable,
                 evaluate_fn: callable,
                 device: torch.device,
                 privacy_audit: bool = True,
                 canary_frac: float = 0.2, 
                 p_value: float = 0.05,
                 k_plus: float = 1 / 3, 
                 k_min: float = 1 / 3,
                 config: dict = {'dataset':'mnist', 'batch':64},
                 exp_n: int = 0,
                 scaling_dp: int = 0
                ):
        
        # Define the client
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.optimizer = optimizer
        self.criterion = criterion
        self.num_examples = num_examples
        self.client_id = client_id
        self.train_fn = train_fn
        self.evaluate_fn = evaluate_fn
        self.device = device
        self.canary_frac = canary_frac
        self.p_value = p_value
        self.k_plus = k_plus
        self.k_min = k_min
        self.privacy_audit = privacy_audit
        self.privacy_estimate = -1
        self.accuracy_mia = -1
        self.acc_privacy_estimate = -1
        self.acc_accuracy_mia = -1
        self.config = config
        self.acc_scores = None
        self.scaling_dp = scaling_dp
        
        # initialize reference vector
        self.reference_s = []
        for p in self.get_parameters():
            self.reference_s.append(np.zeros_like(p))
        
        # initialize k for sparsity
        n_params = sum(p.numel() for p in self.model.parameters())
        # self.k = int(n_params / np.log2(self.config['rounds'][exp_n]))  # as in SoteriaFL
        self.k = int(n_params / 20) # as in SoteriaFL
        w = (n_params / self.k) - 1
        self.gamma = np.sqrt((1 + 2 * w) / (2 * (1 + w)**3))

        # prepare dataset auditing
        canaries, non_canaries = random_split(self.train_loader.dataset, [self.canary_frac, 1 - self.canary_frac])
        self.n_canaries = len(canaries)
        self.scores = np.zeros(self.n_canaries)

        # subsample canaries & make new dataloader
        true_in_out = torch.distributions.bernoulli.Bernoulli(torch.ones(self.n_canaries) * 0.5).sample()
        self.true_in_out = true_in_out.numpy()
        canaries_in_idx = torch.nonzero(true_in_out)
        subsampled_train_data = torch.utils.data.ConcatDataset([
            non_canaries,
            torch.utils.data.Subset(canaries, canaries_in_idx)
        ])
        self.subsampled_train_loader = DataLoader(subsampled_train_data, batch_size=len(subsampled_train_data), shuffle=True)
        self.canary_loader = DataLoader(canaries, batch_size=len(subsampled_train_data), shuffle=False)

        # local differential privacy initialization
        if True:
            # Calculate sample rate = (batch_size / total_number_of_samples)
            if cfg.privacy_audit:
                sample_rate = min(1.0, self.subsampled_train_loader.batch_size / len(self.subsampled_train_loader.dataset))
            else:
                sample_rate = min(1.0, self.train_loader.batch_size / len(self.train_loader.dataset))

            self.sigma = opacus.accountants.utils.get_noise_multiplier(
                target_epsilon=cfg.epsilon,
                target_delta=cfg.delta,
                sample_rate=sample_rate,
                epochs=int(self.config['epochs']), 
                accountant='rdp',  
            ) 
            # self.sigma = self.config['sigma'][self.scaling_dp]

            self.privacy_engine = opacus.privacy_engine.PrivacyEngine(accountant='rdp', secure_mode=False)
            if cfg.privacy_audit:
                self.model, self.optimizer, self.subsampled_train_loader = self.privacy_engine.make_private(
                    module=self.model,
                    optimizer=self.optimizer,
                    data_loader=self.subsampled_train_loader,
                    noise_multiplier=self.sigma,
                    # max_grad_norm=self.config['sensitivity'][self.scaling_dp],
                    max_grad_norm=cfg.sensitivity
                    )
            else:
                self.model, self.optimizer, self.train_loader = self.privacy_engine.make_private(
                    module=self.model,
                    optimizer=self.optimizer,
                    data_loader=self.train_loader,
                    noise_multiplier=self.sigma,
                    # max_grad_norm=self.config['sensitivity'][self.scaling_dp],
                    max_grad_norm=cfg.sensitivity
                    )           
            
            if client_id == 1:
                print(f"\n\033[94mLocal Differential Privacy with introduced noise_value_sd: {self.sigma}\033[0m\n")


    def get_parameters(self, config=None):
        return [val.cpu().numpy() for _, val in self.model.state_dict().items()]


    def set_parameters(self, parameters):
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        self.model.load_state_dict(state_dict, strict=True)


    def fit(self, params_in, config):
        self.set_parameters(params_in)  

        # privacy auditing
        if self.privacy_audit:
            if True:
                """
                Local Differential Privacy (DP) training
                """
                models.train_with_opacus(
                    self.model, 
                    self.device, 
                    self.subsampled_train_loader, 
                    self.optimizer, 
                    self.criterion, 
                    self.sigma, 
                    self.config["epochs"], 
                    self.client_id,
                    # self.config["sensitivity"][self.scaling_dp]
                    cfg.sensitivity
                    )
            # else:
            #     """
            #     Traditional training without DP
            #     """
            #     for epoch in range(self.config["epochs"]):
            #         self.train_fn(
            #             self.model, 
            #             self.device, 
            #             self.subsampled_train_loader, 
            #             self.optimizer, 
            #             self.criterion, 
            #             epoch, 
            #             self.client_id
            #             )
                
            """
            shifted k-random sparsification on the gradients [SOTERIAFL]
            """
            # calculate gradients
            params_out = self.get_parameters(config)
            grads = [param_out - param_in for param_in, param_out in zip(params_in, params_out)]
            
            # shifted gradient
            shifted_grads = [grad - s for grad, s in zip(grads, self.reference_s)] 
            
            # k-sparsification on the gradients
            shifted_sparse_grads = self.compress_parameters(
                shifted_grads,
                k = self.k
                )
            
            # update reference vector
            self.reference_s = [s + self.gamma * sparse_grad for s, sparse_grad in zip(self.reference_s, shifted_sparse_grads)]
            
            # update model parameters
            params_out = [param_in + grad for param_in, grad in zip(params_in, shifted_sparse_grads)]
            
            # update model weights
            self.set_parameters(params_out)
            
            # normalize client update vector
            client_update = utils.parameters_to_1d(params_out) - utils.parameters_to_1d(params_in)
            client_update = client_update / np.linalg.norm(client_update)

            # compute scores for each canary, used to predict membership            
            scores = []
            # canary_loader = torch.utils.data.DataLoader(canaries, batch_size=cfg.batch_size, shuffle=False)
            if cfg.score_fn == 'whitebox':
                self.set_parameters(params_in)
                for samples, targets in self.canary_loader:
                    scores.extend(self.score_with_pseudograd_batch(samples, targets, client_update))
                self.set_parameters(params_out)
            if cfg.score_fn == 'blackbox':
                # self.set_parameters(params_in)  # TO REMOVE
                for samples, targets in self.canary_loader:
                    scores.extend(self.score_blackbox_batch(samples, targets, client_update))
                # self.set_parameters(params_out) # TO REMOVE
            else:
                NotImplementedError(f'score function {cfg.score_fn} is not known')

            # accumulative leakage
            if self.acc_scores is None:
                self.acc_scores = copy.deepcopy(scores)
            else:
                self.acc_scores = self.acc_scores + np.asarray(scores)

            # lower-bound privacy budget evaluation
            self.accuracy_mia, self.privacy_estimate = self.evaluate_privacy(scores)
            self.acc_accuracy_mia, self.acc_privacy_estimate = self.evaluate_privacy(self.acc_scores)
            
            utils.save_audit_metrics(
                config["current_round"], 
                self.accuracy_mia, 
                self.privacy_estimate, 
                self.acc_accuracy_mia, 
                self.acc_privacy_estimate, 
                self.accuracy_mia, 
                self.privacy_estimate, 
                self.acc_accuracy_mia, 
                self.acc_privacy_estimate, 
                client_id=self.client_id,
                history_folder=f"histories/{self.config['model_name']}/{self.config['dataset']}/"
                )
        
        else: # NO AUDITING
            if True:   
                """
                Local Differential Privacy (DP) training
                """                
                models.train_with_opacus(self.model, 
                    self.device, 
                    self.train_loader, 
                    self.optimizer, 
                    self.criterion, 
                    self.sigma, 
                    self.config["epochs"], 
                    self.client_id,
                    # self.config["sensitivity"][self.scaling_dp]
                    cfg.sensitivity
                    )
            # else:
            #    """
            #    Traditional training without DP
            #    """
            #     for epoch in range(self.config["epochs"]):
            #         self.train_fn(
            #             self.model, 
            #             self.device, 
            #             self.train_loader, 
            #             self.optimizer, 
            #             self.criterion, 
            #             epoch, 
            #             self.client_id
            #             )
                    
            """
            k-random sparsification on the gradients [Part of SOTERIAFL]
            """
            # calculate gradients
            params_out = self.get_parameters(config)
            grads = [param_out - param_in for param_in, param_out in zip(params_in, params_out)]
                
            # k-sparsification on the gradients
            sparse_grads = self.compress_parameters(
                grads,
                k = self.k
                )
            
            # update model parameters
            params_out = [param_in + grad for param_in, grad in zip(params_in, sparse_grads)]
            
            # update model weights
            self.set_parameters(params_out)
            

        gc.collect() 
        torch.cuda.empty_cache() 
        if self.client_id == 1:
            print(f"\033[91mTraining Round: {config["current_round"]}\033[0m")
        return params_out, self.num_examples["train"], {}
    
    
    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        loss, accuracy, f1_score = self.evaluate_fn(self.model, self.device, self.val_loader, self.criterion, self.client_id)
        
        # save loss and accuracy client
        utils.save_client_metrics(config["current_round"], loss, accuracy, f1_score, client_id=self.client_id,
                                    history_folder=f"histories/{self.config['model_name']}/{self.config['dataset']}/")
        
        return float(loss), self.num_examples["val"], {
            "accuracy": float(accuracy),
            "f1_score": float(f1_score),
            "privacy_estimate": self.privacy_estimate,
            "accuracy_mia": self.accuracy_mia,
            "accumulative_privacy_estimate": self.acc_privacy_estimate,
            "accumulative_accuracy_mia": self.acc_accuracy_mia,
            "privacy_estimate_mean": self.privacy_estimate,
            "accuracy_mia_mean": self.accuracy_mia,
            "accumulative_privacy_estimate_mean": self.acc_privacy_estimate,
            "accumulative_accuracy_mia_mean": self.acc_accuracy_mia
        }


    def evaluate_privacy(self, scores):
        ground_truth = copy.deepcopy(self.true_in_out)
        score_indices_sorted = np.argsort(scores)[::-1]
        classified_in = score_indices_sorted[:int(self.n_canaries * self.k_plus + 1)]
        classified_out = score_indices_sorted[int(self.n_canaries * (1 - self.k_min) + 1):]
        abstained = np.setdiff1d(score_indices_sorted, np.concatenate((classified_in, classified_out)))
        classification = np.zeros(self.n_canaries)
        classification[classified_in] = 1
        classification[abstained] = 2
        ground_truth[abstained] = 2
        W = ground_truth == classification
        num_correct = W.sum() - len(abstained)
        accuracy_mia = num_correct / (self.n_canaries - len(abstained))
        
        # tpr = np.sum(classification == true_in_out) / len(canaries_in_idx)
        # tnr = np.sum((1 - classification) == (1 - true_in_out)) / len(canaries_out_idx)
        # fpr = np.sum(classification == (1 - true_in_out)) / len(canaries_out_idx)
        # fnr = np.sum((1 - classification) == true_in_out) / len(canaries_in_idx)

        # compute empirical privacy estimate, which should be < epsilon w/ high probability
        privacy_estimate = utils.get_eps_audit(
            m=self.n_canaries,
            r=self.n_canaries - len(abstained),
            v=num_correct,
            delta=cfg.delta,
            p=0.05)
        
        # Kairouz privacy estimate from https://proceedings.mlr.press/v37/kairouz15.html
        # privacy_estimate = np.max([np.log(1 - cfg.delta - fpr) - np.log(fnr), 
                            # np.log(1 - cfg.delta - fnr) - np.log(fpr)])
                            
        return accuracy_mia, privacy_estimate
    
    
    def score_with_pseudograd_batch(self, samples, targets, client_update):
        '''
        Computes membership inference attack scores for a batch by 
        computing the inner product between the 'pseudogradient'
        represented by client update and the true gradients
        for each sample in the batch.
        '''
        self.model.to(self.device)  # Ensure model is on the correct device
        samples = samples.to(self.device)
        targets = targets.to(self.device)
        
        # Forward pass
        predictions = self.model(samples)
        losses = torch.nn.functional.cross_entropy(predictions, targets, reduction='none')
        
        scores = []
        for loss in losses:
            # Compute gradients for each sample
            audit_grad = torch.autograd.grad(loss, self.model.parameters(), retain_graph=True)
            # audit_grad = parameters_to_1d(audit_grad)
            audit_grad = np.concatenate([x.cpu().flatten() for x in audit_grad])
            score = np.dot(client_update, - audit_grad)
            scores.append(score)
        
        return scores


    def score_blackbox_batch(self, samples, targets, client_update):
        with torch.no_grad():
            self.model.to(self.device)  # Ensure model is on the correct device
            samples = samples.to(self.device)
            targets = targets.to(self.device)
            
            # Forward pass
            predictions = self.model(samples)
            losses = torch.nn.functional.cross_entropy(predictions, targets, reduction='none').cpu()

            return -losses


    def compress_parameters(self, params, k):
        """
        Compresses the model parameters using random-k sparsification.

        Args:
            params (List[np.ndarray]): List of NumPy arrays representing model parameters.
            k (int): Number of coordinates to retain during compression.

        Returns:
            List[np.ndarray]: Compressed parameters where only k coordinates are retained
                            (and scaled by d/k according to the random-k operator).
        """
        # Flatten all parameters (excluding scalars) into a single array
        flattened_params = np.concatenate([p.flatten() for p in params if p.ndim > 0])
        d = flattened_params.size

        # If k >= d, no compression happens (just return original parameters)
        if k >= d:
            print(f"No compression applied: k ({k}) >= d ({d}).")
            return params
        assert k > 0, "k must be a positive integer."

        # Randomly select k indices out of d
        indices = np.random.choice(d, k, replace=False)

        # Create a boolean mask of size d with exactly k entries as True
        mask = np.zeros(d, dtype=bool)
        mask[indices] = True

        # Apply the mask and scale by d/k
        scaling_factor = d / k
        compressed_flattened = scaling_factor * flattened_params * mask

        # Reconstruct the parameter shapes
        sparsified_params = []
        start_idx = 0
        for p in params:
            if p.ndim == 0:
                # If it's a scalar, just keep it uncompressed
                sparsified_params.append(p)
            else:
                flat_len = p.size
                sliced = compressed_flattened[start_idx:start_idx + flat_len]
                sparsified_params.append(sliced.reshape(p.shape))
                start_idx += flat_len

        return sparsified_params


    def prune_params_basedon_grads(self, grads, params, pruning_rate=0.3):
        """
        Prunes 'pruning_rate' fraction (e.g., 0.3) of the weights with the largest gradients.
        Args:
            grads (List[np.ndarray]): List of NumPy arrays representing gradients.
            params (List[np.ndarray]): List of NumPy arrays representing model parameters.
            pruning_rate (float): Fraction of weights to prune (e.g., 0.3 for 30%).

        Returns:
            List[np.ndarray]: Pruned parameters with the largest weights set to zero.
        """
        
        assert len(grads) == len(params), "Number of gradients and parameters must match."
        assert pruning_rate > 0 and pruning_rate < 1, "Pruning rate must be in (0, 1)."
        
        # Flatten all parameters (excluding scalars) into a single array
        flattened_params = np.concatenate([p.flatten() for p in params if p.ndim > 0])
        flattened_grads = np.concatenate([np.abs(g.flatten()) for g in grads if g.ndim > 0])

        # Determine the threshold for pruning
        threshold = np.percentile(flattened_grads, 100 * (1 - pruning_rate))

        # Create a mask for weights below the threshold
        mask = flattened_grads <= threshold  # Keep small or midrange gradients
        # print(f"Pruned parameters: {np.sum(mask)} out of {len(flattened_params)}, pruning rate {pruning_rate}.")

        # Apply the mask to set pruned weights to zero
        pruned_params = flattened_params * mask

        # Reconstruct the parameter shapes
        pruned_params_list = []
        start_idx = 0
        for p in params:
            if p.ndim == 0:
                # If it's a scalar, just keep it uncompressed
                pruned_params_list.append(p)
            else:
                flat_len = p.size
                sliced = pruned_params[start_idx:start_idx + flat_len]
                pruned_params_list.append(sliced.reshape(p.shape))
                start_idx += flat_len

        return pruned_params_list


    def prune_largest_grads(self, grads, pruning_rate=0.3):
        """
        Prunes 'pruning_rate' fraction (e.g., 0.3) of the largest gradients.
        Args:
            grads (List[np.ndarray]): List of NumPy arrays representing gradients.
            pruning_rate (float): Fraction of weights to prune (e.g., 0.3 for 30%).

        Returns:
            List[np.ndarray]: Pruned parameters with the largest weights set to zero.
        """
        
        assert pruning_rate > 0 and pruning_rate < 1, "Pruning rate must be in (0, 1)."
        
        # Flatten all parameters (excluding scalars) into a single array
        flattened_grads = np.concatenate([g.flatten() for g in grads if g.ndim > 0])
        flattened_abs_grads = np.concatenate([np.abs(g.flatten()) for g in grads if g.ndim > 0])

        # Determine the threshold for pruning
        threshold = np.percentile(flattened_abs_grads, 100 * (1 - pruning_rate))

        # Create a mask for weights below the threshold
        mask = flattened_abs_grads <= threshold  # Keep small or midrange gradients
        # print(f"Pruned parameters: {np.sum(mask)} out of {len(flattened_params)}, pruning rate {pruning_rate}.")

        # Apply the mask to set pruned weights to zero
        pruned_grads = flattened_grads * mask

        # Reconstruct the parameter shapes
        pruned_grads_list = []
        start_idx = 0
        for g in grads:
            if g.ndim == 0:
                # If it's a scalar, just keep it uncompressed
                pruned_grads_list.append(g)
            else:
                flat_len = g.size
                sliced = pruned_grads[start_idx:start_idx + flat_len]
                pruned_grads_list.append(sliced.reshape(g.shape))
                start_idx += flat_len

        return pruned_grads_list


def parse_args():
    parser = argparse.ArgumentParser(description="Flower")
    parser.add_argument(
        "--id",
        type=int,
        choices=range(1, 101),
        required=True,
        help="Specifies the artificial data partition",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        help="Dataset name",
        default="mnist",
        choices=["mnist", "cifar10", "imdb", "fmnist"],
    )
    parser.add_argument(
        "--exp_n",
        type=int,
        help="exp number",
        default=0,
    )
    parser.add_argument(
        "--scaling_dp",
        type=int,
        help="scaling factor for differential privacy",
        default=0,
    )
    
    return parser.parse_args()
    









# main
def main()->None:
    # Arguments
    args = parse_args()

    # check gpu and set manual seed
    device = utils.check_gpu(seed=cfg.seed, client_id=args.id)
    utils.set_seed(cfg.seed)
    config = cfg.experiments[args.dataset]

    # model and history folder
    model = models.model_dict[config["dataset"]](config["model_args"]).to(device)
    if args.id == 1:
        utils.print_num_parameters(model)
    # Load data
    data = torch.load(f'../data/client_datasets/IID_data_client_{args.id}.pt', weights_only=False)
    
    # Split the dataset
    train_size = config['client_train_samples'][args.exp_n]
    val_size = int(train_size * 0.3) # 30% for validation
    total_requested = train_size + val_size
    if total_requested > len(data):
        raise ValueError(
            f"Requested train+val samples ({total_requested}) exceed dataset size ({len(data)})!"
        )
    torch.manual_seed(cfg.seed)
    indices = torch.randperm(len(data))[:total_requested]
    subset_data = Subset(data, indices)
    train_dataset, val_dataset = random_split(
        subset_data, [train_size, val_size],
        generator=torch.Generator().manual_seed(cfg.seed)
)
    num_examples = {"train": train_size,"val": val_size}
    print(f"Num samples: {num_examples}")


    # Create the data loaders
    train_loader = DataLoader(train_dataset, batch_size=len(train_dataset), shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config["batch_test"], shuffle=False)

    # Optimizer and Loss function
    optimizer = torch.optim.SGD(model.parameters(), lr=cfg.lr, momentum=cfg.momentum)
    criterion = F.mse_loss if config["n_classes"] == 1 else F.cross_entropy
    
    # Start Flower client
    client = FlowerClient(
                        model, 
                        train_loader, 
                        val_loader, 
                        optimizer, 
                        criterion, 
                        num_examples, 
                        args.id, 
                        models.simple_train, 
                        models.simple_test, 
                        device,
                        privacy_audit=cfg.privacy_audit,
                        canary_frac=cfg.canary_frac,
                        p_value=cfg.p_value,
                        k_plus=cfg.k_plus,
                        k_min=cfg.k_min,
                        config=config, 
                        exp_n=args.exp_n,
                        scaling_dp=args.scaling_dp,                        
                          ).to_client()
    fl.client.start_client(server_address="[::]:8088", client=client) # local host
    
    # read saved data and plot
    utils.plot_client_metrics(args.id, config, show=False)
    






if __name__ == "__main__":
    main()
