import argparse
import pickle
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import torch.profiler
import numpy as np
import wandb
import os
import glob
import re
import random
from opacus.accountants.utils import get_noise_multiplier



# Import your custom dataset & models
from models.client_model import Bert
from models.server_model import MLPCombiner
from prepare_data import AmazonPolarityPreprocessor
from torch.utils.data import Dataset
# from transformers import BertTokenizer, BlipImageProcessor

parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR',
                    help='path to dataset')
parser.add_argument('-j', '--workers', default=32, type=int, metavar='N',
                    help='number of data loading workers (default: 32)')
parser.add_argument('--epochs', default=100, type=int, metavar='N',
                    help='number of total epochs to run')
# parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
#                     help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
                    metavar='N',
                    help='mini-batch size (default: 256), this is the total '
                         'batch size of all GPUs on the current node when '
                         'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', '--learning-rate', default=0.03, type=float,
                    metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--server_lr', '--server-learning-rate', default=0.005, type=float,
                    metavar='SVLR', help='learning rate for server', dest='server_lr')
# parser.add_argument('--lr_un', '--unsupervised-learning-rate', default=10.0, type=float,
#                     metavar='LR', help='initial learning rate for final linear layer', dest='lr_un')
#parser.add_argument('--schedule', default=[60, 80], nargs='*', type=int,
#                    help='learning rate schedule (when to drop lr by a ratio)')
# parser.add_argument('--schedule', default=[], nargs='*', type=int,
#                     help='learning rate schedule (when to drop lr by a ratio)')
parser.add_argument('--momentum', default=0, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--warmup_rate', default=0.1, type=float,
                    help='Warmup rate for linear LR warmup')
parser.add_argument('--wd', '--weight-decay', default=0., type=float,
                    metavar='W', help='weight decay (default: 0.)',
                    dest='weight_decay')
parser.add_argument('-p', '--print-freq', default=50, type=int,
                    metavar='N', help='print frequency (default: 50)')
