from Optimizer import * 
import math
import torch

def trimmed_mean(tensor, trim_ratio=0.1):
    """
    Compute the trimmed mean by discarding the smallest and largest 
    `trim_ratio` proportion of elements, then calculating the mean.

    Args:
        tensor: Input parameter tensor (already stacked).
        trim_ratio: Proportion to trim from both ends, default is 0.1.

    Returns:
        Mean after trimming.
    """
    trim_count = int(tensor.size(0) * trim_ratio)
    if trim_count > 0:
        sorted_tensor, _ = torch.sort(tensor, dim=0)
        trimmed_tensor = sorted_tensor[trim_count:-trim_count]
    else:
        trimmed_tensor = tensor
    return trimmed_tensor.mean(dim=0)


class RobustOptimizer(DFLOptimizerInit):
    def __init__(self, neighbors, lr_constant, 
                 model_type='lenet5', 
                 n_classes=10, n_workers=10,
                 epochs=10, device='cpu', random_state=None,
                 input_dim=None,
                 pretrained=False,
                 custom_init=True, agg_method='median', byz_ratio=0.1):
        super().__init__(
            neighbors=neighbors, 
            lr_constant=lr_constant,  # same learning rate for all
            model_type=model_type,
            n_classes=n_classes, n_workers=n_workers,
            epochs=epochs, device=device,
            random_state=random_state, input_dim=input_dim,
            pretrained=pretrained, custom_init=custom_init
        )
        self.agg_method = agg_method
        self.byz_ratio = byz_ratio  # for trimmed mean calculation

    def _trimmed_mean(self, tensor, n_neighbors=1):
        trim_ratio = min(self.byz_ratio, 0.49)
        return trimmed_mean(tensor, trim_ratio=trim_ratio)

    def _aggregate_robust(self, m=0):        
        param_lists = [self.param_lists[i] for i in self.neighbors[m]]
        n_neighbors = len(param_lists)
        if self.agg_method == 'median':
            aggregated = [
                torch.median(torch.stack(params, dim=0), dim=0)[0] 
                for params in zip(*param_lists)]
        elif self.agg_method == 'trimmed_mean':
            aggregated = [
                self._trimmed_mean(torch.stack(params, dim=0), n_neighbors=n_neighbors)
                for params in zip(*param_lists)]
        return aggregated

    def _update(self, Xs, ys):
        """
        Perform one DFL update step for all workers (first aggregate, then fit_onestep).
        """
        loss_vals = []
        self.param_lists = self.get_parameters()
        for m, (Xm, ym) in enumerate(zip(Xs, ys)):
            # 1) Aggregate from neighbors
            agg_params = self._aggregate_robust(m=m)
            self.models_[m].set_parameters(agg_params)
            # 2) Local update
            loss_val = self.models_[m].fit_onestep(Xm, ym)
            loss_vals.append(loss_val)

        loss_tensor = torch.stack(loss_vals)  # shape: [n_workers]
        avg_loss = loss_tensor.sum() / self.n_workers
        return avg_loss


class BRIDGE(RobustOptimizer):
    def __init__(self, neighbors, lr_constant, 
                 model_type='lenet5', 
                 n_classes=10, n_workers=10,
                 epochs=10, device='cpu', random_state=None,
                 input_dim=None,
                 pretrained=False,
                 custom_init=True, agg_method='median', byz_ratio=0.1):
        super().__init__(
            neighbors=neighbors, 
            lr_constant=lr_constant,
            model_type=model_type,
            n_classes=n_classes, n_workers=n_workers,
            epochs=epochs, device=device,
            random_state=random_state, input_dim=input_dim,
            pretrained=pretrained, custom_init=custom_init,
            agg_method=agg_method, byz_ratio=byz_ratio
        )

    def _aggregate_robust(self, m=0):
        param_lists = [self.models_[i].coef_past for i in self.neighbors[m]]
        n_neighbors = len(param_lists)
        if self.agg_method == 'median':
            aggregated = [
                torch.median(torch.stack(params, dim=0), dim=0)[0] 
                for params in zip(*param_lists)]     
        elif self.agg_method == 'trimmed_mean':
            aggregated = [
                self._trimmed_mean(torch.stack(params, dim=0), n_neighbors=n_neighbors)
                for params in zip(*param_lists)]
        return aggregated    

    def _update(self, Xs, ys):
        """
        Perform one DFL update step for all workers (update first, then aggregate).
        """
        loss_vals = []
        for m, (Xm, ym) in enumerate(zip(Xs, ys)):
            model = self.models_[m]
            loss_val = model.fit_onestep(Xm, ym)
            coef_half = model.get_parameters()
            loss_vals.append(loss_val)
            # Aggregate
            agg_params = self._aggregate_robust(m=m)
            coef_current = [ch - cp + agg for ch, cp, agg in zip(coef_half, model.coef_past, agg_params)]
            model.coef_past = coef_current
            model.set_parameters(coef_current)

        loss_tensor = torch.stack(loss_vals)
        avg_loss = loss_tensor.sum() / self.n_workers
        return avg_loss

    def refit(self, Xs, ys, X_val=None, y_val=None, print_freq=100):
        """
        Run multiple update rounds after initialization.
        """
        for model in self.models_:
            model.train()
            model.coef_past = model.get_parameters()
        for epoch in range(self.epochs):
            running_loss = 0.0
            start_time = time.time()
            loss_ = self._update(Xs, ys)
            running_loss += loss_

            if (epoch + 1) % print_freq == 0:
                if X_val is not None:
                    accs, losss = self.evaluate(X_val, y_val)
                    self.history_['acc'].append(accs)
                    self.history_['loss'].append(losss)
                loss_val = (running_loss / print_freq).item()
                t_elapsed = time.time() - start_time
                remaining = t_elapsed * (self.epochs - epoch - 1)
                print(f"\rEpoch [{epoch+1}/{self.epochs}], Avg Loss: {loss_val:.4f}, TC: {t_elapsed:3.2f}s, ETA: {remaining/60:3.2f}min", end='')
        return self


