from u_base import *
from MMLP import *
from dataloader import *

class GradientAttributor:
    def __init__(self, model):
        self.model = model.eval()

    def attribute(self, inputs, label_batch_size=1):
        inputs = inputs.clone().detach().to(device).requires_grad_(True)
        outputs = self.model(inputs)

        all_grads = []
        for i in range(0, c, label_batch_size):
            thisout = outputs[:, i:i+label_batch_size]
            grads = torch.autograd.grad(
                outputs=thisout,
                inputs=inputs,
                grad_outputs=torch.ones_like(thisout),
                create_graph=False,
                retain_graph=(i + label_batch_size < c)
            )[0]
            all_grads.append(grads.unsqueeze(0) if label_batch_size == 1 else grads.permute(1, 0, 2))
        scores = torch.cat(all_grads, dim=0) # (c,n,d)
        scores = scores.mean(dim=1).cpu().numpy() # (c,d)
        return scores
    
    def attribute_(self, inputs, batch_num=1, sample_batch_size=np.power(2, 31, dtype=np.int64)):
        scores = np.zeros((c,d)) # (c,d)
        if(batch_num==1):
            return self.attribute(inputs)
        for i in range(batch_num):
            thisinput = inputs[i*sample_batch_size:(i+1)*sample_batch_size]
            scores += self.attribute(thisinput)
        return scores/batch_num
    
def top_features(attributions, part=0.1):
    c,d = np.shape(attributions)
    k = int(d*part)
    indices_list = []
    for i in range(c):
        scores = attributions[i]
        topk_indices = np.argsort(scores)[-k:][::-1]
        indices_list.append(topk_indices.tolist())
    return indices_list

if __name__ == '__main__':
    datasnames = ['20NG', 'Corel5k', 'HumanGO', 'Ohsumed', 'TMC2007_500', 'Yelp', 'Bibtex', 'Bookmarks', 'Delicious', 'Imdb']
    num_label = [20, 374, 14, 23, 22, 5, 159, 208, 983, 28]
    num_train = [12933, 3332, 2053, 9301, 19140, 7240, 4941, 58862, 10743, 81363]
    path = "data/"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # the supremum memory we used should lower than your GPU
    suf_comp = np.power(2,29, dtype=np.int64) # 2^30 int for 4 GB Video Memory
    selected_ratio = 0.1
    for dataIdx in range(10):
        print(datasnames[dataIdx])
        Xtr,Ytr,Xte,Yte = readData(path+datasnames[dataIdx], num_label[dataIdx], num_train[dataIdx])
        n,d = np.shape(Xtr)
        m,c = np.shape(Yte)
        print(dataIdx,"==========",n,d,m,c)

        model = TabularMLP(input_dim=d, num_classes=c)
        t0 = time()
        model.trainMLP(Xtr,Ytr)
        comp = n*d*c
        if comp<=suf_comp:
            numbatch = 1
        else:
            numbatch = int(comp/suf_comp)
        GA = GradientAttributor(model)
        scores = GA.attribute_(torch.tensor(Xtr, dtype=torch.float32), numbatch, int(suf_comp/d/c)) # (c,d)
        scores = np.power(scores,2)
        indices_list = top_features(scores, selected_ratio) # (c, 0.1d)
        t1 = time()
        predition = []
        for i in range(c):
            learner = Baser(basemode='dt')
            learner.fit(Xtr[:,indices_list[i]],Ytr[:,i])
            out = learner.predict_proba(Xte[:,indices_list[i]])[:,1]
            predition.append(out)
        predition = np.transpose(predition)
        saveResult(datasnames[dataIdx], 'FALS', evaluate(predition, Yte), t1-t0, time()-t1)
