import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import argparse
import random
import numpy as np
import os
import pandas as pd

from botorch.models import SingleTaskGP
from botorch.fit import fit_gpytorch_mll
from botorch.acquisition.analytic import LogExpectedImprovement
from botorch.optim import optimize_acqf
from gpytorch.mlls import ExactMarginalLogLikelihood
from torch.quasirandom import SobolEngine
from torch.utils.data import DataLoader, TensorDataset, Subset
from gpytorch.priors import LogNormalPrior
from botorch.models.transforms import Normalize, Standardize
from sklearn.preprocessing import StandardScaler

import torchquantum as tq

class QFCModel(tq.QuantumModule):
    """
    Quantum-classical model for Human Activity Recognition (HAR) classification.
    Features a learnable classical pre-processing layer, a variational
    quantum circuit, and a classical post-processing head.
    """
    class QLayer(tq.QuantumModule):
        """A single layer of the variational quantum circuit."""
        def __init__(self, matrix, num_qubits, layer_num):
            super().__init__()
            self.n_wires = num_qubits
            self.matrix = matrix

            self.gates = nn.ModuleDict()
            for i in range(self.n_wires):
                key = f"qlayer{layer_num}_qubit{i}_gate_"
                if matrix[i][0] == 1:
                    self.gates[key + "rx"] = tq.RX(has_params=True, trainable=True)
                if matrix[i][1] == 1:
                    self.gates[key + "ry"] = tq.RY(has_params=True, trainable=True)
                if matrix[i][2] == 1:
                    self.gates[key + "rz"] = tq.RZ(has_params=True, trainable=True)

            self.entanglers = nn.ModuleList([tq.CNOT(has_params=False, trainable=False) for _ in range(self.n_wires)])

        def forward(self, qdev: tq.QuantumDevice):
            for key, gate in self.gates.items():
                strt_idx = key.find("qubit")
                wire = int(key[strt_idx + len("qubit"):].split('_')[0])
                gate(qdev, wires=wire)
            for i in range(self.n_wires):
                self.entanglers[i](qdev, wires=[i, (i + 1) % self.n_wires])

    def __init__(self, matrix, num_qubits, num_classes, input_dim=561):
        super().__init__()
        self.n_wires = num_qubits
        self.num_classes = num_classes
        self.num_layers = len(matrix)
        self.input_dim = input_dim
        
        self.pre_fc = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 2 ** self.n_wires)
        )
        
        self.encoder = tq.AmplitudeEncoder()
        self.q_layers = nn.ModuleList([self.QLayer(matrix[i], num_qubits, i) for i in range(self.num_layers)])
        self.measure = tq.MeasureAll(tq.PauliZ)
        
        in_dim = self.n_wires
        self.head = nn.Sequential(
            nn.Linear(in_dim, 64),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(64, num_classes),
        )

    def forward(self, x):
        qdev = tq.QuantumDevice(n_wires=self.n_wires, bsz=x.shape[0], device=x.device)
        bsz = x.shape[0]
        
        x = self.pre_fc(x)
        x = F.normalize(x, p=2, dim=1)
        
        self.encoder(qdev, x)
        for i in range(self.num_layers):
            self.q_layers[i](qdev)
        x = self.measure(qdev)
        
        x = self.head(x)
        return F.log_softmax(x, dim=1)

def load_har_data(data_path):
    """
    Load the UCI HAR dataset from text files.
    
    Returns:
        train_data: (X_train, y_train, subject_train)
        test_data: (X_test, y_test, subject_test)
    """
    X_train = np.loadtxt(os.path.join(data_path, 'train', 'X_train.txt'))
    y_train = np.loadtxt(os.path.join(data_path, 'train', 'y_train.txt'), dtype=int) - 1  # Convert to 0-based indexing
    subject_train = np.loadtxt(os.path.join(data_path, 'train', 'subject_train.txt'), dtype=int)
    
    X_test = np.loadtxt(os.path.join(data_path, 'test', 'X_test.txt'))
    y_test = np.loadtxt(os.path.join(data_path, 'test', 'y_test.txt'), dtype=int) - 1  # Convert to 0-based indexing
    subject_test = np.loadtxt(os.path.join(data_path, 'test', 'subject_test.txt'), dtype=int)
    
    return (X_train, y_train, subject_train), (X_test, y_test, subject_test)

