import torch
import torch.nn as nn
import pdb
import numpy as np



class CacheMemory(nn.Module):
    """
    mode, "average" or "max"
    """

    def __init__(self,source_data,target_data,mode="average",T=1):
        super().__init__()
        if source_data is not None:
            self.source_data = torch.FloatTensor(source_data)
        if target_data is not None:
            self.target_data = torch.FloatTensor(target_data)
        self.mode = mode
        self.T = T

    def forward(self,x,pre_feats=False):
        ## x shape is [batch_size,n_features]
        ## source_data shape is [tot_data, n_features]
        ## target_data shape is [tot_data, n_features_target]

        self.source_data = self.source_data.to(x.device)
        self.target_data = self.target_data.to(x.device)

        # pdb.set_trace()

        ## norm
        self.source_data = self.source_data/torch.norm(self.source_data,dim=1,keepdim=True)
        x = x/torch.norm(x,dim=1,keepdim=True)

        sim_mat = torch.matmul(x,self.source_data.t())  ## [batch_size,tot_data]
        coff = torch.softmax(sim_mat/self.T,dim=1)

        # print(torch.std(coff[0]))
        if self.mode == "average":
            out = torch.matmul(coff,self.target_data)

        if self.mode == "max":
            index = torch.argmax(coff,dim=1)
            batch_size = x.shape[0]
            coff_ = torch.zeros_like(coff)
            coff_[range(batch_size),index] = 1
            out = torch.matmul(coff_,self.target_data)

        if pre_feats:
            return out, coff
        else:
            return out









##
