import os
import sys
import torch.nn.functional as F
import torch.nn as nn
import torch

class MLP_Block_BP(nn.Module):
    def __init__(self, n_neurons, in_dim_feat, normalize_input=False, out_layer=False):
        super().__init__()
        
        self.normalize_input = normalize_input
        self.out_layer = out_layer
        if self.normalize_input:
            self.norm = nn.LayerNorm(in_dim_feat, elementwise_affine=False)
        self.fc = nn.Linear(in_dim_feat, n_neurons, bias=True)
        self.relu = nn.ReLU(True)
    
    def forward(self, h):
        if self.normalize_input:
            h = self.norm(h)
        h = self.fc(h)
        if not self.out_layer:
            h = self.relu(h)
        return h

class MLP_Block_FF(nn.Module):
    def __init__(self, n_neurons, in_dim_feat, normalize_input=False):
        super().__init__()
        
        self.normalize_input = normalize_input
        if self.normalize_input:
            self.norm = nn.LayerNorm(in_dim_feat, elementwise_affine=False)
        self.fc = nn.Linear(in_dim_feat, n_neurons, bias=True)
        self.relu = nn.ReLU(True)
    
    def forward(self, h):
        if self.normalize_input:
            h = self.norm(h)
        h = self.fc(h)
        h = self.relu(h)
        self.h = h
        return h

class MLP_Block_GIFF(nn.Module):
    def __init__(self, n_neurons, in_dim_feat, in_dim_class, normalize_input=False, same_w=False):
        super().__init__()
        
        self.same_w = same_w
        self.normalize_input = normalize_input
        if self.normalize_input:
            self.norm = nn.LayerNorm(in_dim_feat, eps=1e-9, elementwise_affine=False)
        self.fc1 = nn.Linear(in_dim_feat, n_neurons, bias=True)
        if not self.same_w:
            self.fc2 = nn.Linear(in_dim_class, n_neurons, bias=True)
        self.relu = nn.ReLU(True)
    
    def forward(self, h, c):
        if self.normalize_input:
            h = self.norm(h)
        h = self.fc1(h)
        h = self.relu(h)
        self.h = h
        
        if not self.same_w:
            c = self.fc2(c)
        else:
            c = self.fc1(c)
        c = self.relu(c)
        self.c = c
        f = h+c
        self.f = f
        
        return f, h, c

class MLP_Net_BP(nn.Module):
    def __init__(self, n_neurons, in_dim_feat, configure = False):
        super().__init__()
        self.configure = configure     
        blocks = []
        if configure == True:
            blocks.append(MLP_Block_BP(n_neurons=n_neurons[0], in_dim_feat=in_dim_feat, normalize_input=False, out_layer=False))
            blocks.append(MLP_Block_BP(n_neurons=n_neurons[1], in_dim_feat=n_neurons[0], normalize_input=True, out_layer=False))
            blocks.append(MLP_Block_BP(n_neurons=n_neurons[2], in_dim_feat=n_neurons[1], normalize_input=True, out_layer=True))
        else:
            blocks.append(MLP_Block_BP(n_neurons=n_neurons, in_dim_feat=in_dim_feat, normalize_input=False, out_layer=False))
            blocks.append(MLP_Block_BP(n_neurons=n_neurons, in_dim_feat=n_neurons, normalize_input=True, out_layer=False))
            blocks.append(MLP_Block_BP(n_neurons=n_neurons, in_dim_feat=n_neurons, normalize_input=True, out_layer=True))    
        self.blocks = nn.Sequential(*blocks)
        self.n_blocks = len(blocks)
        
        self.activation_num = in_dim_feat + n_neurons[0] + n_neurons[1] + n_neurons[2]
        self.gradient_num = in_dim_feat*n_neurons[0] + n_neurons[0]*n_neurons[1] + n_neurons[1]*n_neurons[2]    \
                            + n_neurons[0] + n_neurons[1] + n_neurons[2] + 1

    def forward(self, x):
        h = self.blocks[0](x) 
        h = self.blocks[1](h)            
        h = self.blocks[2](h)              
        return h

class MLP_Net_FF(nn.Module):
    def __init__(self, n_neurons, in_dim_feat, configure = False):
        super().__init__()
                
        blocks = []
        self.configure = configure
        if configure == True:
            blocks.append(MLP_Block_FF(n_neurons=n_neurons[0], in_dim_feat=in_dim_feat, normalize_input=False))
            blocks.append(MLP_Block_FF(n_neurons=n_neurons[1], in_dim_feat=n_neurons[0], normalize_input=True))
            #blocks.append(MLP_Block_FF(n_neurons=n_neurons[2], in_dim_feat=n_neurons[1], normalize_input=True))
        else:    
            blocks.append(MLP_Block_FF(n_neurons=n_neurons, in_dim_feat=in_dim_feat, normalize_input=False))
            blocks.append(MLP_Block_FF(n_neurons=n_neurons, in_dim_feat=n_neurons, normalize_input=True))
            #blocks.append(MLP_Block_FF(n_neurons=n_neurons, in_dim_feat=n_neurons, normalize_input=True))
            
        self.blocks = nn.Sequential(*blocks)
        self.n_blocks = len(blocks)

    def forward(self, x):
        h = self.blocks[0](x) 
        h = self.blocks[1](h)            
        #h = self.blocks[2](h)          
        hs = [b.h.view(b.h.shape[0],-1) for b in self.blocks.children()]
        
        return torch.cat(hs,dim=1)

class MLP_Net_GIFF(nn.Module):
    def __init__(self, n_neurons, in_dim_feat, in_dim_class, same_w, configure = False):
        super().__init__()
                
        blocks = []
        self.configure = configure
        if configure == True:
            blocks.append(MLP_Block_GIFF(n_neurons=n_neurons[0], in_dim_feat=in_dim_feat, in_dim_class=in_dim_class, normalize_input=False, same_w=False))
            blocks.append(MLP_Block_GIFF(n_neurons=n_neurons[1], in_dim_feat=n_neurons[0], in_dim_class=in_dim_class, normalize_input=True, same_w=same_w))
            blocks.append(MLP_Block_GIFF(n_neurons=n_neurons[2], in_dim_feat=n_neurons[1], in_dim_class=in_dim_class, normalize_input=True, same_w=same_w))
        else:
            blocks.append(MLP_Block_GIFF(n_neurons=n_neurons, in_dim_feat=in_dim_feat, in_dim_class=in_dim_class, normalize_input=False, same_w=False))
            blocks.append(MLP_Block_GIFF(n_neurons=n_neurons, in_dim_feat=n_neurons, in_dim_class=n_neurons, normalize_input=True, same_w=same_w))
            blocks.append(MLP_Block_GIFF(n_neurons=n_neurons, in_dim_feat=n_neurons, in_dim_class=n_neurons, normalize_input=True, same_w=same_w))   
        self.blocks = nn.Sequential(*blocks)
        self.n_blocks = len(blocks)

    def forward(self, x, c):
        f,h,c = self.blocks[0](x,c) 
        f,h,c = self.blocks[1](h,c)            
        f,h,c = self.blocks[2](h,c)   
        hs = [b.f.view(b.f.shape[0],-1) for b in self.blocks.children()]     
        if self.configure == True:
            max_size = max(tensor.size(1) for tensor in hs)
            padded_hs = []
            for tensor in hs:
                padding_size = max_size - tensor.size(1)
                padded_tensor = F.pad(tensor, (0, padding_size))
                padded_hs.append(padded_tensor)
            hs = padded_hs
        return torch.stack(hs,dim=1)

