import torch
from .base import _BaseAggregator


class TM(_BaseAggregator):
    def __init__(self, b):
        self.b = b
        super(TM, self).__init__()

    def __call__(self, inputs):
        if len(inputs) - 2 * self.b > 0:
            b = self.b
        else:
            b = self.b
            while len(inputs) - 2 * b <= 0:
                b -= 1
            if b < 0:
                raise RuntimeError

        if len(inputs) > 10:
            '''
                implementation of Trimmed-Mean for large scale data
            '''
            stacked = torch.stack(inputs, dim=0)
            largest, _ = torch.topk(stacked, b, 0)
            neg_smallest, _ = torch.topk(-stacked, b, 0)
            new_stacked = torch.cat([stacked, -largest, neg_smallest]).sum(0)
            new_stacked /= len(inputs) - 2 * b
            trimmed_mean = new_stacked
        else:
            '''
                implementation of Trimmed-Mean for 2D
            '''
            stacked = torch.stack(inputs, dim=0)
            # sort all gradients
            sorted_stacked, _ = torch.sort(stacked, dim=0)
            # select middle n-2f
            middle_values = sorted_stacked[b:-b]
            # calculate average
            trimmed_mean = torch.mean(middle_values, dim=0)
        return trimmed_mean

    def __str__(self):
        return "Trimmed Mean (b={})".format(self.b)