def get_federated_har_dataflows(data_path, num_clients, samples_per_client, primary_activities_per_client, primary_percentage, seed):
    """
    Creates non-IID HAR data distribution for federated learning.
    Each client gets different primary activities with activity imbalance.
    
    Args:
        data_path: Path to UCI HAR Dataset directory
        num_clients: Number of federated clients
        samples_per_client: Training samples per client
        primary_activities_per_client: List of primary activities for each client
        primary_percentage: Percentage of data from primary activities
        seed: Random seed for reproducibility
    """
    (X_train, y_train, subject_train), (X_test, y_test, subject_test) = load_har_data(data_path)
    
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)
    
    X_train_tensor = torch.FloatTensor(X_train)
    y_train_tensor = torch.LongTensor(y_train)
    X_test_tensor = torch.FloatTensor(X_test)
    y_test_tensor = torch.LongTensor(y_test)
    
    full_train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
    full_test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
    
    indices_by_class = [[] for _ in range(6)]
    for i, label in enumerate(y_train):
        indices_by_class[label].append(i)
    
    rng = random.Random(seed)
    client_dataflows = []
    
    for client_id in range(num_clients):
        print(f"Creating data for Client {client_id + 1}")
        
        client_primary_activities = primary_activities_per_client[client_id]
        secondary_activities = [c for c in range(6) if c not in client_primary_activities]
        
        num_primary_samples = int(samples_per_client * primary_percentage)
        num_secondary_samples = samples_per_client - num_primary_samples
        
        num_samples_per_primary = num_primary_samples // len(client_primary_activities)
        num_samples_per_secondary = num_secondary_samples // len(secondary_activities) if len(secondary_activities) > 0 else 0
        
        client_train_indices = []
        
        for pa in client_primary_activities:
            pool = indices_by_class[pa]
            take = min(num_samples_per_primary, len(pool))
            client_train_indices.extend(rng.sample(pool, take))
        
        for sa in secondary_activities:
            if num_samples_per_secondary > 0:
                pool = indices_by_class[sa]
                take = min(num_samples_per_secondary, len(pool))
                client_train_indices.extend(rng.sample(pool, take))
        
        rng.shuffle(client_train_indices)
        
        train_subset = Subset(full_train_dataset, client_train_indices)
        
        client_test_indices = []
        test_indices_by_class = {}
        
        for idx, label in enumerate(y_test):
            if label not in test_indices_by_class:
                test_indices_by_class[label] = []
            test_indices_by_class[label].append(idx)
        
        test_samples_per_primary = 50  # ~300 test samples per client
        test_samples_per_secondary = 10
        
        for pa in client_primary_activities:
            if pa in test_indices_by_class:
                pool = test_indices_by_class[pa]
                take = min(test_samples_per_primary, len(pool))
                client_test_indices.extend(rng.sample(pool, take))
        
        for sa in secondary_activities:
            if sa in test_indices_by_class and test_samples_per_secondary > 0:
                pool = test_indices_by_class[sa]
                take = min(test_samples_per_secondary, len(pool))
                client_test_indices.extend(rng.sample(pool, take))
        
        test_subset = Subset(full_test_dataset, client_test_indices)
        
        train_loader = DataLoader(train_subset, batch_size=32, shuffle=True, num_workers=2, pin_memory=True)
        test_loader = DataLoader(test_subset, batch_size=32, shuffle=False, num_workers=2, pin_memory=True)
        
        client_dataflows.append({'train': train_loader, 'test': test_loader})
        
        print(f"  Client {client_id + 1}: {len(client_train_indices)} training samples, {len(client_test_indices)} test samples")
        print(f"  Primary activities: {client_primary_activities}")
    
    return client_dataflows

def train(dataflow, model, device, optimizer):
    """Standard training loop for one epoch."""
    model.train()
    total_correct, total_samples, total_loss = 0, 0, 0.0
    
    for inputs, targets in dataflow["train"]:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = F.nll_loss(outputs, targets)
        loss.backward()
        optimizer.step()

        pred = outputs.argmax(dim=1, keepdim=True)
        correct = pred.eq(targets.view_as(pred)).sum().item()
        
        total_correct += correct
        total_samples += inputs.size(0)
        total_loss += loss.item()

    avg_loss = total_loss / len(dataflow["train"])
    avg_accuracy = 100. * total_correct / total_samples
    return avg_loss, avg_accuracy

