import torch
import torch.nn as nn
from torch.optim import Adam

class Mixup_value_net(nn.Module):
    """Mixup value network h_theta.

    Attributes:
        input_dim: A input dimension of Mixup value network (= input example dim + label dim + 2 * knn option dim)

    """
    
    def __init__(self, input_dim):
        """Initialize Mixup value network."""
        
        super(Mixup_value_net, self).__init__()
        self.data = []
        self.network = nn.Sequential(
            nn.Linear(input_dim, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, 100), 
            nn.ReLU(),
            nn.Linear(100, 1),
            nn.Sigmoid())
        self.optimizer = Adam(self.parameters(), lr=1e-3)

    def forward(self, inputs):
        """Forward propagation of the Mixup value network.
        
        Args:
            inputs: (x, y, k) triples
              
        Returns:
            selection probability h_theta(x, y, k).
        
        """
        
        selection_prob = self.network(inputs).view(-1, 1)
        return selection_prob

    def put_data(self, item):
        """Save the data (reward, log probability, mean of h values) for the training.
        
        Args:
            item: A (reward, log probability, mean of h values) triple.
            
        """
        self.data.append(item)

    def train_net(self):
        """Train Mixup value network using saved data. Use REINFORCE algorithm, which performs reinforcement learning."""
        self.optimizer.zero_grad()
        for R, prob, mean in self.data:
            RL_term = prob * R
            penalty_term = torch.max(mean-0.99, torch.zeros(1).float().cuda()) + torch.max(0.01 - mean, torch.zeros(1).float().cuda())
            loss = RL_term + penalty_term * 1e3
            loss.backward()
        self.optimizer.step()
        self.data = []