class GradientTrack(RobustOptimizer):
    def __init__(
        self, neighbors, lr_constant, 
        model_type='lenet5', 
        n_classes=10, n_workers=10,
        epochs=10, device='cpu', random_state=None,
        input_dim=None,
        pretrained=False,
        custom_init=True,
        agg_method='median',
        byz_ratio=0.1
    ):
        super().__init__(
            neighbors=neighbors, 
            lr_constant=lr_constant,
            model_type=model_type,
            n_classes=n_classes,
            n_workers=n_workers,
            epochs=epochs,
            device=device,
            random_state=random_state,
            input_dim=input_dim,
            pretrained=pretrained,
            custom_init=custom_init,
            agg_method=agg_method,
            byz_ratio=byz_ratio
        )

    def _aggregate_robust(self, grads_or_params, m=0):
        """
        Robust aggregation for the m-th worker's neighbors.
        grads_or_params: a list of length = n_workers, each element 
        being a list of parameters or gradients from a worker.
        """
        param_lists = [grads_or_params[i] for i in self.neighbors[m]]
        n_neighbors = len(param_lists)
        if self.agg_method == 'median':
            aggregated = [
                torch.median(torch.stack(params, dim=0), dim=0)[0] 
                for params in zip(*param_lists)
            ]
        elif self.agg_method == 'trimmed_mean':
            aggregated = [
                self._trimmed_mean(torch.stack(params, dim=0), n_neighbors=n_neighbors)
                for params in zip(*param_lists)
            ]
        elif self.agg_method == 'mean':  # simple mean option
            aggregated = [
                torch.mean(torch.stack(params, dim=0), dim=0) 
                for params in zip(*param_lists)
            ]
        else:
            raise ValueError(f"Unknown aggregation method: {self.agg_method}")
        return aggregated

    def _update(self, Xs, ys):
        """
        One round of update for each worker:
          1) Aggregate parameters and subtract old gradients -> get new parameters
          2) Compute local gradient with new parameters
          3) Calculate g_k^i + grad_current - grad_past
          4) Aggregate the result among neighbors to get g_{k+1}^i
          5) Update coefs_half and G
        """
        loss_vals = []
        new_params = []

        # Step 1: Parameter Aggregation and Update
        for m in range(self.n_workers):
            param_agg = self._aggregate_robust(self.coefs_half, m=m)
            updated_param = [
                p_agg - self.lr_constant * g 
                for (p_agg, g) in zip(param_agg, self.models_[m].G)
            ]
            new_params.append(updated_param)

        # Step 2: Compute Gradients
        local_temp = []
        for m, (Xm, ym) in enumerate(zip(Xs, ys)):
            self.models_[m].set_parameters(new_params[m])
            grad_current, loss_val = self.models_[m].compute_gradients_loss(Xm, ym)
            loss_vals.append(loss_val)

            tmp = [
                G + gc - gp
                for (G, gc, gp) in zip(self.models_[m].G, grad_current, self.models_[m].grad_past)
            ]
            local_temp.append(tmp)

            self.models_[m].grad_past = [gc.clone() for gc in grad_current]

        # Step 3: Aggregate Gradients
        new_g = []
        for m in range(self.n_workers):
            g_agg = self._aggregate_robust(local_temp, m=m)
            new_g.append(g_agg)

        # Step 4: Write Back
        for m in range(self.n_workers):
            self.coefs_half[m] = new_params[m]
            self.models_[m].G = new_g[m]

        loss_tensor = torch.stack(loss_vals)
        avg_loss = loss_tensor.mean()
        return avg_loss

    def fit(self, Xs, ys, X_val=None, y_val=None, print_freq=100):
        """
        First-time fitting. Initializes model, coefs_half, G, and grad_past.
        """
        self._initialize_models()
        self._initialize_history()

        self.coefs_half = []
        for m, model in enumerate(self.models_):
            model.train()
            self.coefs_half.append(model.get_parameters())
            g_init, _ = model.compute_gradients_loss(Xs[m], ys[m])
            model.G = g_init
            model.grad_past = [g.clone() for g in g_init]

        self.refit(Xs, ys, X_val=X_val, y_val=y_val, print_freq=print_freq)
        return self