def valid_test(dataflow, model, device):
    """Standard validation/testing loop."""
    model.eval()
    correct = 0
    with torch.no_grad():
        for inputs, targets in dataflow["test"]:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            pred = outputs.argmax(dim=1, keepdim=True)
            correct += pred.eq(targets.view_as(pred)).sum().item()
    
    accuracy = 100. * correct / len(dataflow['test'].dataset)
    return accuracy

def evaluate_model(x_candidate, num_qubits, num_layers, epochs, device, dataflow, num_classes):
    """Instantiate, train, and evaluate a model for a given architecture."""
    binary_matrix = []
    for layer in range(num_layers):
        layer_matrix = []
        for i in range(num_qubits):
            qubit_gates = [round(float(x_candidate[(3 * i + j) + layer * num_qubits * 3])) for j in range(3)]
            layer_matrix.append(qubit_gates)
        binary_matrix.append(layer_matrix)
    
    model = QFCModel(binary_matrix, num_qubits, num_classes).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    final_train_acc = 0
    for epoch in range(epochs):
        loss, acc = train(dataflow, model, device, optimizer)
        final_train_acc = acc
        scheduler.step()
    
    test_acc = valid_test(dataflow, model, device)
    return final_train_acc, test_acc, binary_matrix

def bo_step(X, Y, bounds, num_qubits, num_layers, epochs, device, dataflow, num_classes):
    """Performs one step of Bayesian Optimization."""
    gp = SingleTaskGP(X, Y, input_transform=Normalize(d=X.shape[-1]), outcome_transform=Standardize(m=1))
    mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
    fit_gpytorch_mll(mll)
    ei = LogExpectedImprovement(model=gp, best_f=Y.max())
    new_x, _ = optimize_acqf(ei, bounds=bounds, q=1, num_restarts=10, raw_samples=1024)
    
    train_acc, test_acc, config = evaluate_model(new_x[0], num_qubits, num_layers, epochs, device, dataflow, num_classes)
    new_y = torch.tensor([[test_acc]], dtype=torch.double)
    
    return new_x, new_y, config

def getXinit(num_qubits, num_layers, extra_sobol=1, seed=None):
    """Generates initial points: all-RX, all-RY, all-RZ + optional Sobol points."""
    dim = num_qubits * num_layers * 3
    
    rX, rY, rZ = [], [], []
    for _ in range(num_qubits * num_layers):
        rX.extend([1.0, 0.0, 0.0])
        rY.extend([0.0, 1.0, 0.0])
        rZ.extend([0.0, 0.0, 1.0])
    
    X_init = torch.tensor([rX, rY, rZ], dtype=torch.double)
    if extra_sobol > 0:
        sobol = SobolEngine(dimension=dim, scramble=True, seed=seed)
        X_sobol = sobol.draw(n=extra_sobol).double()
        X_init = torch.cat([X_init, X_sobol], dim=0)

    return X_init

def client_architecture_search(client_id, dataflow, num_qubits, num_layers, num_classes, device, bo_iters, epochs_per_model):
    """
    Perform Bayesian Optimization architecture search for a specific client.
    """
    print(f"\n=== Architecture Search for Client {client_id + 1} ===")
    
    dim = 3 * num_qubits * num_layers
    bounds = torch.stack([torch.zeros(dim), torch.ones(dim)]).double()
    
    X_init = getXinit(num_qubits, num_layers, extra_sobol=0, seed=None)
    n_init = X_init.shape[0]
    Y_init = torch.zeros(n_init, 1, dtype=torch.double)
    
    bo_results = []
    
    print(f"Evaluating {n_init} initial architectures...")
    for i in range(n_init):
        train_acc, test_acc, config = evaluate_model(X_init[i], num_qubits, num_layers, epochs_per_model, device, dataflow, num_classes)
        Y_init[i, 0] = test_acc / 100.0
        bo_results.append({
            'Client': client_id + 1,
            'Evaluation': i + 1,
            'Config': str(config),
            'Train_Acc': train_acc,
            'Test_Acc': test_acc
        })
        print(f"  Initial {i + 1}: Train={train_acc:.2f}%, Test={test_acc:.2f}%")
    
    X_best, Y_best = X_init.clone(), Y_init.clone()
    
    print(f"Starting {bo_iters} BO iterations...")
    for i in range(bo_iters):
        new_x, new_y, config = bo_step(X_best, Y_best, bounds, num_qubits, num_layers, epochs_per_model, device, dataflow, num_classes)
        X_best = torch.cat([X_best, new_x])
        Y_best = torch.cat([Y_best, new_y / 100.0])
        
        test_acc = float(new_y)
        train_acc = evaluate_model(new_x[0], num_qubits, num_layers, epochs_per_model, device, dataflow, num_classes)[0]
        
        bo_results.append({
            'Client': client_id + 1,
            'Evaluation': len(bo_results) + 1,
            'Config': str(config),
            'Train_Acc': train_acc,
            'Test_Acc': test_acc
        })
        
        print(f"  BO {i + 1}: Train={train_acc:.2f}%, Test={test_acc:.2f}%")
    
    best_idx = int(torch.argmax(Y_best * 100.0))
    best_config_vector = X_best[best_idx]
    
    best_config = []
    for layer in range(num_layers):
        layer_matrix = []
        for i in range(num_qubits):
            qubit_gates = [round(float(best_config_vector[(3 * i + j) + layer * num_qubits * 3])) for j in range(3)]
            layer_matrix.append(qubit_gates)
        best_config.append(layer_matrix)
    
    best_acc = float(Y_best[best_idx] * 100.0)
    print(f"Client {client_id + 1} best architecture: {best_acc:.2f}% accuracy")
    
    return best_config, bo_results

