
import torch
import torch.nn as nn
import torch.nn.functional as F
from tools import get_missing_data

class Imputer(nn.Module):

    def __init__(self, input_dim, dataset_name):
        super(Imputer, self).__init__()

        self.input_dim = input_dim
        self.dataset_name = dataset_name
        
        if self.dataset_name is 'kdd': # 121
            self.impute = nn.Sequential(

                nn.Linear(self.input_dim, 64, bias=False),
                nn.LeakyReLU(inplace=True),

                nn.Linear(64, 32, bias=False),
                nn.LeakyReLU(inplace=True),

                nn.Linear(32, 32, bias=False),
                nn.LeakyReLU(inplace=True),

                nn.Linear(32, 64, bias=False),
                nn.LeakyReLU(inplace=True),

                nn.Linear(64, self.input_dim, bias=False),
            )

        
        elif self.dataset_name is 'adult':
            self.impute = nn.Sequential(
                nn.Linear(self.input_dim, 64, bias=False),
                nn.LeakyReLU(inplace=True),

                nn.Linear(64, 128, bias=False),
                nn.LeakyReLU(inplace=True),

                nn.Linear(128, 128, bias=False),
                nn.LeakyReLU(inplace=True),

                nn.Linear(128, 64, bias=False),
                nn.LeakyReLU(inplace=True),

                nn.Linear(64, self.input_dim, bias=False),
            )
        else:
            raise Exception(f'Unknown dataset [{self.dataset_name}]!')

    def forward(self, x):

        return self.impute(x)


class Projector(nn.Module):

    def __init__(self, input_dim, mid_dim, dataset_name):
        super(Projector, self).__init__()
        self.input_dim = input_dim
        self.mid_dim = mid_dim
        self.dataset_name = dataset_name
        
        if self.dataset_name is 'kdd':
            self.project = nn.Sequential(
                nn.Linear(self.input_dim, 256, bias=False),
                nn.LeakyReLU(inplace=True),

                nn.Linear(256, 64, bias=False),
                nn.LeakyReLU(inplace=True),

                nn.Linear(64, self.mid_dim, bias=False),

            )
        elif self.dataset_name is 'adult':
            self.project = nn.Sequential(
                nn.Linear(self.input_dim, 64, bias=False),
                nn.LeakyReLU(inplace=True),

                nn.Linear(64, 32, bias=False),
                nn.LeakyReLU(inplace=True),

                nn.Linear(32, 16, bias=False),
                nn.LeakyReLU(inplace=True),

                nn.Linear(16, 8, bias=False),
                nn.LeakyReLU(inplace=True),

                nn.Linear(8, self.mid_dim, bias=False)
            )
        else:
            raise Exception(f'Unknown dataset [{self.dataset_name}]!')

    def forward(self, x):

        return self.project(x)


class Recover(nn.Module):

    def __init__(self, mid_dim, recover_dim, dataset_name) -> None:
        super(Recover, self).__init__()
        self.mid_dim = mid_dim
        self.recover_dim = recover_dim
        self.dataset_name = dataset_name

        if self.dataset_name is 'kdd': # 121

            self.decoder = nn.Sequential(

                nn.Linear(self.mid_dim, 64, bias=False),
                nn.LeakyReLU(inplace=True),

                nn.Linear(64, 100, bias=False),
                nn.LeakyReLU(inplace=True),

                nn.Linear(100, self.recover_dim, bias=False),

            )

        
        elif self.dataset_name is 'adult':
          
            self.decoder = nn.Sequential(
                nn.Linear(self.mid_dim, 7, bias=False),
                nn.LeakyReLU(inplace=True),
                nn.Linear(7, 10, bias=False),
                nn.LeakyReLU(inplace=True),

                nn.Linear(10, self.recover_dim, bias=False)
            )
        else:
            raise Exception(f'Unknown dataset [{self.dataset_name}]!')

    def forward(self, x):
        
        return self.decoder(x)


class MVNet(nn.Module):

    def __init__(self, input_dim, mid_dim, dataset_name):
        super(MVNet, self).__init__()

        self.imputer = Imputer(input_dim=input_dim, dataset_name=dataset_name)
        self.projector = Projector(input_dim=input_dim, mid_dim=mid_dim, dataset_name=dataset_name)
        self.recover = Recover(mid_dim=mid_dim, recover_dim=input_dim, dataset_name=dataset_name)

    def forward(self, x, negative=False, missing_rate=0.0):

        if not negative:
            imputed_data = self.imputer(x)
            mid_repre = self.projector(imputed_data)
            recover_data = self.recover(mid_repre)
            return imputed_data, mid_repre, recover_data
        else:
            recover_data = self.recover(x)
            missing_data, masks = get_missing_data(recover_data, missing_rate)
            imputed_data = self.imputer(missing_data)
            mid_repre = self.projector(imputed_data)
            return imputed_data, mid_repre, missing_data, torch.from_numpy(masks)
        
        
