from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

class RandomNetworkAdversary(nn.Module):

    def __init__(self, num_envs, in_dims, out_dims, softmax_bins, device):
        super(RandomNetworkAdversary, self).__init__()

        """
        Class to add random action to the action generated by the policy. 
        The output is binned to 32 bins per channel and we do softmax over 
        these bins to figure out the most likely joint angle.

        Note: OpenAI et al. 2019 found out that if they used a continuous space 
              and a tanh non-linearity, actions would always be close to 0. 
              Section B.3 https://arxiv.org/abs/1910.07113

        Q: Why do we need dropouts here? 

        A: If we were using a CPU-based simulator as in OpenAI et al. 2019, we 
           will use a different RNA network for different CPU. However, 
           this is not feasible for a GPU-based simulator as that would mean 
           creating N_envs RNA networks which will overwhelm the GPU-memory. 
           Therefore, dropout is a nice approximation of this by re-sampling 
           weights of the same neural network for each different env on the GPU. 
        """

        self.in_dims  = in_dims 
        self.out_dims = out_dims
        self.softmax_bins = softmax_bins
        self.num_envs = num_envs

        self.device = device 
       
        self.num_feats1 = 512
        self.num_feats2 = 1024

        # Sampling random probablities for dropout masks 
        dropout_probs = torch.rand((2, ))

        # Setting up the RNA neural network here    

        # First layer

        self.fc1 = nn.Linear(in_dims, self.num_feats1).to(self.device)

        self.dropout_masks1 = torch.bernoulli(torch.ones((self.num_envs, \
            self.num_feats1)), p=dropout_probs[0]).to(self.device)

        self.fc1_1 = nn.Linear(self.num_feats1, self.num_feats1).to(self.device)


        # Second layer 
        self.fc2 = nn.Linear(self.num_feats1, self.num_feats2).to(self.device)

        self.dropout_masks2 = torch.bernoulli(torch.ones((self.num_envs, \
            self.num_feats2)), p=dropout_probs[1]).to(self.device)

        self.fc2_1 = nn.Linear(self.num_feats2, self.num_feats2).to(self.device)


        # Last layer 
        self.fc3 = nn.Linear(self.num_feats2, out_dims*softmax_bins).to(self.device)

        # This is needed to reset weights and dropout masks 
        self._refresh()

    def _refresh(self):

        self._init_weights()
        self.eval()
        self.refresh_dropout_masks()

    def _init_weights(self):

        print('initialising weights for random network')

        nn.init.kaiming_uniform_(self.fc1.weight)
        nn.init.kaiming_uniform_(self.fc1_1.weight)
        nn.init.kaiming_uniform_(self.fc2.weight)
        nn.init.kaiming_uniform_(self.fc2_1.weight)
        nn.init.kaiming_uniform_(self.fc3.weight)

        return

    def refresh_dropout_masks(self):

        dropout_probs = torch.rand((2, ))

        self.dropout_masks1 = torch.bernoulli(torch.ones((self.num_envs, self.num_feats1)), \
            p=dropout_probs[0]).to(self.dropout_masks1.device)

        self.dropout_masks2 = torch.bernoulli(torch.ones((self.num_envs, self.num_feats2)), \
            p=dropout_probs[1]).to(self.dropout_masks2.device)

        return
   
    def forward(self, x):

        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc1_1(x)
        x = self.dropout_masks1 * x 

        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc2_1(x)
        x = self.dropout_masks2 * x 

        x = self.fc3(x)

        x = x.view(-1, self.out_dims, self.softmax_bins)
        output = F.softmax(x, dim=-1)

        # We have discretised the joint angles into bins 
        # Now we pick up the bin for each joint angle 
        # corresponding to the highest softmax value / prob.

        return output

if __name__ == "__main__":

    num_envs = 1024
    RNA = RandomNetworkAdversary(num_envs=num_envs, in_dims=16, out_dims=16, softmax_bins=32, device='cuda')

    x = torch.tensor(torch.randn(num_envs, 16).to(RNA.device))
    y = RNA(x)
    import ipdb; ipdb.set_trace()

    