def make_selective_copy(global_model, client_config, device):
    """
    Create a local model with client-specific architecture by copying relevant parameters from global model.
    """
    local_model = QFCModel(client_config, global_model.n_wires, global_model.num_classes).to(device)
    
    global_state = global_model.state_dict()
    local_state = local_model.state_dict()
    
    for name in local_state:
        if name in global_state and local_state[name].shape == global_state[name].shape:
            local_state[name] = global_state[name].clone()
        else:
            print(f"Parameter {name} in local model not found in global")

    local_model.load_state_dict(local_state)
    return local_model

def aggregate(local_models, global_model):
    """
    Aggregate local models into global model using parameter-specific averaging.
    Only parameters with matching shapes are aggregated.
    """
    global_state = global_model.state_dict()
    temp_state = {k: torch.zeros_like(v) for k, v in global_state.items()}
    count_state = {k: 0 for k in global_state}
    
    for local_model in local_models:
        local_state = local_model.state_dict()

        for name in local_state:
            if name in global_state and local_state[name].shape == global_state[name].shape:
                temp_state[name] += local_state[name]
                count_state[name] += 1

    total_params = len(global_state)
    updated_params = sum(1 for count in count_state.values() if count > 0)
    print(f"  Parameter update coverage: {updated_params}/{total_params} ({100*updated_params/total_params:.1f}%)")
    
    update_counts = {}
    for count in count_state.values():
        update_counts[count] = update_counts.get(count, 0) + 1
    print(f"  Update distribution: {update_counts}")

    for name in global_state:
        if count_state[name] > 0:
            global_state[name] = (temp_state[name] / count_state[name]).to(global_state[name].dtype)
        else:
            pass

    global_model.load_state_dict(global_state)
    return global_model

def create_global_architecture(client_configs):
    num_layers = len(client_configs[0])
    num_qubits = len(client_configs[0][0])
    num_clients = len(client_configs)
    
    global_matrix = []
    for layer in range(num_layers):
        global_matrix.append([])
        for qubit in range(num_qubits):
            global_matrix[layer].append([])
            for gate in range(3):
                gate_count = 0
                for client_config in client_configs:
                    gate_count += client_config[layer][qubit][gate]
                
                gate_value = 1 if gate_count > 0 else 0
                
                global_matrix[layer][qubit].append(gate_value)
    
    return global_matrix

