import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import Misc
import sys
import os

### Best estimate using X,A explicitly
class Best_Estimate(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Best_Estimate, self).__init__()
        self.relu = nn.ReLU()
        self.linearx = nn.Linear(input_size,1)

    def forward(self, x, a, task, phase, group):
        if task == 'classification':
            if phase == 0:   # training phase
                if group == None:
                    X = torch.cat((x,a), dim = 1)
                    out_1 = self.linearx(X)
                    y_pred = torch.sigmoid(out_1)
                    return y_pred
            
            if phase == 1:   # DFP
                A_F = torch.cat([torch.tensor([1, 0])] * a.shape[0]).reshape(a.shape[0], -1)
                A_M = torch.cat([torch.tensor([0, 1])] * a.shape[0]).reshape(a.shape[0], -1)
                X_F = torch.cat((x, A_F), dim = 1)
                X_M = torch.cat((x, A_M), dim = 1)
                
                if group == 'Female':
                    out_1 = self.linearx(X_F)
                    y_pred = torch.sigmoid(out_1)
                    return y_pred

                if group == 'Male':
                    out_1 = self.linearx(X_M)
                    y_pred = torch.sigmoid(out_1)
                    return y_pred
                
        if task == 'regression':
            if phase == 0:   # training phase
                if group == None:
                    X = torch.cat((x,a), dim = 1)
                    out_1 = self.linearx(X)
                    y_pred = out_1
                    return y_pred
            
            if phase == 1:   # DFP
                A_F = torch.cat([torch.tensor([1, 0])] * a.shape[0]).reshape(a.shape[0], -1)
                A_M = torch.cat([torch.tensor([0, 1])] * a.shape[0]).reshape(a.shape[0], -1)
                X_F = torch.cat((x, A_F), dim = 1)
                X_M = torch.cat((x, A_M), dim = 1)
                
                if group == 'Female':
                    out_1 = self.linearx(X_F)
                    y_pred = out_1
                    return y_pred

                if group == 'Male':
                    out_1 = self.linearx(X_M)
                    y_pred = out_1
                    return y_pred



### Unawareness using X only
class Unawareness(torch.nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Unawareness, self).__init__()
        self.linearx = torch.nn.Linear(input_size, 1)
        self.linear1 = torch.nn.Linear(input_size, hidden_size)
        self.linear2 = torch.nn.Linear(hidden_size,hidden_size)
        self.linear3 = torch.nn.Linear(hidden_size,hidden_size)
        self.linear4 = torch.nn.Linear(hidden_size,1)
        self.relu = torch.nn.ReLU()
    
    def forward(self, x, task, phase, group):
        if task == 'classification':
            if phase == None:
                if group == None:
                    out_1 = self.linear1(x)
                    act_1 = self.relu(out_1)
                    out_2 = self.linear2(act_1)
                    act_2 = self.relu(out_2)
                    out_3 = self.linear3(act_2)
                    act_3 = self.relu(out_3)
                    out_4 = self.linear4(act_3)

                    y_pred = torch.sigmoid(out_4)
                    return y_pred, act_3
        
        if task == 'regression':
            if phase == None:
                if group == None:
                    out_1 = self.linear1(x)
                    act_1 = self.relu(out_1)
                    out_2 = self.linear2(act_1)
                    act_2 = self.relu(out_2)
                    out_3 = self.linear3(act_2)
                    act_3 = self.relu(out_3)
                    out_4 = self.linear4(act_3)

                    y_pred = out_4
                    return y_pred, act_3
            


