import numpy as np
import torch
import torch.nn as nn

class SkipAdd(nn.Module):
    '''
    The skip connection module
    '''
    def forward(self, x, y):
        return x + y
    
def skipadd_flops_counter(self, inputs, outputs):
    return inputs[0].numel(), 0

SkipAdd.__flops__ = skipadd_flops_counter # directly modify the class, not the instance of the class

class MLPBlockReLU(nn.Module):
    '''
    Single MLP block, where both the middle and the final activation functions are ReLU
    h_{i+1} = h_{i} + W2 * ReLU(W1 * Batch_Norm(h_{i})))
    '''
    def __init__(self, d1, d2, block_id=None, dropout_rate=0.3):
        super().__init__()
        self.block_id = block_id
        self.norm = nn.BatchNorm1d(d1, affine=False)  # PreNorm without lernable weight/bias
        # self.norm = nn.BatchNorm1d(d1, affine=True)  # PreNorm with lernable weight/bias

        # self.norm = nn.LayerNorm(d1, eps=1e-6, elementwise_affine=True) # learnable weight/bias

        self.fc1 = nn.Linear(d1, d2, bias=False)
        self.act = nn.ReLU()
        # self.dropout = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(d2, d1, bias=False)
        self.skip_add = SkipAdd()
        self.skip_add.block_id = block_id

    def forward(self, x):
        nx = self.norm(x)  # PreNorm
        res = self.fc2(self.act(self.fc1(nx)))
        return self.skip_add(x, res)  # skip_add(x, res)


class NormalMLP(nn.Module):
    '''
    The whole normal model
    '''
    def __init__(self,
                 input_dim,
                 output_dim, 
                 reps_dim, 
                 mid_dims,
                 in_proj=False, 
                 out_proj=False):
        super().__init__()

        assert (reps_dim == input_dim) or (reps_dim == output_dim), \
            "ERROR: After considering the super-resolution mode, the representation dimension should equal input (or output) in Normal Model!"
        assert all(middle_dim > reps_dim for middle_dim in mid_dims), \
            "ERROR: all middle dimension should be larger than reps dimension"
        
        self.act_relu = nn.ReLU()
        
        self.in_proj = in_proj
        self.out_proj = out_proj

        # optionally add in_proj 
        if self.in_proj:
            self.in_fc = nn.Linear(input_dim, reps_dim, bias=False)

        # ReLU MLP blocks
        blocks = []
        for i, middle_dim in enumerate(mid_dims):  
            blocks.append(MLPBlockReLU(reps_dim, middle_dim, block_id=i+1))
        self.blocks = nn.ModuleList(blocks)

        # optionally add out_proj
        if self.out_proj:
            self.out_fc = nn.Linear(reps_dim, output_dim, bias=False)

    def forward(self, x, y_star): # y_star for effective profile
        h = x

        # optionally forward through in_proj 
        if self.in_proj:
            h = self.in_fc(h)

        ## Middle MLP inference
        for block in self.blocks:
            h = block(h)
        
        # optionally forward through out_proj 
        if self.out_proj:
            h = self.out_fc(h)

        # y_pred = h
        y_pred = torch.sigmoid(h)
        return y_pred
    
class HourGlassMLP(nn.Module):
    '''
    The whole hourglass model
    '''
    def __init__(self,
                 input_dim, 
                 output_dim, 
                 reps_dim, 
                 mid_dims,
                 compact_mid=False):
        super().__init__()

        assert reps_dim > input_dim, \
            "ERROR: representation dimension should be wider than input in HourGlass"
        assert all(mid_dim < reps_dim for mid_dim in mid_dims), \
            "ERROR: all middle dimension should be lower than reps dimension"
        if compact_mid:
            assert all(mid_dim < input_dim for mid_dim in mid_dims), \
                "ERROR: For saving #params, middle_dim should be lower than input dim in HourGlass"
        
        self.up = nn.Linear(input_dim, reps_dim, bias=False)
        self.down = nn.Linear(reps_dim, output_dim, bias=False)
        
        # all ReLU MLP Blocks
        self.blocks = nn.ModuleList([
            MLPBlockReLU(reps_dim, middle_dim, block_id=i+1) 
            for i, middle_dim in enumerate(mid_dims)
        ])

    def forward(self, x, y_star): # y_star for effective profile
        h = self.up(x)
        for block in self.blocks:
            h = block(h)
        y_pred = torch.sigmoid(self.down(h))
        return y_pred