parser.add_argument('--resume', default='./log', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('--no_resume', action='store_true', help='Do not resume from checkpoint')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
# parser.add_argument('--world-size', default=-1, type=int,
#                     help='number of nodes for distributed training')
# parser.add_argument('--rank', default=-1, type=int,
#                     help='node rank for distributed training')
# parser.add_argument('--dist-url', default='tcp://localhost:10001', type=str,
#                     help='url used to set up distributed training')
# parser.add_argument('--dist-backend', default='nccl', type=str,
#                     help='distributed backend')
parser.add_argument('--seed', default=None, type=int,
                    help='seed for initializing training. ')
# parser.add_argument('--gpus', default=None, nargs='+', type=int,
#                     help='GPU id(s) to use. Default is all visible GPUs.')
# parser.add_argument('--multiprocessing-distributed', action='store_true',
#                     help='Use multi-processing distributed training to launch '
#                          'N processes per node, which has N GPUs. This is the '
#                          'fastest way to use PyTorch for either single node or '
#                          'multi node data parallel training')

parser.add_argument('--num_clients', default=4, type=int,
                    help='number of clients')
parser.add_argument('--mode', default="dpzv", type=str,
        help='VFL algorithm to use: vafl, zofo or dpzv')
# parser.add_argument('--labeled_frac', default=1.0, type=float,
#                     help='fraction of training data that is labeled')
#parser.add_argument('--local_epochs', default=1, type=float,
#                    help='Number of local iterations')
# parser.add_argument('--server_time', default=1.0, type=float,
#                     help='How long roundtrip server communication takes')

#Add for DPZV

parser.add_argument('--no_dp', action='store_true', help='Do not use DP')

parser.add_argument('--zo_mu', default=1e-3, type=float,
                    help='the scale for the perturbation of parameters in zero order update')
parser.add_argument('--dp_clip_threshold', default=10., type=float,
                    help='the clipping threshold of the gradient for achieving DP')
parser.add_argument('--dp_epsilon', default=6., type=float,
                    help='DP level parameter epsilon')
parser.add_argument('--dp_delta', default=1e-5, type=float,
                    help='DP level parameter delta')
parser.add_argument('--grad_estimate_method', default='central', type=str,
                    help='The method for estimating zeroth-order gradient: central or forward')
parser.add_argument('--min_lr', default=1e-7, type=float,
                    help='Minimum learning rate')
parser.add_argument('--patience', default=3, type=int,
                    help='Patience for learning rate scheduler')
parser.add_argument('--K', default=1, type=int,
                    help='The number of excessive updates per round')


# Add for ZOFO
parser.add_argument('--num_purt', default=5, type=int,
                    help='number of purturbations for ZOFO')
parser.add_argument('--no_mezo', action='store_true',
                    help='use mezo for ZOFO')

args = parser.parse_args()


def clip_tensor(tensor, max_norm):
    """
    Clips a tensor to have a norm at most `max_norm`.
    
    Args:
        tensor (torch.Tensor): The input tensor.
        max_norm (float): The maximum allowed norm.
    
    Returns:
        torch.Tensor: The clipped tensor.
    """
    norm = torch.norm(tensor, p=2)  # Compute L2 norm
    scale = min(1, max_norm / (norm + 1e-6))  # Compute scaling factor (avoid division by zero)
    return tensor * scale

def embedding_dp(embedding, args):
    embedding = clip_tensor(embedding, args.dp_clip_threshold)
    noise = torch.normal(
        mean=0,
        std=args.dpzero_gaussian_std,
        size=embedding.size(),
        device=embedding.device,
        dtype=embedding.dtype,
    )
    embedding = embedding + noise
    return embedding

def vafl_train(loader, models, optimizers, criterion, args, epoch):
    total_loss = 0.0
    total_count = 0
    device = args.device

    for batch_idx, (texts, labels) in enumerate(loader):
        if texts is None:
            continue  # Skip empty batches


        for client in range(args.num_clients + 1):
            adjust_lr(args, optimizer=optimizers[client])
        texts = [t.to(device) for t in texts]
        labels = labels.to(device)
        embeddings = []
        for client in range(args.num_clients):
            text_local = texts[client]
            with torch.no_grad():
                embedding = models[client](text_local)
                embedding = embedding_dp(embedding, args)
                embeddings.append(embedding)
        args.global_step += 1
        for client in range(args.num_clients):
            text_local = texts[client]
            optimizers[client].zero_grad()
            optimizers[-1].zero_grad()
            embedding_view = [client_view.detach().clone() for client_view in embeddings]
            embedding_view[client] = models[client](text_local)
            output = models[-1](embeddings)
            loss = criterion(output, labels).mean()
            loss.backward()
            optimizers[client].step()
            optimizers[-1].step()

            total_loss += loss.item()
            total_count += 1
        if batch_idx % 100 == 0:
            acc1 = accuracy(output, labels, topk=(1,))[0]
            print(f"Epoch: {epoch+1}/{args.epochs}, Batch: {batch_idx}/{len(loader)}, Loss: {loss.item():.4f}, Acc: {acc1:.2f}")
    return total_loss / total_count

def perturb_embedding(embedding, random_seed: int, args, scaling_factor=1, ):
    torch.manual_seed(random_seed)
    z = torch.normal(mean=0, std=1, size=embedding.size(), device=embedding.device, dtype=embedding.dtype)
    with torch.no_grad():
        embedding = embedding + scaling_factor * z * args.zo_mu
    return embedding

def project_gradient(loss_diff, size, random_seed: int):
    torch.manual_seed(random_seed)
    z = torch.normal(mean=0, std=1, size=size, device=loss_diff.device, dtype=loss_diff.dtype)
    grad = loss_diff * z
    return grad

def zofo_train(loader, models, optimizers, criterion, args, epoch):
    total_loss = 0.0
    total_count = 0
    device = args.device

    for batch_idx, (texts, labels) in enumerate(loader):
        if texts is None:
            continue  # Skip empty batches

        for client in range(args.num_clients + 1):
            adjust_lr(args, optimizer=optimizers[client])
        texts = [t.to(device) for t in texts]
        labels = labels.to(device)
        embeddings = []
        for client in range(args.num_clients):
            with torch.no_grad():
                embedding = models[client](texts[client])
                embedding = embedding_dp(embedding, args)
                embeddings.append(embedding)
        args.global_step += 1
        for client in range(args.num_clients):
            optimizers[client].zero_grad()
            optimizers[-1].zero_grad()
            text_local = texts[client]
            deltas = []
            mu_multiplier = 2
            embeddings_view_plus = embeddings.copy()
            embeddings_view_minus = embeddings.copy()
            embedding = models[client](text_local)
            embedding = embedding_dp(embedding, args)
            for _ in range(args.num_purt):
                embedding_view = embedding.clone()
                random_seed = np.random.randint(1000000000)
                embeddings_view_plus[client] = perturb_embedding(embedding_view, random_seed, args, scaling_factor=1)
                embeddings_view_minus[client] = perturb_embedding(embedding_view, random_seed, args, scaling_factor=-2)
                with torch.no_grad():
                    output_plus = models[-1](embeddings_view_plus)
                    output_minus = models[-1](embeddings_view_minus)
                    loss_1 = criterion(output_plus, labels).mean()
                    loss_2 = criterion(output_minus, labels).mean()
                    deltas.append((loss_1 - loss_2) / (mu_multiplier * args.zo_mu))
            loss_diff = sum(deltas) / args.num_purt
            partial_grad = project_gradient(loss_diff, embedding.size(), random_seed=random_seed)
            embedding.backward(gradient=partial_grad, inputs=list(models[client].parameters()))
            optimizers[client].step()
            output = models[-1](embeddings)
            loss = criterion(output, labels).mean()
            loss.backward()
            optimizers[-1].step()

            total_loss += loss.item()
            total_count += 1

        if batch_idx % 100 == 0:
            acc1 = accuracy(output, labels, topk=(1,))[0]
            print(f"Epoch: {epoch+1}/{args.epochs}, Batch: {batch_idx}/{len(loader)}, Loss: {loss.item():.4f}, Acc: {acc1:.2f}")
    return total_loss / total_count



class DPZV_Trainer():

    def __init__(self, args, server_optimizer):
        self.args = args
        self.server_optimizer = server_optimizer
        # self.agg_optimizer = None
        # print(self.dpzero_gaussian_std)
        self.lr = args.lr
        # self.random_seeds = [[] for _ in range(args.num_clients)]
        # self.history_diff = [[] for _ in range(args.num_clients)]   
        # self.grad = None 
        # self.total_params = [0 for _ in range(args.num_clients)]
        # self.adam_v = 0    
        # self.patience = args.patience
        # self.threshold = 0.1
        # self.min_lr = args.min_lr
        # self.best_metric = None
        # self.num_bad_epochs = 0


    def _get_learning_rate(self):
        return self.lr
    
    def zo_perturb_parameters(self, model: nn.Module, random_seed: int, scaling_factor=1):
        args = self.args
        torch.manual_seed(random_seed)
        with torch.no_grad():
            for name, param in model.named_parameters():
                z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
                param.data = param.data + scaling_factor * z * args.zo_mu
    
    def dpzero_clip(self, loss_diff, C=1.):
        abs_loss_diff = torch.abs(loss_diff)
        clipped_mask = abs_loss_diff > C
        clipping_rate = clipped_mask.float().mean().item()
        tmp = torch.min(torch.ones_like(loss_diff), torch.div(C * torch.ones_like(loss_diff), abs_loss_diff))
        return torch.mul(tmp, loss_diff).mean(), clipping_rate
    
    def zo_forward(self, model, inputs):
        """
        Get (no gradient) loss from the model. Dropout is turned off too.
        """
        model.eval()

        with torch.inference_mode():
            output = model(inputs)
        return output

    def zo_adjust_lr(self):
        new_lr = adjust_lr(self.args, self.server_optimizer)
        self.lr = new_lr
    
    def zo_update(self, model, seed, projected_grad):
            """
            Update the parameters with the estimated gradients.
            """
            args = self.args
            with torch.no_grad():
                torch.manual_seed(seed)
                grad = {}
                for name, param in model.named_parameters():
                    # Resample z
                    z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
                    if "bias" not in name and "layer_norm" not in name and "layernorm" not in name:
                        grad[name] = projected_grad * z + args.weight_decay * param.data
                    else:
                        grad[name] = projected_grad * z
                    # self.grads[name] = grad.clone()
                for name, param in model.named_parameters():
                    param.data = param.data - self._get_learning_rate() * grad[name]

    def mezo_update(self, model, seed, projected_grad ):
        """
        Update the parameters with the estimated gradients.
        """
        args = self.args
        with torch.no_grad():
            torch.manual_seed(seed)
            for name, param in model.named_parameters():
                # Resample z
                z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device, dtype=param.data.dtype)
                if "bias" not in name and "layer_norm" not in name and "layernorm" not in name:
                    grad = projected_grad * z + args.weight_decay * param.data
                else:
                    grad = projected_grad * z
                # self.grads[name] = grad.clone()
                param.data = param.data - self._get_learning_rate() * grad

    def train(self, loader, models, criterion, args, epoch):
        total_loss = 0.0
        total_count = 0
        device = args.device
        args = self.args  # using the trainer's own args

        for batch_idx, (texts, labels) in enumerate(loader):
            if texts is None:
                continue  # Skip empty batches

            self.zo_adjust_lr()
            texts = [t.to(device) for t in texts]
            labels = labels.to(device)
            embeddings = []
            for client in range(args.num_clients):
                with torch.no_grad():
                    embeddings.append(models[client](texts[client]))
            args.global_step += 1
            for client in range(args.num_clients):
                text_local = texts[client]
                embeddings_view_plus = embeddings.copy()
                embeddings_view_minus = embeddings.copy()
                seed = np.random.randint(100000)
                with torch.no_grad():
                    self.zo_perturb_parameters(model=models[client], random_seed=seed, scaling_factor=1)
                    embeddings_view_plus[client] = models[client](text_local)
                    self.zo_perturb_parameters(model=models[client], random_seed=seed, scaling_factor=-2)
                    embeddings_view_minus[client] = models[client](text_local)
                    self.zo_perturb_parameters(model=models[client], random_seed=seed, scaling_factor=1)
                    embeddings[client] = models[client](text_local)
                    output_plus = models[-1](embeddings_view_plus)
                    output_minus = models[-1](embeddings_view_minus)
                    loss_1 = criterion(output_plus, labels)
                    loss_2 = criterion(output_minus, labels)
                    mu_multiplier = 2
                    loss_diff = (loss_1 - loss_2) / (mu_multiplier * args.zo_mu)
                if args.dp_epsilon > 0:
                    projected_grad, clipping_rate = self.dpzero_clip(loss_diff, args.dp_clip_threshold)
                    projected_grad += torch.randn(1).item() * args.dpzero_gaussian_std
                else:
                    projected_grad, clipping_rate = self.dpzero_clip(loss_diff, args.dp_clip_threshold)
                if args.no_mezo:
                    self.zo_update(models[client], seed, projected_grad)
                else:
                    self.mezo_update(models[client], seed, projected_grad)
                for k in range(args.K):
                    output = models[-1](embeddings)
                    loss = criterion(output, labels).mean()
                    loss.backward()
                    self.server_optimizer.step()

                    total_loss += loss.item()
                    total_count += 1

            if batch_idx % 100 == 0:
                acc1 = accuracy(output, labels, topk=(1,))[0]
                print(f"Epoch: {epoch+1}/{args.epochs}, Batch: {batch_idx}/{len(loader)}, Loss: {loss.item():.4f}, Acc: {acc1:.2f}")
        return total_loss / total_count
            
