import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import argparse
import random
import numpy as np
import os
import pandas as pd

from botorch.models import SingleTaskGP
from botorch.fit import fit_gpytorch_mll
from botorch.acquisition.analytic import LogExpectedImprovement
from botorch.optim import optimize_acqf
from gpytorch.mlls import ExactMarginalLogLikelihood
from torch.quasirandom import SobolEngine
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Subset
from gpytorch.priors import LogNormalPrior
from botorch.models.transforms import Normalize, Standardize

import torchquantum as tq

class QFCModel(tq.QuantumModule):
    """
    Quantum-classical model for MNIST classification.
    Features a learnable classical pre-processing layer, a variational
    quantum circuit, and a classical post-processing head.
    """
    class QLayer(tq.QuantumModule):
      
        def __init__(self, matrix, num_qubits, layer_num):
            super().__init__()
            self.n_wires = num_qubits
            self.matrix = matrix

            self.gates = nn.ModuleDict()
            for i in range(self.n_wires):
                key = f"qlayer{layer_num}_qubit{i}_gate_"
                if matrix[i][0] == 1:
                    self.gates[key + "rx"] = tq.RX(has_params=True, trainable=True)
                if matrix[i][1] == 1:
                    self.gates[key + "ry"] = tq.RY(has_params=True, trainable=True)
                if matrix[i][2] == 1:
                    self.gates[key + "rz"] = tq.RZ(has_params=True, trainable=True)

            self.entanglers = nn.ModuleList([tq.CNOT(has_params=False, trainable=False) for _ in range(self.n_wires)])

        def forward(self, qdev: tq.QuantumDevice):
            for key, gate in self.gates.items():
                strt_idx = key.find("qubit")
                wire = int(key[strt_idx + len("qubit"):].split('_')[0])
                gate(qdev, wires=wire)
            for i in range(self.n_wires):
                self.entanglers[i](qdev, wires=[i, (i + 1) % self.n_wires])

    def __init__(self, matrix, num_qubits, num_classes):
        super().__init__()
        self.n_wires = num_qubits
        self.num_classes = num_classes
        self.num_layers = len(matrix)
        self.pre_fc = nn.Linear(28 * 28, 2 ** self.n_wires)
        self.encoder = tq.AmplitudeEncoder()
        self.q_layers = nn.ModuleList([self.QLayer(matrix[i], num_qubits, i) for i in range(self.num_layers)])
        self.measure = tq.MeasureAll(tq.PauliZ)
        in_dim = self.n_wires
        self.head = nn.Sequential(
            nn.Linear(in_dim, 64),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(64, num_classes),
        )

    def forward(self, x):
        qdev = tq.QuantumDevice(n_wires=self.n_wires, bsz=x.shape[0], device=x.device)
        bsz = x.shape[0]
        x = x.view(bsz, -1)
        x = self.pre_fc(x)
        x = F.normalize(x, p=2, dim=1)
        self.encoder(qdev, x)
        for i in range(self.num_layers):
            self.q_layers[i](qdev)
        x = self.measure(qdev)
        x = self.head(x)
        return F.log_softmax(x, dim=1)