def federated_learning(client_configs, client_dataflows, num_qubits, num_classes, device, global_epochs, local_epochs):
    print(f"\n=== Starting Federated Learning ===")
    
    global_matrix = create_global_architecture(client_configs)
    print(f"Global architecture created: {global_matrix}")
    
    global_model = QFCModel(global_matrix, num_qubits, num_classes).to(device)
    
    
    temp_models = []
    for client_config in client_configs:
        temp_model = QFCModel(client_config, num_qubits, num_classes).to(device)
        temp_models.append(temp_model)
    global_model = aggregate(temp_models, global_model)
    print("Global model initialized with averaged client architectures")
    
    global_results = []
    
    for global_epoch in range(global_epochs):
        print(f"\n--- Global Epoch {global_epoch + 1}/{global_epochs} ---")
        
        local_models = []
        client_train_accs = []
        
        for client_id, (client_config, client_dataflow) in enumerate(zip(client_configs, client_dataflows)):
            print(f"Training Client {client_id + 1}...")
            
            
            local_model = QFCModel(global_matrix, num_qubits, num_classes).to(device)
            local_model.load_state_dict(global_model.state_dict())
            
            optimizer = optim.Adam(local_model.parameters(), lr=1e-3, weight_decay=1e-4)
            scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=local_epochs)
            
            
            final_train_acc = 0
            best_acc = 0
            patience = 5
            no_improve = 0
            
            for epoch in range(local_epochs * 2):  # Double the epochs
                loss, acc = train(client_dataflow, local_model, device, optimizer)
                final_train_acc = acc
                
                
                if acc > best_acc:
                    best_acc = acc
                    no_improve = 0
                else:
                    no_improve += 1
                    if no_improve >= patience:
                        break
                        
                scheduler.step()
            
            print(f"  Client {client_id + 1} final train accuracy: {final_train_acc:.2f}%")
            local_models.append(local_model)
            client_train_accs.append(final_train_acc)
        
        
        global_model = aggregate(local_models, global_model)
        
        
        global_train_acc = sum(client_train_accs) / len(client_train_accs)
        
        
        global_test_accs = []
        for client_id, client_dataflow in enumerate(client_dataflows):
            client_test_acc = valid_test(client_dataflow, global_model, device)
            global_test_accs.append(client_test_acc)
        global_test_acc = sum(global_test_accs) / len(global_test_accs)
        
        
        global_on_client_train_accs = []
        for client_id, client_dataflow in enumerate(client_dataflows):
            client_train_acc_with_global = evaluate_global_on_client_train(client_dataflow, global_model, device)
            global_on_client_train_accs.append(client_train_acc_with_global)
        
        avg_global_on_train = sum(global_on_client_train_accs) / len(global_on_client_train_accs)
        
        global_results.append({
            'Global_Epoch': global_epoch + 1,
            'Client_Train_Acc_Avg': global_train_acc,  # Average of client local train accuracies
            'Global_on_Client_Train_Avg': avg_global_on_train,  # Global model evaluated on client training data
            'Test_Accuracy': global_test_acc,  # Average of global model on all client test sets
            'Client_Train_Accs': client_train_accs,  # Individual client accuracies
            'Global_on_Client_Train_Accs': global_on_client_train_accs,  # Global model on each client's training data
            'Global_on_Client_Test_Accs': global_test_accs  # Global model on each client's test data
        })
        
        print(f"Average client train accuracy: {global_train_acc:.2f}%")
        print(f"Global model on client train data: {avg_global_on_train:.2f}%")
        print(f"Global model test accuracy: {global_test_acc:.2f}%")
    
    return global_results

def evaluate_global_on_client_train(client_dataflow, global_model, device):
    """Evaluate global model on a specific client's training data."""
    global_model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in client_dataflow["train"]:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = global_model(inputs)
            pred = outputs.argmax(dim=1, keepdim=True)
            correct += pred.eq(targets.view_as(pred)).sum().item()
            total += inputs.size(0)
    
    accuracy = 100. * correct / total
    return accuracy

