import torch
import os
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np


class SoftSelector(nn.Module):
    def __init__(self, targets, subset_size=500, lambda1=0.1, lambda2=0.1, num_classes=10):
        super(SoftSelector, self).__init__()
        self.pi = nn.Sequential(nn.Linear(2048                  # data input
                                            +16                 # h.mean()
                                            +num_classes        # one-hot y_true
                                            +num_classes,       # sa_feature
                                            256),
                                nn.LeakyReLU(),
                                nn.Linear(256, 16), nn.LeakyReLU(), nn.Linear(16, 1), nn.Sigmoid())
        self.lambda1 = lambda1
        self.lambda2 = lambda2
        self.subset_size = subset_size
        self.tgt_indices = []
        self.num_classes = num_classes
        self.data_len = len(targets)
        for i3 in range(num_classes):
            self.tgt_indices.append(torch.where(targets == i3))

    def forward(self,  h, m_logits, x, y, loss_xy):
        input = torch.cat([h.mean(dim=0).unsqueeze(0).repeat(self.data_len,1), x, y, m_logits], dim=1)
        pi_xy = self.pi(input).squeeze()
        top_k, _ = torch.topk(pi_xy,self.subset_size)
        obj_fn = torch.dot(pi_xy, loss_xy) + self.lambda1*(pi_xy.sum()-self.subset_size).norm() - self.lambda2*self.entropy(pi_xy)
        return obj_fn, pi_xy, top_k
            
    def entropy(self, pi_xy):
        return torch.distributions.Categorical(probs=pi_xy.softmax(dim=0)).entropy()