class ZOOVFL_Trainer(DPZV_Trainer):
    def __init__(self, args, server_optimizer):
        super().__init__(args, server_optimizer)

    def train(self, loader, models, criterion, args, epoch):
        total_loss = 0.0
        total_count = 0
        device = args.device
        args = self.args  # using the trainer's own args

        for batch_idx, (texts, labels) in enumerate(loader):
            if texts is None:
                continue  # Skip empty batches

            self.zo_adjust_lr()
            texts = [t.to(device) for t in texts]
            labels = labels.to(device)
            embeddings = []
            for client in range(args.num_clients):
                with torch.no_grad():
                    embedding = models[client](texts[client])
                    embedding = embedding_dp(embedding, args)
                    embeddings.append(embedding)
            args.global_step += 1
            for client in range(args.num_clients):
                text_local = texts[client]
                embeddings_view_plus = embeddings.copy()
                embeddings_view_minus = embeddings.copy()
                seed = np.random.randint(100000)
                with torch.no_grad():
                    self.zo_perturb_parameters(model=models[client], random_seed=seed, scaling_factor=1)
                    embedding = models[client](text_local)
                    embedding = embedding_dp(embedding, args)
                    embeddings_view_plus[client] = embedding
                    self.zo_perturb_parameters(model=models[client], random_seed=seed, scaling_factor=-2)
                    embedding = models[client](text_local)
                    embedding = embedding_dp(embedding, args)
                    embeddings_view_minus[client] = embedding
                    self.zo_perturb_parameters(model=models[client], random_seed=seed, scaling_factor=1)
                    embedding = models[client](text_local)
                    embedding = embedding_dp(embedding, args)
                    embeddings[client] = embedding
                    output_plus = models[-1](embeddings_view_plus)
                    output_minus = models[-1](embeddings_view_minus)
                    loss_1 = criterion(output_plus, labels)
                    loss_2 = criterion(output_minus, labels)
                    mu_multiplier = 2
                    loss_diff = (loss_1 - loss_2) / (mu_multiplier * args.zo_mu)
                    projected_grad = loss_diff.mean()
                if args.no_mezo:
                    self.zo_update(models[client], seed, projected_grad)
                else:
                    self.mezo_update(models[client], seed, projected_grad)
                for k in range(args.K):
                    output = models[-1](embeddings)
                    loss = criterion(output, labels).mean()
                    loss.backward()
                    self.server_optimizer.step()

                    total_loss += loss.item()
                    total_count += 1

            if batch_idx % 100 == 0:
                acc1 = accuracy(output, labels, topk=(1,))[0]
                print(f"Epoch: {epoch+1}/{args.epochs}, Batch: {batch_idx}/{len(loader)}, Loss: {loss.item():.4f}, Acc: {acc1:.2f}")
        return total_loss / total_count



