import torch
import torch.nn as nn
from torch.nn import functional as F


class Memory(object):
    def __init__(self, feature_centers, cluster_numbers, assist_value=0.00001):
        self.feature_centers = feature_centers.cuda()
        self.feature_centers = F.normalize(self.feature_centers, dim=0)
        self.cluster_numbers = cluster_numbers
        self.assist_value = assist_value
        print('Cluster Numbers: ', self.cluster_numbers)
        print('Assist Value {}'.format(assist_value))

    def update(self, features, preds, assist_preds=None):
        batch_size = features.size(0)
        use_assist = (assist_preds is not None)
        if use_assist:
            assist_preds = torch.argmax(assist_preds, dim=1)
        for i in range(batch_size):
            pred = int(preds[i])
            # if use_assist:
            #     assist_pred = int(assist_preds[i])
            feature = features[i, :]
            self.cluster_numbers[pred] += 1
            num_features = self.cluster_numbers[pred]
            feature_center = self.feature_centers[pred]
            new_feature_center = (feature_center * (num_features - 1) + feature.detach()) / num_features
            # print('===== {} ====='.format(pred))
            # print('before:', feature_center)
            # print('after:', new_feature_center)
            # print('==========================')
            self.feature_centers[pred] = new_feature_center
            # if not pred == assist_pred:
            #     # print('==== {} ===='.format(i))
            #     # print(pred)
            #     # print(assist_pred)
            #     # print('---------')
            #     self.cluster_numbers[assist_pred] += 1
            #     assist_num_features = self.cluster_numbers[assist_pred]
            #     assist_feature_center = self.feature_centers[assist_pred]
            #     new_assist_feature_center = (
            #                                         assist_feature_center * (assist_num_features - 1) +
            #                                         self.assist_value * feature
            #                                 ) / assist_num_features
            #
            #     self.feature_centers[assist_pred] = new_assist_feature_center

        # print(aaa)

    def get_pred(self, features, get_score=False):
        feature_centers = self.feature_centers.clone().detach()
        # print(self.feature_centers)
        sim = torch.mm(features, feature_centers.t())
        preds = torch.argmax(sim, dim=1)

        if not get_score:
            return preds

        sim_score = F.softmax(sim, dim=1)
        return preds, sim_score