def get_federated_mnist_dataflows(num_clients, samples_per_client, primary_classes_per_client, primary_percentage, seed):
    """
    Creates non-IID MNIST data distribution for federated learning.
    Each client gets different primary classes with class imbalance.
    """
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))  # MNIST normalization values
    ])
    
    full_train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    full_test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

    indices_by_class = [[] for _ in range(10)]
    for i, (_, label) in enumerate(full_train_dataset):
        indices_by_class[label].append(i)
    
    rng = random.Random(seed)
    client_dataflows = []
    
    for client_id in range(num_clients):
        print(f"Creating data for Client {client_id + 1}")
        
        client_primary_classes = primary_classes_per_client[client_id]
        secondary_classes = [c for c in range(10) if c not in client_primary_classes]
        
        num_primary_samples = int(samples_per_client * primary_percentage)
        num_secondary_samples = samples_per_client - num_primary_samples
        
        num_samples_per_primary = num_primary_samples // len(client_primary_classes)
        num_samples_per_secondary = num_secondary_samples // len(secondary_classes) if len(secondary_classes) > 0 else 0
        
        client_train_indices = []
        
        for pc in client_primary_classes:
            pool = indices_by_class[pc]
            take = min(num_samples_per_primary, len(pool))
            client_train_indices.extend(rng.sample(pool, take))
        
        for sc in secondary_classes:
            if num_samples_per_secondary > 0:
                pool = indices_by_class[sc]
                take = min(num_samples_per_secondary, len(pool))
                client_train_indices.extend(rng.sample(pool, take))
        
        rng.shuffle(client_train_indices)
        
        train_subset = Subset(full_train_dataset, client_train_indices)
        
        client_test_indices = []
        test_indices_by_class = {}
        
        for idx, (_, label) in enumerate(full_test_dataset):
            if label not in test_indices_by_class:
                test_indices_by_class[label] = []
            test_indices_by_class[label].append(idx)
        
        test_samples_per_primary = 50  # ~500 test samples per client
        test_samples_per_secondary = 10
        
        for pc in client_primary_classes:
            if pc in test_indices_by_class:
                pool = test_indices_by_class[pc]
                take = min(test_samples_per_primary, len(pool))
                client_test_indices.extend(rng.sample(pool, take))
        
        for sc in secondary_classes:
            if sc in test_indices_by_class and test_samples_per_secondary > 0:
                pool = test_indices_by_class[sc]
                take = min(test_samples_per_secondary, len(pool))
                client_test_indices.extend(rng.sample(pool, take))
        
        test_subset = Subset(full_test_dataset, client_test_indices)
        
        train_loader = DataLoader(train_subset, batch_size=64, shuffle=True, num_workers=2, pin_memory=True)
        test_loader = DataLoader(test_subset, batch_size=64, shuffle=False, num_workers=2, pin_memory=True)
        
        client_dataflows.append({'train': train_loader, 'test': test_loader})
        
        print(f"  Client {client_id + 1}: {len(client_train_indices)} training samples, {len(client_test_indices)} test samples")
        print(f"  Primary classes: {client_primary_classes}")
    
    return client_dataflows

def train(dataflow, model, device, optimizer):
    """Standard training loop for one epoch."""
    model.train()
    total_correct, total_samples, total_loss = 0, 0, 0.0
    
    for inputs, targets in dataflow["train"]:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = F.nll_loss(outputs, targets)
        loss.backward()
        optimizer.step()

        pred = outputs.argmax(dim=1, keepdim=True)
        correct = pred.eq(targets.view_as(pred)).sum().item()
        
        total_correct += correct
        total_samples += inputs.size(0)
        total_loss += loss.item()

    avg_loss = total_loss / len(dataflow["train"])
    avg_accuracy = 100. * total_correct / total_samples
    return avg_loss, avg_accuracy

def valid_test(dataflow, model, device):
    """Standard validation/testing loop."""
    model.eval()
    correct = 0
    with torch.no_grad():
        for inputs, targets in dataflow["test"]:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            pred = outputs.argmax(dim=1, keepdim=True)
            correct += pred.eq(targets.view_as(pred)).sum().item()
    
    accuracy = 100. * correct / len(dataflow['test'].dataset)
    return accuracy

def evaluate_model(x_candidate, num_qubits, num_layers, epochs, device, dataflow, num_classes):
    """Instantiate, train, and evaluate a model for a given architecture."""
    binary_matrix = []
    for layer in range(num_layers):
        layer_matrix = []
        for i in range(num_qubits):
            qubit_gates = [round(float(x_candidate[(3 * i + j) + layer * num_qubits * 3])) for j in range(3)]
            layer_matrix.append(qubit_gates)
        binary_matrix.append(layer_matrix)
    
    model = QFCModel(binary_matrix, num_qubits, num_classes).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    final_train_acc = 0
    for epoch in range(epochs):
        loss, acc = train(dataflow, model, device, optimizer)
        final_train_acc = acc
        scheduler.step()
    
    test_acc = valid_test(dataflow, model, device)
    return final_train_acc, test_acc, binary_matrix

