import torch.nn.functional as F
import torch


def aggregate_models(global_model, client_models, client_indices, train_loaders):
    """
    Aggregate client models using FedAvg weighted averaging.

    The weight for each client is proportional to the number of samples in its dataset.
    This follows the FedAvg algorithm from the paper:
    w_{t+1} = sum_{k in S_t} (n_k / n) * w_{t+1}^k
    where n_k is the number of samples for client k and n is the total number of samples.
    """

    global_dict = global_model.state_dict()

    n_samples = [len(train_loaders[idx].dataset) for idx in client_indices]
    total_samples = sum(n_samples)

    aggregated_dict = {}

    for key in global_dict.keys():
        aggregated_dict[key] = torch.zeros_like(global_dict[key], dtype=torch.float32)

    for i, idx in enumerate(client_indices):
        client_dict = client_models[idx].state_dict()
        weight = n_samples[i] / total_samples

        for key in global_dict.keys():
            aggregated_dict[key] += client_dict[key].float() * weight

    for key in global_dict.keys():
        global_dict[key] = aggregated_dict[key].to(dtype=global_dict[key].dtype)

    global_model.load_state_dict(global_dict)


def evaluate_model(model, data_loader, device):

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in data_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    return 100 * correct / total


def client_train_with_temp(
    model, train_loader, optimizer, epochs, device, temperature=1.0, print_flag=True
):
    """Train a client model on local data for FedChill."""
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)

            if temperature != 1.0:
                log_probs = F.log_softmax(outputs / temperature, dim=1)
                loss = F.nll_loss(log_probs, targets)
            else:
                loss = F.cross_entropy(outputs, targets)

            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        if print_flag:
            print(
                f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}"
            )