@torch.no_grad()
def evaluate(loader, models, criterion, device):
    for model in models:
        model.eval()
    # total_loss = 0.0
    # total_count = 0

    # text_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    # blip_processor = BlipImageProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
    # audio_processor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-base-960h")

    for batch_idx, (texts, labels) in enumerate(loader):
        if texts is None:
            continue # Skip empty batches
        texts = [t.to(device) for t in texts]       
        labels = labels.to(device)
        embeddings = []
        with torch.no_grad():
            for client in range(args.num_clients):
                embedding = models[client](texts[client])
                embeddings.append(embedding)
            output = models[-1](embeddings)
       
        loss = criterion(output, labels).mean()

        acc1 = accuracy(output, labels, topk=(1,))[0]


    return loss.item(), acc1.item()

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum()
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

def adjust_lr(
    args,
    optimizer=None
):
    """
    Mimics the linear warmup and linear decay from HF Transformers.
    
    :param global_step: Current training step (int).
    :param warmup_steps: Number of steps to linearly warm up the LR.
    :param total_steps: Total training steps (e.g., epochs * steps_per_epoch).
    :param init_lr: The maximum (peak) LR reached after warmup.
    """
    global_step = args.global_step
    warmup_steps = args.warmup_steps
    total_steps = args.total_steps
    if global_step < warmup_steps:
        # Warmup phase: LR from 0 -> init_lr
        new_lr = args.lr * float(global_step) / float(warmup_steps)
    else:
        # Decay phase: LR from init_lr -> 0
        # fraction of (remaining steps) completed
        steps_since_warmup = global_step - warmup_steps
        total_decay_steps = total_steps - warmup_steps
        if steps_since_warmup >= total_decay_steps:
            # if we've exceeded total_steps, LR = 0
            new_lr = 0.0
        else:
            remaining_frac = 1.0 - float(steps_since_warmup) / float(total_decay_steps)
            new_lr = args.lr * remaining_frac
    if optimizer is not None:
        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lr
    return new_lr