def clip_torch(z, tau):
    """
    Perform L2 norm clipping on each row of the 2D tensor z. tau is the clipping threshold.
    z: shape = (n, d)
    """
    if z.dim() == 0:
        z = z.unsqueeze(0)

    # Compute the L2 norm of z
    z_norm = torch.norm(z)

    # If norm exceeds tau, scale it proportionally; otherwise keep it unchanged
    if z_norm <= tau:
        return z
    else:
        return z * (tau / z_norm)

def clipped_gossip(w_m, neighbor_params, delta_max=0.9):
    """
    Perform clipping-based aggregation on "self + neighbors" parameters.
    All inputs are lists of Tensors with the same structure.
    
    Args:
      w_m (list[Tensor]):
          Local model parameter list; each element is a Tensor.
      neighbor_params (list[list[Tensor]]):
          List of neighbor parameter lists, each with the same structure as w_m.
      delta_max (float):
          Hyperparameter controlling clipping ratio.
        
    Returns:
      aggregated_list (list[Tensor]): Aggregated parameters, with the same structure as w_m.
    """
    # Record shapes and sizes of individual tensors; flatten w_m to 1D tensor
    shapes = [p.shape for p in w_m]
    sizes  = [p.numel() for p in w_m]
    w_m_flat = torch.cat([p.view(-1) for p in w_m], dim=0)

    # Flatten each neighbor's parameters in the same order as w_m and stack into (n_neighbors, total_dim)
    neighbor_stack = torch.stack(
        [torch.cat([p.view(-1) for p in nbr], dim=0) for nbr in neighbor_params],
        dim=0)

    # Compute differences and squared distances to w_m_flat
    diff = neighbor_stack - w_m_flat
    dists = (diff ** 2).sum(dim=1)
    n_neighbors = len(neighbor_params)
    
    # Sort distances in ascending order
    sorted_dists, _ = torch.sort(dists)
    
    # Select top k distances;
    k = max(int(math.floor(n_neighbors * (1 - delta_max))), 1)
    selected_dists = sorted_dists[:k]
    # tau = sqrt(mean of these k distances)
    tau = torch.sqrt(selected_dists.mean())

    # Clip differences, add back w_m_flat, then average across neighbors
    clipped_diff = clip_torch(diff, tau) + w_m_flat
    aggregated_flat = clipped_diff.mean(dim=0)
    
    # Reshape 1D tensor back to the structure of w_m
    parts = torch.split(aggregated_flat, sizes)
    aggregated_list = [part.view(shape) for part, shape in zip(parts, shapes)]
    return aggregated_list



class ClippedGossip(RobustOptimizer):
    def __init__(self, neighbors, lr_constant, 
                 model_type='lenet5', 
                 n_classes=10, n_workers=10,
                 epochs=10, device='cpu', random_state=None,
                 input_dim=None,
                 pretrained=False,
                 custom_init=True, byz_ratio=0.1):
        super().__init__(
                    neighbors=neighbors, 
                    lr_constant=lr_constant,  # Same learning rate for all workers
                    model_type=model_type,
                    n_classes=n_classes, n_workers=n_workers,
                    epochs=epochs, device=device,
                    random_state=random_state, input_dim=input_dim,
                    pretrained=pretrained, custom_init=custom_init,
                    byz_ratio=byz_ratio)
        
    def _aggregate_robust(self, param_lists, m=0):   
        w_m = param_lists[m]
        neighbor_params = [param_lists[i] for i in self.neighbors[m]]
        aggregated = clipped_gossip(w_m, neighbor_params, delta_max=2 * self.byz_ratio)
        return aggregated
    
    def _update(self, Xs, ys):
        """
        Perform one round of DFL update for all workers (aggregate first, then fit_onestep)
        """
        loss_vals = []
        param_lists = self.get_parameters()
        for m, (Xm, ym) in enumerate(zip(Xs, ys)):
            # Aggregation
            model = self.models_[m]
            agg_params = self._aggregate_robust(param_lists, m=m)
            model.set_parameters(agg_params)
            # Local update
            loss_val = self.models_[m].fit_onestep(Xm, ym)
            loss_vals.append(loss_val)

        # Convert list to tensor and compute weighted average using broadcasting
        loss_tensor = torch.stack(loss_vals)  # shape: [n_workers]
        avg_loss = loss_tensor.sum() / self.n_workers
        return avg_loss

    """
    This method uses momentum, so we modify initialization settings accordingly.
    """
    def _initialize_models(self):
        # If self.lr is a scalar instead of a list, then all workers use the same learning rate
        if isinstance(self.lr, (float, int)):
            lr_list = [self.lr] * self.n_workers
        else:
            lr_list = self.lr
        self.models_ = []
        for m in range(self.n_workers):
            model = Optimizer(
                model_type=self.model_type,
                num_classes=self.n_classes,
                lr=lr_list[m],
                device=self.device,
                random_state=self.random_state,
                input_dim=self.input_dim,
                pretrained=self.pretrained,
                custom_init=self.custom_init,
                momentum=0.9)
            self.models_.append(model)
        return self