def bo_step(X, Y, bounds, num_qubits, num_layers, epochs, device, dataflow, num_classes):
    """Performs one step of Bayesian Optimization."""
    gp = SingleTaskGP(X, Y, input_transform=Normalize(d=X.shape[-1]), outcome_transform=Standardize(m=1))
    mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
    fit_gpytorch_mll(mll)
    ei = LogExpectedImprovement(model=gp, best_f=Y.max())
    new_x, _ = optimize_acqf(ei, bounds=bounds, q=1, num_restarts=10, raw_samples=1024)
    
    train_acc, test_acc, config = evaluate_model(new_x[0], num_qubits, num_layers, epochs, device, dataflow, num_classes)
    new_y = torch.tensor([[test_acc]], dtype=torch.double)
    
    return new_x, new_y, config

def getXinit(num_qubits, num_layers, extra_sobol=1, seed=None):
    """Generates initial points: all-RX, all-RY, all-RZ + optional Sobol points."""
    dim = num_qubits * num_layers * 3
    
    rX, rY, rZ = [], [], []
    for _ in range(num_qubits * num_layers):
        rX.extend([1.0, 0.0, 0.0])
        rY.extend([0.0, 1.0, 0.0])
        rZ.extend([0.0, 0.0, 1.0])
    
    X_init = torch.tensor([rX, rY, rZ], dtype=torch.double)
    if extra_sobol > 0:
        sobol = SobolEngine(dimension=dim, scramble=True, seed=seed)
        X_sobol = sobol.draw(n=extra_sobol).double()
        X_init = torch.cat([X_init, X_sobol], dim=0)

    return X_init

def client_architecture_search(client_id, dataflow, num_qubits, num_layers, num_classes, device, bo_iters, epochs_per_model):
    """
    Perform Bayesian Optimization architecture search for a specific client.
    """
    print(f"\n=== Architecture Search for Client {client_id + 1} ===")
    
    dim = 3 * num_qubits * num_layers
    bounds = torch.stack([torch.zeros(dim), torch.ones(dim)]).double()
    
    X_init = getXinit(num_qubits, num_layers, extra_sobol=0, seed=None)
    n_init = X_init.shape[0]
    Y_init = torch.zeros(n_init, 1, dtype=torch.double)
    
    bo_results = []
    
    print(f"Evaluating {n_init} initial architectures...")
    for i in range(n_init):
        train_acc, test_acc, config = evaluate_model(X_init[i], num_qubits, num_layers, epochs_per_model, device, dataflow, num_classes)
        Y_init[i, 0] = test_acc / 100.0
        bo_results.append({
            'Client': client_id + 1,
            'Evaluation': i + 1,
            'Config': str(config),
            'Train_Acc': train_acc,
            'Test_Acc': test_acc
        })
        print(f"  Initial {i + 1}: Train={train_acc:.2f}%, Test={test_acc:.2f}%")
    
    X_best, Y_best = X_init.clone(), Y_init.clone()
    
    print(f"Starting {bo_iters} BO iterations...")
    for i in range(bo_iters):
        new_x, new_y, config = bo_step(X_best, Y_best, bounds, num_qubits, num_layers, epochs_per_model, device, dataflow, num_classes)
        X_best = torch.cat([X_best, new_x])
        Y_best = torch.cat([Y_best, new_y / 100.0])
        
        test_acc = float(new_y)
        train_acc = evaluate_model(new_x[0], num_qubits, num_layers, epochs_per_model, device, dataflow, num_classes)[0]
        
        bo_results.append({
            'Client': client_id + 1,
            'Evaluation': len(bo_results) + 1,
            'Config': str(config),
            'Train_Acc': train_acc,
            'Test_Acc': test_acc
        })
        
        print(f"  BO {i + 1}: Train={train_acc:.2f}%, Test={test_acc:.2f}%")
    
    best_idx = int(torch.argmax(Y_best * 100.0))
    best_config_vector = X_best[best_idx]
    
    best_config = []
    for layer in range(num_layers):
        layer_matrix = []
        for i in range(num_qubits):
            qubit_gates = [round(float(best_config_vector[(3 * i + j) + layer * num_qubits * 3])) for j in range(3)]
            layer_matrix.append(qubit_gates)
        best_config.append(layer_matrix)
    
    best_acc = float(Y_best[best_idx] * 100.0)
    print(f"Client {client_id + 1} best architecture: {best_acc:.2f}% accuracy")
    
    return best_config, bo_results