### Group-specific model using X,A explicitly
class GSM(torch.nn.Module):

    def __init__(self, input_size, hidden_size):
        super(GSM, self).__init__()
        self.linearF = torch.nn.Linear(input_size,1)
        self.linearM = torch.nn.Linear(input_size,1)
        self.relu = torch.nn.ReLU()

    def forward(self, x, a, task, phase, group):
        if task == 'classification':
            if phase == 0:   # training phase with access to A
                if group == None:
                    idx_F = torch.where(a.argmax(-1))[0]
                    idx_M = torch.where(a.argmin(-1))[0]

                    prob_F = self.linearF(x[idx_F])
                    prob_M = self.linearM(x[idx_M])

                    y_pred_F = torch.sigmoid(prob_F)
                    y_pred_M = torch.sigmoid(prob_M)
                    return y_pred_F, y_pred_M

            if phase == 1:   # DFP phase
                if group == 'Female':
                    prob_F = self.linearF(x)
                    y_pred_F = torch.sigmoid(prob_F)
                    return y_pred_F
                
                if group == 'Male':
                    prob_M = self.linearM(x)
                    y_pred_M = torch.sigmoid(prob_M)
                    return y_pred_M
                    
        if task == 'regression':
            if phase == 0:   # training phase with access to A
                if group == None:
                    idx_F = torch.where(a.argmax(-1))[0]
                    idx_M = torch.where(a.argmin(-1))[0]

                    prob_F = self.linearF(x[idx_F])
                    prob_M = self.linearM(x[idx_M])

                    y_pred_F = prob_F
                    y_pred_M = prob_M
                    return y_pred_F, y_pred_M

            if phase == 1:   # DFP phase
                if group == 'Female':
                    prob_F = self.linearF(x)
                    y_pred_F = prob_F
                    return y_pred_F
                
                if group == 'Male':
                    prob_M = self.linearM(x)
                    y_pred_M = prob_M
                    return y_pred_M
                


### Group-specific model under LDP (using X,Z)
class GSM_LDP(torch.nn.Module):
    def __init__(self, input_size, hidden_size):
        super(GSM_LDP, self).__init__()
        self.linearF = torch.nn.Linear(input_size,1)
        self.linearM = torch.nn.Linear(input_size,1)
        self.relu = torch.nn.ReLU()
    
    def forward(self, x, z, task, phase, group):
        if task == 'classification':
            if phase == 0:   # training phase with access to A
                if group == None:
                    idx_F = torch.where(z.argmax(-1))[0]
                    idx_M = torch.where(z.argmin(-1))[0]

                    prob_AF_ZF = self.linearF(x[idx_F])
                    prob_AF_ZM = self.linearF(x[idx_M])
                    prob_AM_ZF = self.linearM(x[idx_F])
                    prob_AM_ZM = self.linearM(x[idx_M])

                    y_pred_AF_ZF = torch.sigmoid(prob_AF_ZF)
                    y_pred_AF_ZM = torch.sigmoid(prob_AF_ZM)
                    y_pred_AM_ZF = torch.sigmoid(prob_AM_ZF)
                    y_pred_AM_ZM = torch.sigmoid(prob_AM_ZM)
                    return y_pred_AF_ZF, y_pred_AF_ZM, y_pred_AM_ZF, y_pred_AM_ZM

            if phase == 1:   # DFP phase
                if group == 'Female':
                    prob_F = self.linearF(x)
                    y_pred_F = torch.sigmoid(prob_F)
                    return y_pred_F
                
                if group == 'Male':
                    prob_M = self.linearM(x)
                    y_pred_M = torch.sigmoid(prob_M)
                    return y_pred_M
        
        if task == 'regression':
            if phase == 0:   # training phase with access to A
                if group == None:
                    idx_F = torch.where(z.argmax(-1))[0]
                    idx_M = torch.where(z.argmin(-1))[0]

                    prob_AF_ZF = self.linearF(x[idx_F])
                    prob_AF_ZM = self.linearF(x[idx_M])
                    prob_AM_ZF = self.linearM(x[idx_F])
                    prob_AM_ZM = self.linearM(x[idx_M])

                    y_pred_AF_ZF = prob_AF_ZF
                    y_pred_AF_ZM = prob_AF_ZM
                    y_pred_AM_ZF = prob_AM_ZF
                    y_pred_AM_ZM = prob_AM_ZM
                    return y_pred_AF_ZF, y_pred_AF_ZM, y_pred_AM_ZF, y_pred_AM_ZM

            if phase == 1:   # DFP phase
                if group == 'Female':
                    prob_F = self.linearF(x)
                    y_pred_F = prob_F
                    return y_pred_F
                
                if group == 'Male':
                    prob_M = self.linearM(x)
                    y_pred_M = prob_M
                    return y_pred_M
                


### Noise estimate network
class Noise_Estimate(torch.nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Noise_Estimate, self).__init__()
        self.linear1 = torch.nn.Linear(input_size, hidden_size)
        self.linear2 = torch.nn.Linear(hidden_size, hidden_size)
        self.linear3 = torch.nn.Linear(hidden_size, 1)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        out_1 = self.linear1(x)
        act_1 = self.relu(out_1)
        out_2 = self.linear2(act_1)
        act_2 = self.relu(out_2)
        out_3 = self.linear3(act_2)

        y_pred = torch.sigmoid(out_3)
        return y_pred