###############################################
# (E) MAIN
###############################################
def main():
    # Set number of batches to measure
    NUM_BATCHES_TO_MEASURE = 5
    
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
    save_folder_terms = [
        f'mode{args.mode}',
        # f'mu{args.zo_mu}',
        # f'b{args.batch_size}',
        # f'cla{args.client_arch}',
        # f'sva{args.server_arch}',
        f'lr{args.lr:g}',
        # f'slr{args.server_lr:g}',
        # f'cthr{args.dp_clip_threshold}',
        f'dpeps{args.dp_epsilon}',
        # f'st{args.server_time}',
        f'seed{args.seed}',
        # f'e{",".join(map(str, args.schedule))},200',
        # f'mom{args.momentum}',    
        ]
    
    save_folder = os.path.join(args.resume, '_'.join(save_folder_terms))
    wandb.init(
        project="DPZV_LLM",  # change to your project name
        name='_'.join(save_folder_terms),             # optional run name
        config={                            # track some hyperparams
            "epochs": args.epochs,
            "batch_size": args.batch_size,
            "learning_rate": args.lr,
            "server_learning_rate": args.server_lr,
            "DP Clip Threshold": args.dp_clip_threshold,
            "DP Epsilon": args.dp_epsilon,
            "ZO mu": args.zo_mu,
            "DP clip threshold": args.dp_clip_threshold,
            "seed": args.seed,
            "K": args.K
        }
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    args.device = device
    print(f"Using device: {device}")

    # 1. Load data
    preprocessor = AmazonPolarityPreprocessor(args)
    # preprocessor.load_dataset()
    preprocessor.preprocess_and_partition()
    train_loader, val_loader = preprocessor.create_dataloaders()

    models = []
    optimizers = []
    for client in range(args.num_clients):
        client_model = Bert().to(device)
        models.append(client_model)
        client_optimizer = optim.Adam(filter(lambda p: p.requires_grad, client_model.parameters()), 
                                         lr=args.server_lr)
        optimizers.append(client_optimizer)
    # # 2. Dataset + DataLoader
    # train_dataset = MOSEIDataset(train_data)
    # dev_dataset = MOSEIDataset(dev_data)
    # test_dataset = MOSEIDataset(test_data)

    # train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,collate_fn=custom_collate_fn)
    # dev_loader = DataLoader(dev_dataset, batch_size=args.batch_size, shuffle=False,collate_fn=custom_collate_fn)
    # test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False,collate_fn=custom_collate_fn)

    # Initialize parameters for scheduler
    args.total_steps = args.epochs * len(train_loader)
    args.warmup_steps = int(args.warmup_rate * args.total_steps)
    args.global_step = 0

    num_clients = args.num_clients
    num_classes = 2


    # 4. Late Fusion
    server_model = MLPCombiner(
        input_size=num_clients*768, hidden_size=768, num_classes=num_classes
    ).to(device)
    server_optimizer = optim.Adam(filter(lambda p: p.requires_grad, server_model.parameters()), 
                                         lr=args.server_lr)
    models.append(server_model)
    optimizers.append(server_optimizer)


    # 6. Loss Function (Assume MSE for continuous sentiment)
    criterion = nn.CrossEntropyLoss(reduction='none')
    
    # 7. Check for existing checkpoints
    if os.path.exists(save_folder):
        checkpoint_list = glob.glob(os.path.join(save_folder,"checkpoint-epoch*.pt"))
    else:
        os.mkdir(save_folder)
        checkpoint_list=None
    if not args.no_resume:
        if checkpoint_list:
            # Sort them to find the "latest" by epoch number
            # checkpoint-epoch10.pt -> extract '10' via regex
            def extract_epoch_number(fname):
                match = re.search(r"checkpoint-epoch(\d+)\.pt", fname)
                return int(match.group(1)) if match else -1

            checkpoint_list.sort(key=extract_epoch_number)
            latest_checkpoint = checkpoint_list[-1]  # last in sorted list
            print(f"Found checkpoint: {latest_checkpoint}. Restoring...")
            
            checkpoint = torch.load(latest_checkpoint, map_location=device)
            for client in range(args.num_clients):
                models[client].load_state_dict(checkpoint[f"client_model{client}"])
                optimizers[client].load_state_dict(checkpoint[f"client_optimizer{client}"])
            models[-1].load_state_dict(checkpoint["server_model"])
            optimizers[-1].load_state_dict(checkpoint["server_optimizer"])
            start_epoch = checkpoint["epoch"]  # start from this epoch
            args.global_step = checkpoint["global_step"]
        else:
            print("No existing checkpoints found. Starting from scratch.")
            start_epoch = 0
    else:
        print("Do not Resume. Starting from scratch.")
        start_epoch = 0

    # Compute noise multiplier
    if args.dp_epsilon>0:
        sample_rate = args.batch_size / len(train_loader.dataset)
        try:
            multiplier = get_noise_multiplier(target_epsilon=args.dp_epsilon,
                                                target_delta=args.dp_delta,
                                                epochs=args.epochs,
                                                sample_rate=sample_rate,
                                                accountant='gdp'
                                                )
        except ValueError:
            multiplier = get_noise_multiplier(target_epsilon=args.dp_epsilon,
                                                target_delta=args.dp_delta,
                                                epochs=args.epochs,
                                                sample_rate=sample_rate,
                                                # accountant='gdp'
                                                )
        dpzero_gaussian_std =multiplier * 2 * args.dp_clip_threshold / args.batch_size
        args.dpzero_gaussian_std = dpzero_gaussian_std
    # 7. Train/Eval Loop
    num_epochs = args.epochs

    if args.mode == "dpzv":
        dpzv_trainer = DPZV_Trainer(args, optimizers[-1])
    for epoch in range(start_epoch, num_epochs):

        
        if args.mode == "vafl":
            train_loss = vafl_train(train_loader, models, optimizers, criterion, args, epoch)

        elif args.mode == "zofo":
            train_loss = zofo_train(train_loader, models, optimizers, criterion, args, epoch)
            
                
        elif args.mode == "dpzv":
            train_loss = dpzv_trainer.train(train_loader, models, criterion, args, epoch)
        
        val_loss, val_acc = evaluate(val_loader, models, criterion, device)

        print(f"Epoch [{epoch + 1}/{num_epochs}], Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

        wandb.log({
            "epoch": epoch + 1,
            "train_loss": train_loss,
            "eva_loss": val_loss,
            "acc": val_acc,
        })

        # Remove previous checkpoint if it exists
        prev_ckpt_path = os.path.join(save_folder, f"checkpoint-epoch{epoch}.pt")
        if os.path.exists(prev_ckpt_path):
            os.remove(prev_ckpt_path)

        # Save checkpoint each epoch
        ckpt_path = os.path.join(save_folder, f"checkpoint-epoch{epoch+1}.pt")
        save_dict = {
            "epoch": epoch + 1,
            "global_step" : args.global_step,
        }
        
        for client in range(args.num_clients):
            save_dict[f"client_model{client}"] = models[client].state_dict()
            save_dict[f"client_optimizer{client}"] = optimizers[client].state_dict()
        save_dict["server_model"] = models[-1].state_dict()
        save_dict["server_optimizer"] = optimizers[-1].state_dict()
        torch.save(save_dict, ckpt_path)
        # wandb.save(ckpt_path)

    # # 8. Final Test
    #  val_loss, val_acc = evaluate(val_loader, models, criterion, device)
    # wandb.log({"test_loss": val_loss})
    # print(f"Test loss: {test_loss:.4f}")

if __name__ == "__main__":
    main()