def main():
    parser = argparse.ArgumentParser(description="Federated Quantum Architecture Search for HAR")
    parser.add_argument("--seed", type=int, default=44)
    parser.add_argument("--num_clients", type=int, default=3, help="Number of federated clients")
    parser.add_argument("--num_qubits", type=int, default=6, help="Number of qubits in quantum circuits")
    parser.add_argument("--num_layers", type=int, default=3, help="Number of layers in quantum circuits")
    parser.add_argument("--num_classes", type=int, default=6, help="Number of activity classes in HAR")
    parser.add_argument("--samples_per_client", type=int, default=800, help="Training samples per client")
    parser.add_argument("--primary_percentage", type=float, default=0.85, help="Percentage of primary activity data per client")
    parser.add_argument("--bo_iters", type=int, default=20, help="BO iterations per client")
    parser.add_argument("--epochs_per_model", type=int, default=15, help="Training epochs per architecture evaluation")
    parser.add_argument("--global_epochs", type=int, default=30, help="Number of federated learning rounds")
    parser.add_argument("--local_epochs", type=int, default=5, help="Local training epochs per round")
    parser.add_argument("--data_path", type=str, default="./data/UCI HAR Dataset", help="Path to UCI HAR Dataset")
    parser.add_argument("--out_dir", type=str, default=None, help="Output directory (auto-generated if not specified)")
    
    args = parser.parse_args()
    
    if args.out_dir is None:
        args.out_dir = f"FL_HAR_{args.num_clients}clients_results"
    
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    print(f"Using device: {device}")
    print(f"Results will be saved to: {args.out_dir}")
    
    all_activities = list(range(6))  # HAR has 6 activities: WALKING, WALKING_UPSTAIRS, WALKING_DOWNSTAIRS, SITTING, STANDING, LAYING
    activities_per_client = 3  # Each client gets exactly 3 primary activities
    
    primary_activities_per_client = []
    rng = random.Random(args.seed)
    
    for client_id in range(args.num_clients):
        client_activities = rng.sample(all_activities, activities_per_client)
        primary_activities_per_client.append(sorted(client_activities))
    
    print(f"Primary activities distribution (3 activities per client, overlaps allowed):")
    activity_names = ['WALKING', 'WALKING_UPSTAIRS', 'WALKING_DOWNSTAIRS', 'SITTING', 'STANDING', 'LAYING']
    for i, activities in enumerate(primary_activities_per_client):
        activity_labels = [activity_names[act] for act in activities]
        print(f"  Client {i+1}: {activities} ({', '.join(activity_labels)})")
    
    print(f"Creating federated HAR dataset for {args.num_clients} clients...")
    client_dataflows = get_federated_har_dataflows(
        data_path=args.data_path,
        num_clients=args.num_clients,
        samples_per_client=args.samples_per_client,
        primary_activities_per_client=primary_activities_per_client,
        primary_percentage=args.primary_percentage,
        seed=args.seed
    )
    
    print(f"\n{'='*50}")
    print("PHASE 1: CLIENT-SPECIFIC ARCHITECTURE SEARCH")
    print(f"{'='*50}")
    
    client_configs = []
    all_bo_results = []
    
    for client_id in range(args.num_clients):
        config, bo_results = client_architecture_search(
            client_id=client_id,
            dataflow=client_dataflows[client_id],
            num_qubits=args.num_qubits,
            num_layers=args.num_layers,
            num_classes=args.num_classes,
            device=device,
            bo_iters=args.bo_iters,
            epochs_per_model=args.epochs_per_model
        )
        client_configs.append(config)
        all_bo_results.extend(bo_results)
    
    print(f"\n{'='*50}")
    print("PHASE 2: FEDERATED LEARNING")
    print(f"{'='*50}")
    
    global_results = federated_learning(
        client_configs=client_configs,
        client_dataflows=client_dataflows,
        num_qubits=args.num_qubits,
        num_classes=args.num_classes,
        device=device,
        global_epochs=args.global_epochs,
        local_epochs=args.local_epochs
    )
    
    os.makedirs(args.out_dir, exist_ok=True)
    
    bo_df = pd.DataFrame(all_bo_results)
    bo_df.to_csv(os.path.join(args.out_dir, "architecture_search_results.csv"), index=False)
    
    fl_df = pd.DataFrame(global_results)
    fl_df.to_csv(os.path.join(args.out_dir, "federated_learning_results.csv"), index=False)
    
    config_data = []
    for i, config in enumerate(client_configs):
        config_data.append({
            'Client': i + 1,
            'Architecture': str(config),
            'Primary_Activities': str(primary_activities_per_client[i])
        })
    config_df = pd.DataFrame(config_data)
    config_df.to_csv(os.path.join(args.out_dir, "client_architectures.csv"), index=False)
    
    print(f"\n{'='*50}")
    print("RESULTS SUMMARY")
    print(f"{'='*50}")
    
    for i, config in enumerate(client_configs):
        print(f"Client {i + 1} final architecture: {config}")
    
    final_acc = global_results[-1]['Test_Accuracy']
    print(f"\nFinal federated model accuracy: {final_acc:.2f}%")
    print(f"Results saved to: {args.out_dir}/")

if __name__ == "__main__":
    main()