def make_selective_copy(global_model, client_config, device):
    """
    Create a local model with client-specific architecture by copying relevant parameters from global model.
    """
    local_model = QFCModel(client_config, global_model.n_wires, global_model.num_classes).to(device)
    
    global_state = global_model.state_dict()
    local_state = local_model.state_dict()
    
    for name in local_state:
        if name in global_state and local_state[name].shape == global_state[name].shape:
            local_state[name] = global_state[name].clone()
        else:
            print(f"Parameter {name} in local model not found in global")

    local_model.load_state_dict(local_state)
    return local_model

def aggregate(local_models, global_model):
    """
    Aggregate local models into global model using parameter-specific averaging.
    Only parameters with matching shapes are aggregated.
    """
    global_state = global_model.state_dict()
    temp_state = {k: torch.zeros_like(v) for k, v in global_state.items()}
    count_state = {k: 0 for k in global_state}
    
    for local_model in local_models:
        local_state = local_model.state_dict()

        for name in local_state:
            if name in global_state and local_state[name].shape == global_state[name].shape:
                temp_state[name] += local_state[name]
                count_state[name] += 1

    total_params = len(global_state)
    updated_params = sum(1 for count in count_state.values() if count > 0)
    print(f"  Parameter update coverage: {updated_params}/{total_params} ({100*updated_params/total_params:.1f}%)")
    
    update_counts = {}
    for count in count_state.values():
        update_counts[count] = update_counts.get(count, 0) + 1
    print(f"  Update distribution: {update_counts}")

    for name in global_state:
        if count_state[name] > 0:
            global_state[name] = (temp_state[name] / count_state[name]).to(global_state[name].dtype)
        else:
            pass

    global_model.load_state_dict(global_state)
    return global_model

def create_global_architecture(client_configs):
    num_layers = len(client_configs[0])
    num_qubits = len(client_configs[0][0])
    num_clients = len(client_configs)
    
    global_matrix = []
    for layer in range(num_layers):
        global_matrix.append([])
        for qubit in range(num_qubits):
            global_matrix[layer].append([])
            for gate in range(3):
                gate_count = 0
                for client_config in client_configs:
                    gate_count += client_config[layer][qubit][gate]
                
                gate_value = 1 if gate_count > 0 else 0
                
                global_matrix[layer][qubit].append(gate_value)
    
    return global_matrix

