import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import numpy as np
from models.client_model import Bert
from models.server_model import MLPCombiner
from prepare_data import AmazonPolarityPreprocessor

def measure_vafl_time(models, optimizers, batch, device, criterion):
    texts, labels = batch
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    
    start_event.record()
    
    texts = [t.to(device) for t in texts]
    labels = labels.to(device)
    embeddings = []
    
    for client in range(len(models)-1):
        text_local = texts[client]
        with torch.no_grad():
            embedding = models[client](text_local)
            embeddings.append(embedding)
    
    for client in range(len(models)-1):
        text_local = texts[client]
        optimizers[client].zero_grad()
        optimizers[-1].zero_grad()
        embedding_view = [client_view.detach().clone() for client_view in embeddings]
        embedding_view[client] = models[client](text_local)
        output = models[-1](embeddings)
        loss = criterion(output, labels).mean()
        loss.backward()
        optimizers[client].step()
        optimizers[-1].step()
    
    end_event.record()
    torch.cuda.synchronize()
    return start_event.elapsed_time(end_event)

def measure_zofo_time(models, optimizers, batch, device, criterion, num_purt=5, zo_mu=1e-3):
    texts, labels = batch
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    
    start_event.record()
    
    texts = [t.to(device) for t in texts]
    labels = labels.to(device)
    embeddings = []
    
    for client in range(len(models)-1):
        with torch.no_grad():
            embeddings.append(models[client](texts[client]))
    
    for client in range(len(models)-1):
        optimizers[client].zero_grad()
        optimizers[-1].zero_grad()
        text_local = texts[client]
        deltas = []
        embeddings_view_plus = embeddings.copy()
        embeddings_view_minus = embeddings.copy()
        embedding = models[client](text_local)
        
        for _ in range(num_purt):
            embedding_view = embedding.clone()
            random_seed = np.random.randint(1000000000)
            z = torch.normal(mean=0, std=1, size=embedding_view.size(), device=embedding_view.device)
            embeddings_view_plus[client] = embedding_view + z * zo_mu
            embeddings_view_minus[client] = embedding_view - 2 * z * zo_mu
            
            with torch.no_grad():
                output_plus = models[-1](embeddings_view_plus)
                output_minus = models[-1](embeddings_view_minus)
                loss_1 = criterion(output_plus, labels).mean()
                loss_2 = criterion(output_minus, labels).mean()
                deltas.append((loss_1 - loss_2) / (2 * zo_mu))
        
        loss_diff = sum(deltas) / num_purt
        z = torch.normal(mean=0, std=1, size=embedding.size(), device=embedding.device)
        partial_grad = loss_diff * z
        embedding.backward(gradient=partial_grad)
        optimizers[client].step()
        
        output = models[-1](embeddings)
        loss = criterion(output, labels).mean()
        loss.backward()
        optimizers[-1].step()
    
    end_event.record()
    torch.cuda.synchronize()
    return start_event.elapsed_time(end_event)

def measure_dpzv_time(models, optimizers, batch, device, criterion, zo_mu=1e-3, dp_clip=10.0, dp_noise=0.1):
    texts, labels = batch
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    
    start_event.record()
    
    texts = [t.to(device) for t in texts]
    labels = labels.to(device)
    embeddings = []
    
    for client in range(len(models)-1):
        with torch.no_grad():
            embeddings.append(models[client](texts[client]))
    
    for client in range(len(models)-1):
        text_local = texts[client]
        embeddings_view_plus = embeddings.copy()
        embeddings_view_minus = embeddings.copy()
        seed = np.random.randint(100000)
        
        with torch.no_grad():
            # Perturb parameters
            for name, param in models[client].named_parameters():
                z = torch.normal(mean=0, std=1, size=param.data.size(), device=param.data.device)
                param.data = param.data + z * zo_mu
            embeddings_view_plus[client] = models[client](text_local)
            
            # for name, param in models[client].named_parameters():
            #     param.data = param.data - 2 * z * zo_mu
            embeddings_view_minus[client] = models[client](text_local)
            
            # for name, param in models[client].named_parameters():
            #     param.data = param.data + z * zo_mu
            embeddings[client] = models[client](text_local)
            
            output_plus = models[-1](embeddings_view_plus)
            output_minus = models[-1](embeddings_view_minus)
            loss_1 = criterion(output_plus, labels)
            loss_2 = criterion(output_minus, labels)
            loss_diff = (loss_1 - loss_2) / (2 * zo_mu)
            
            # Clip and add noise
            loss_norm = torch.norm(loss_diff)
            scale = min(1, dp_clip / (loss_norm + 1e-6))
            loss_diff = scale * loss_diff
            loss_diff += torch.normal(mean=0, std=dp_noise, size=loss_diff.size(), device=device)
            
            # Update parameters
            for name, param in models[client].named_parameters():
                print(f"Parameter name: {name}")
                print(f"param.data.size(): {param.data.size()}")
                z = torch.normal(mean=0, std=1, size=param.size(), device=param.device).reshape(param.data.shape)
                grad = loss_diff.mean() * z.view_as(param.data)
                param.data = param.data - optimizers[client].param_groups[0]['lr'] * grad
        
        output = models[-1](embeddings)
        loss = criterion(output, labels).mean()
        loss.backward()
        optimizers[-1].step()
    
    end_event.record()
    torch.cuda.synchronize()
    return start_event.elapsed_time(end_event)

def main():
    # Setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Parameters
    num_clients = 4
    batch_size = 32
    num_batches = 5
    
    # Create preprocessor and get data
    class Args:
        def __init__(self):
            self.batch_size = batch_size
            self.num_clients = num_clients
            self.data = "datasets/amazon_polarity"
    args = Args()
    
    preprocessor = AmazonPolarityPreprocessor(args)
    preprocessor.preprocess_and_partition()
    train_loader, _ = preprocessor.create_dataloaders()
    
    # Initialize models and optimizers
    models = []
    optimizers = []
    for _ in range(num_clients):
        client_model = Bert().to(device)
        models.append(client_model)
        client_optimizer = optim.Adam(client_model.parameters(), lr=0.001)
        optimizers.append(client_optimizer)
    
    server_model = MLPCombiner(
        input_size=num_clients*768,
        hidden_size=768,
        num_classes=2
    ).to(device)
    server_optimizer = optim.Adam(server_model.parameters(), lr=0.001)
    models.append(server_model)
    optimizers.append(server_optimizer)
    
    criterion = nn.CrossEntropyLoss(reduction='none')
    
    # Measure times for each method
    methods = ['vafl', 'zofo', 'dpzv']
    for method in methods:
        times = []
        for i, batch in enumerate(train_loader):
            if i >= num_batches:
                break
            if batch[0] is None:
                continue
                
            if method == 'vafl':
                elapsed_time = measure_vafl_time(models, optimizers, batch, device, criterion)
            elif method == 'zofo':
                elapsed_time = measure_zofo_time(models, optimizers, batch, device, criterion)
            else:  # dpzv
                elapsed_time = measure_dpzv_time(models, optimizers, batch, device, criterion)
            
            times.append(elapsed_time)
            print(f"[{method}] Batch {i}: {elapsed_time:.2f} ms")
        
        avg_time = sum(times) / len(times)
        print(f"\nAverage CUDA time per batch for {method}: {avg_time:.2f} ms\n")

if __name__ == "__main__":
    main()
