import numpy as np


class Ball(object):
    def __init__(self, feas, label, task_id):
        """
        feas: n*d , where n and d denote the number of samples and features respectively
        """
        self._ball_center = self._init_center(feas)
        self._ball_radius = self._init_radius(feas, self._ball_center)
        self._ball_label = label
        self.num, self.dim = feas.shape
        self._task_id = task_id
    
    def _get_center(self):
        return self._ball_center
    
    def _init_center(self, feas):
        return np.mean(feas, axis=0)


    def _update_center(self, center):
        self._ball_center = center

    def _get_radius(self):
        return self._ball_radius

    def _update_radius(self, radius):
        self.radius = radius
    
    def _init_radius(self, feas, center):
        dist = np.linalg.norm(feas - center, axis=1)
        return np.max(dist)

    def _get_label(self):
        return self._ball_label

    def _get_smp_num(self):
        return self.num
    
    def _get_dim(self):
        return self.dim

    def _get_task_id(self):
        return self._task_id


class BallList(object):
    def __init__(self, blur_r=0.05):
        self._ball_list = []
        self._blur_r = blur_r

    def _add_ball(self, feas, label, task_id):
        self._ball_list.append(Ball(feas, label, task_id))
    
    def _get_balls(self):
        return self._ball_list
    
    def _get_center(self):
        return np.array(list(map(lambda x: x._get_center(), self._ball_list)))

    def _get_radius(self, curr_task):
        return np.array(list(map(lambda x: x._get_radius()*(1+(curr_task - x._get_task_id())*self._blur_r), self._ball_list)))
    
    def _get_radius_orig(self):
        return np.array(list(map(lambda x: x._get_radius(), self._ball_list)))

    def _get_labels(self):
        return np.array(list(map(lambda x: x._get_label(), self._ball_list)))