def federated_learning(client_configs, client_dataflows, num_qubits, num_classes, device, global_epochs, local_epochs):
    print(f"\n=== Starting Federated Learning ===")
    
    global_matrix = create_global_architecture(client_configs)
    print(f"Global architecture created: {global_matrix}")
    
    global_model = QFCModel(global_matrix, num_qubits, num_classes).to(device)
    
    
    temp_models = []
    for client_config in client_configs:
        temp_model = QFCModel(client_config, num_qubits, num_classes).to(device)
        temp_models.append(temp_model)
    global_model = aggregate(temp_models, global_model)
    print("Global model initialized with averaged client architectures")
    
    global_results = []
    
    for global_epoch in range(global_epochs):
        print(f"\n--- Global Epoch {global_epoch + 1}/{global_epochs} ---")
        
        local_models = []
        client_train_accs = []
        
        for client_id, (client_config, client_dataflow) in enumerate(zip(client_configs, client_dataflows)):
            print(f"Training Client {client_id + 1}...")
            
            
            local_model = QFCModel(global_matrix, num_qubits, num_classes).to(device)
            local_model.load_state_dict(global_model.state_dict())
            
            optimizer = optim.Adam(local_model.parameters(), lr=1e-3, weight_decay=1e-4)
            scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=local_epochs)
            
            
            final_train_acc = 0
            best_acc = 0
            patience = 5
            no_improve = 0
            
            for epoch in range(local_epochs * 2):  # Double the epochs
                loss, acc = train(client_dataflow, local_model, device, optimizer)
                final_train_acc = acc
                
                
                if acc > best_acc:
                    best_acc = acc
                    no_improve = 0
                else:
                    no_improve += 1
                    if no_improve >= patience:
                        break
                        
                scheduler.step()
            
            print(f"  Client {client_id + 1} final train accuracy: {final_train_acc:.2f}%")
            local_models.append(local_model)
            client_train_accs.append(final_train_acc)
        
        
        global_model = aggregate(local_models, global_model)
        
        
        global_train_acc = sum(client_train_accs) / len(client_train_accs)
        
        
        global_test_accs = []
        for client_id, client_dataflow in enumerate(client_dataflows):
            client_test_acc = valid_test(client_dataflow, global_model, device)
            global_test_accs.append(client_test_acc)
        global_test_acc = sum(global_test_accs) / len(global_test_accs)
        
        
        global_on_client_train_accs = []
        for client_id, client_dataflow in enumerate(client_dataflows):
            client_train_acc_with_global = evaluate_global_on_client_train(client_dataflow, global_model, device)
            global_on_client_train_accs.append(client_train_acc_with_global)
        
        avg_global_on_train = sum(global_on_client_train_accs) / len(global_on_client_train_accs)
        
        global_results.append({
            'Global_Epoch': global_epoch + 1,
            'Client_Train_Acc_Avg': global_train_acc,  # Average of client local train accuracies
            'Global_on_Client_Train_Avg': avg_global_on_train,  # Global model evaluated on client training data
            'Test_Accuracy': global_test_acc,  # Average of global model on all client test sets
            'Client_Train_Accs': client_train_accs,  # Individual client accuracies
            'Global_on_Client_Train_Accs': global_on_client_train_accs,  # Global model on each client's training data
            'Global_on_Client_Test_Accs': global_test_accs  # Global model on each client's test data
        })
        
        print(f"Average client train accuracy: {global_train_acc:.2f}%")
        print(f"Global model on client train data: {avg_global_on_train:.2f}%")
        print(f"Global model test accuracy: {global_test_acc:.2f}%")
    
    return global_results

def evaluate_global_on_client_train(client_dataflow, global_model, device):
    """Evaluate global model on a specific client's training data."""
    global_model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in client_dataflow["train"]:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = global_model(inputs)
            pred = outputs.argmax(dim=1, keepdim=True)
            correct += pred.eq(targets.view_as(pred)).sum().item()
            total += inputs.size(0)
    
    accuracy = 100. * correct / total
    return accuracy

