import numpy as np
import sys
sys.path.append('..')

def stable_softmax(x):
    '''
    x: (n_samples, n_classes)
    '''
    z = x - np.max(x, axis=1, keepdims=True)
    numerator = np.exp(z)
    denominator = np.sum(numerator, axis=1, keepdims=True)
    return numerator / denominator

class ProtoClassifier:
    def __init__(self, X, y):
        '''
        description:
            fit the model
        args:
            X: (n_samples, n_features)
            y: (n_samples,)
            reweight: whether reweight the features
        '''
        self.X = X.astype(np.float64)
        self.y = y
        self.classes_ = np.array(sorted(set(self.y)))
        self.prototypes = []
        for i in self.classes_:
            indices = np.where(self.y == i)
            self.prototypes.append(np.mean(self.X[indices], axis=0))
        self.prototypes = np.stack(self.prototypes)

    def calc_dist(self, x, dist_type='euclidean'):
        '''
        x:(n_samples, n_features)
        dist_type: 'euclidean' or 'cosine' or 'manhattan' or 'dot'
        return dist: (n_samples, n_classes, n_features)
        '''
        x = x.astype(np.float64)
        _x = x.reshape(x.shape[0],1,x.shape[1]) # (n_samples, 1, n_features)
        _prototypes = self.prototypes.reshape(1,*(self.prototypes.shape)) # (1, n_classes, n_features)
        if dist_type == 'euclidean':
            dist = np.linalg.norm(_x - _prototypes, axis=2) # (n_samples, n_classes)
        elif dist_type == 'manhattan':
            dist = np.sum(np.abs(_x - _prototypes), axis=2)
        elif dist_type == 'cosine':
            cos = np.sum(_x * _prototypes,axis=2)/np.maximum((np.linalg.norm(_x,axis=2)*np.linalg.norm(_prototypes,axis=2)),1e-8)#avoid division by zero
            dist = 1-cos #use 1-cosine similarity as distance # (n_samples, n_classes)
        else:
            raise ValueError(f'unsupported dist_type:{dist_type}')
        return dist
    def predict(self, x, dist_type='euclidean', predict_type='nearest_prototype', return_probs=False, excluded_if_same_dist=False):
        '''
        description:
            predict the class of x
        args:
            x: (n_samples, n_features, emb_dim)
            dist_type: 'euclidean' or 'cosine'
            predict_type: 'nearest_prototype' or 'optimal_transport' or 'nearest_neighbour'
            return_probs: whether return the probs of each class
        return:
            if return_probs is False, return the predicted class of x
            if return_probs is True, return the predicted class of x and the probs of each class
        '''
        if predict_type == 'nearest_prototype':
            dist = self.calc_dist(x, dist_type) # (n_samples, n_classes)
            class_idx = np.argmin(dist, axis=1)
            pred_class = self.classes_[class_idx]
            if excluded_if_same_dist:
                dist = np.round(dist, 5)
                sorted_dist = np.sort(dist, axis=1)
                same_dist_mask = (sorted_dist[:,0] == sorted_dist[:,1])
                if isinstance(self.classes_[0],str):
                    assert 'class_not_found' not in self.classes_, 'class_not_found is reserved for class not found'
                    pred_class[same_dist_mask] = 'class_not_found'
                elif isinstance(self.classes_[0],np.int64):
                    assert -1 not in self.classes_, '-1 is reserved for class not found'
                    pred_class[same_dist_mask] = -1
                else:
                    raise ValueError(f'unsupported class type:{type(self.classes_[0])}')
            if return_probs:
                probs = stable_softmax(-dist)
                return pred_class, probs
            return pred_class

            