import torch
from torch import nn
import numpy as np
from collections import OrderedDict

class Net(nn.Module):
    def __init__(self,p_dim):
        super(Net, self).__init__()
        self.fc_unit = nn.Sequential(OrderedDict([
            ('fc1',nn.Linear(p_dim, 20)),
            ('rl1',nn.ReLU()),
            ('fc2',nn.Linear(20, 20)),
            ('rl2',nn.ReLU()),
            ('fc3',nn.Linear(20, 1))
        ]))
        
      
    def forward(self, x):
        output = self.fc_unit(x)
        return output

class Net1_dp(nn.Module):
    def __init__(self, p_dim, hid_dim, dropout,is_training = True):
        super(Net1_dp, self).__init__()
        self.training = is_training
        self.lin1 = nn.Linear(p_dim, hid_dim)
        self.lin2 = nn.Linear(hid_dim, 1)
        self.relu = nn.ReLU()
        self.dropout = dropout
      
    def forward(self, x):
        H1 = self.relu(self.lin1(x))
        if self.training == True:
            H1 = dropout_layer(H1, self.dropout)
        output = self.lin2(H1)   
        return output  

class Net2_dp(nn.Module):
    def __init__(self,p_dim, hid_dim, dropout, is_training = True):
        super(Net2_dp, self).__init__()
        self.training = is_training
        self.lin1 = nn.Linear(p_dim, hid_dim)
        self.lin2 = nn.Linear(hid_dim, hid_dim)
        self.lin3 = nn.Linear(hid_dim, hid_dim)
        self.lin4 = nn.Linear(hid_dim, hid_dim)
        self.lin5 = nn.Linear(hid_dim, 1)
        self.relu = nn.ReLU()
        self.dropout = dropout
      
    def forward(self, x):
        H1 = self.relu(self.lin1(x))
        if self.training == True:
            H1 = dropout_layer(H1, self.dropout)
        H2 = self.relu(self.lin2(H1))
        if self.training == True:
            H2 = dropout_layer(H2, self.dropout)
        H3 = self.relu(self.lin3(H2))
        if self.training == True:
            H3 = dropout_layer(H3, self.dropout)
        H4 = self.relu(self.lin4(H3))
        if self.training == True:
            H4 = dropout_layer(H4, self.dropout)
        output = self.lin5(H4)
        return output 


def dropout_layer(x, dropout):
    assert 0 <= dropout <= 1
    if dropout == 1:
        return torch.zeros_like(x)
    if dropout == 0:
        return x
    mask = (torch.rand(x.shape) > dropout).float()
    return mask * x/(1.0 - dropout)


def SIS(x, y, percent):
    p = x.shape[1]
    #n = x.shape[0]
    cor_xy = np.zeros(p)
    y = y.reshape(-1)
    for i in range(p):
        x_temp = x[:,i]
        xy = (x_temp*y).sum()
        xx = (x_temp**2).sum()
        cor_xy[i] = np.abs(xy/xx)
    top = int(np.floor(percent*p))
    order_top = np.argsort(-cor_xy)
    index = order_top[0:top]
    return index