def main():
    parser = argparse.ArgumentParser(description="Federated Quantum Architecture Search for MNIST")
    parser.add_argument("--seed", type=int, default=44)
    parser.add_argument("--num_clients", type=int, default=3, help="Number of federated clients")
    parser.add_argument("--num_qubits", type=int, default=6, help="Number of qubits in quantum circuits")
    parser.add_argument("--num_layers", type=int, default=3, help="Number of layers in quantum circuits")
    parser.add_argument("--num_classes", type=int, default=10)
    parser.add_argument("--samples_per_client", type=int, default=1200, help="Training samples per client")
    parser.add_argument("--primary_percentage", type=float, default=0.85, help="Percentage of primary class data per client")
    parser.add_argument("--bo_iters", type=int, default=25, help="BO iterations per client")
    parser.add_argument("--epochs_per_model", type=int, default=15, help="Training epochs per architecture evaluation")
    parser.add_argument("--global_epochs", type=int, default=50, help="Number of federated learning rounds")
    parser.add_argument("--local_epochs", type=int, default=5, help="Local training epochs per round")
    parser.add_argument("--out_dir", type=str, default=None, help="Output directory (auto-generated if not specified)")
    
    args = parser.parse_args()
    
    if args.out_dir is None:
        args.out_dir = f"FL_MNIST_{args.num_clients}clients_results"
    
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    print(f"Using device: {device}")
    print(f"Results will be saved to: {args.out_dir}")
    
    all_classes = list(range(10))  # MNIST has 10 classes
    classes_per_client = 5  # Each client gets exactly 5 primary classes
    
    primary_classes_per_client = []
    rng = random.Random(args.seed)
    
    for client_id in range(args.num_clients):
        client_classes = rng.sample(all_classes, classes_per_client)
        primary_classes_per_client.append(sorted(client_classes))
    
    print(f"Primary classes distribution (5 classes per client, overlaps allowed):")
    for i, classes in enumerate(primary_classes_per_client):
        print(f"  Client {i+1}: {classes}")
    
    print(f"Creating federated dataset for {args.num_clients} clients...")
    client_dataflows = get_federated_mnist_dataflows(
        num_clients=args.num_clients,
        samples_per_client=args.samples_per_client,
        primary_classes_per_client=primary_classes_per_client,
        primary_percentage=args.primary_percentage,
        seed=args.seed
    )
    
    print(f"\n{'='*50}")
    print("PHASE 1: CLIENT-SPECIFIC ARCHITECTURE SEARCH")
    print(f"{'='*50}")
    
    client_configs = []
    all_bo_results = []
    
    for client_id in range(args.num_clients):
        config, bo_results = client_architecture_search(
            client_id=client_id,
            dataflow=client_dataflows[client_id],
            num_qubits=args.num_qubits,
            num_layers=args.num_layers,
            num_classes=args.num_classes,
            device=device,
            bo_iters=args.bo_iters,
            epochs_per_model=args.epochs_per_model
        )
        client_configs.append(config)
        all_bo_results.extend(bo_results)
    
    print(f"\n{'='*50}")
    print("PHASE 2: FEDERATED LEARNING")
    print(f"{'='*50}")
    
    global_results = federated_learning(
        client_configs=client_configs,
        client_dataflows=client_dataflows,
        num_qubits=args.num_qubits,
        num_classes=args.num_classes,
        device=device,
        global_epochs=args.global_epochs,
        local_epochs=args.local_epochs
    )
    
    os.makedirs(args.out_dir, exist_ok=True)
    
    bo_df = pd.DataFrame(all_bo_results)
    bo_df.to_csv(os.path.join(args.out_dir, "architecture_search_results.csv"), index=False)
    
    fl_df = pd.DataFrame(global_results)
    fl_df.to_csv(os.path.join(args.out_dir, "federated_learning_results.csv"), index=False)
    
    config_data = []
    for i, config in enumerate(client_configs):
        config_data.append({
            'Client': i + 1,
            'Architecture': str(config),
            'Primary_Classes': str(primary_classes_per_client[i])
        })
    config_df = pd.DataFrame(config_data)
    config_df.to_csv(os.path.join(args.out_dir, "client_architectures.csv"), index=False)
    
    print(f"\n{'='*50}")
    print("RESULTS SUMMARY")
    print(f"{'='*50}")
    
    for i, config in enumerate(client_configs):
        print(f"Client {i + 1} final architecture: {config}")
    
    final_acc = global_results[-1]['Test_Accuracy']
    print(f"\nFinal federated model accuracy: {final_acc:.2f}%")
    print(f"Results saved to: {args.out_dir}/")

if __name__ == "__main__":
    main()
