import torch
import torch.nn.functional as F
import copy
from models.resnet import ResNet18

class FedAvgServer:

    def __init__(self, model_args, device):
        self.device = device
        # The server holds a global model, which is updated by averaging client models.
        self.global_model = ResNet18(**model_args).to(self.device)

    def aggregate_weights(self, client_models, client_sample_counts=None):
        if not client_models:
            return

        # Zero out the global model's weights
        global_dict = self.global_model.state_dict()
        for k in global_dict.keys():
            global_dict[k] = torch.zeros_like(global_dict[k])

        # Compute aggregation weights
        if client_sample_counts is not None and len(client_sample_counts) == len(client_models):
            # Weighted aggregation based on sample counts
            total_samples = sum(client_sample_counts)
            if total_samples > 0:
                alpha_weights = [count / total_samples for count in client_sample_counts]
            else:
                alpha_weights = [1.0 / len(client_models)] * len(client_models)
        else:
            # Uniform weights
            alpha_weights = [1.0 / len(client_models)] * len(client_models)

        # Weighted sum of client models
        for model, alpha in zip(client_models, alpha_weights):
            model_dict = model.state_dict()
            for k in global_dict.keys():
                param = model_dict[k]
                # Only aggregate floating point parameters
                if param.dtype.is_floating_point:
                    global_dict[k] += alpha * param
        
        # For non-float parameters (e.g., BatchNorm's num_batches_tracked), use the first model's value
        for k in global_dict.keys():
            if not global_dict[k].dtype.is_floating_point:
                global_dict[k] = client_models[0].state_dict()[k].clone()

        # Load the averaged weights into the global model
        self.global_model.load_state_dict(global_dict)

    def get_global_model(self):
        return self.global_model


class FedAvgClient:
    def __init__(self, client_id, model_args, device, lr=1e-3):
        self.client_id = client_id
        self.device = device
        # Each client has its own instance of the model
        self.model = ResNet18(**model_args).to(self.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)

    def set_weights(self, global_model_state_dict):

        self.model.load_state_dict(copy.deepcopy(global_model_state_dict))

    def train_task(self, train_loader, epochs):

        self.model.train()
        for epoch in range(epochs):
            for x, y in train_loader:
                # Skip batches with only 1 sample to avoid BatchNorm errors
                if x.size(0) < 2:
                    continue
                    
                x, y = x.to(self.device), y.to(self.device)
                
                # Standard classification loss
                logits = self.model(x)
                loss = F.cross_entropy(logits, y)
                
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

class FedProxClient(FedAvgClient):

    def __init__(self, client_id, model_args, device, lr=1e-3, mu=0.01):
        super().__init__(client_id, model_args, device, lr)
        self.mu = mu
        self.global_model_params = None

    def set_weights(self, global_model_state_dict):
        super().set_weights(global_model_state_dict)
        # Keep a copy of the global model's parameters for the proximal term
        self.global_model_params = [param.detach().clone() for param in self.model.parameters()]

    def train_task(self, train_loader, epochs):

        self.model.train()
        if self.global_model_params is None:
            raise ValueError("Global model parameters not set. Call set_weights first.")
            
        for epoch in range(epochs):
            for x, y in train_loader:
                # Skip batches with only 1 sample to avoid BatchNorm errors
                if x.size(0) < 2:
                    continue
                    
                x, y = x.to(self.device), y.to(self.device)
                
                logits = self.model(x)
                loss_ce = F.cross_entropy(logits, y)

                # FedProx proximal term
                prox_term = 0.0
                for local_param, global_param in zip(self.model.parameters(), self.global_model_params):
                    prox_term += (local_param - global_param).pow(2).sum()
                
                loss = loss_ce + (self.mu / 2) * prox_term
                
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step() 