import torch
import torch.nn as nn
import math
import pandas as pd
import os

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, 
                 in_features:int, 
                 out_features:int, 
                 res_scale:float = 1.,
                 survival_rate:float = 0.5
                ):
        super(BasicBlock, self).__init__()
        self.relu = nn.ReLU(inplace=False)
        self.lin = nn.Linear(in_features, out_features)
        self.lin.weight.data.normal_(0, math.sqrt(2. / in_features))
        self.res_scale = res_scale
        self.survival_rate = survival_rate

    def forward(self, x):
        residual = x
        
        if not self.training or isinstance(self.survival_rate, torch.Tensor) or torch.rand(1)[0] <= self.survival_rate:
            out = self.relu(residual)
            out = self.lin(out)
            
            if not self.training or isinstance(self.survival_rate, torch.Tensor) or self.survival_rate == 1:
                out *= self.survival_rate
                
            out = residual + out * self.res_scale 
        else:
            out = x
            
        return out
    
    
class BasicBlockBN(nn.Module):
    expansion = 1

    def __init__(self, 
                 in_features:int, 
                 out_features:int, 
                 res_scale:float = 1.,
                 survival_rate:float = 0.5
                ):
        super(BasicBlockBN, self).__init__()
        self.bn = nn.BatchNorm1d(in_features)
        self.relu = nn.ReLU(inplace=False)
        self.lin = nn.Linear(in_features, out_features)
        self.lin.weight.data.normal_(0, math.sqrt(2. / in_features))
        self.res_scale = res_scale
        self.survival_rate = survival_rate

    def forward(self, x):
        residual = x
        
        if not self.training or torch.rand(1)[0] <= self.survival_rate:
            out = self.bn(x)
            out = self.relu(out)
            out = self.lin(out)
            
            if not self.training or self.survival_rate == 1:
                out *= self.survival_rate
                
            out = residual + out * self.res_scale 
        else:
            out = x
            
        return out    
    
    
class SimpleResNet(nn.Module):
    def __init__(self, 
                 in_features, 
                 out_features,
                 h_features,
                 depth=3,
                 block=BasicBlock,
                 res_scale=1.,
                 survival_rates=None,
                ):
        super(SimpleResNet, self).__init__()
        
        self.in_features = in_features
        self.out_features = out_features
        self.h_features = h_features # width of (internal) hidden layers
        self.depth = depth
        self.res_scale = res_scale
        
        self.block = block
        
        if survival_rates is None:
            self.survival_rates = [1] * self.depth
        else:
            assert len(survival_rates)==self.depth, "survival_rates must be an array of length depth"
            self.survival_rates = survival_rates
        
        # Initialize layers
        self.fc_in = nn.Linear(in_features, h_features) # Use a fully connected layer to match dimensions
        self.resnet_layers = self._make_resnet_layers()
        self.fc_out = nn.Linear(h_features, out_features)
        self.smax = nn.Softmax()
        
    def _make_resnet_layers(self):
        layers = [self.block(in_features=self.h_features, 
                             out_features=self.h_features,
                             res_scale=self.res_scale,
                             survival_rate=p
                            ) for p in self.survival_rates]
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.fc_in(x)
        x = self.resnet_layers(x)
        x = self.fc_out(x)
        
        #x = x**2/torch.sum(x**2, axis=-1).unsqueeze(axis=-1)
        
        x = x.squeeze()
        return x
    
    def reset_survival_rates(self, survival_rates=None):
        if survival_rates is None:
            self.survival_rates = [1] * self.depth
        else:
            assert len(survival_rates)==self.depth, "survival_rates must be an array of length depth"
            self.survival_rates = survival_rates
        
        for k, l in enumerate(self.resnet_layers):
            l.survival_rate = self.survival_rates[k]
    

class GradNormContainer(object):
    def __init__(self):
        self.grad_norms = []
        
    def __call__(self, module, grad_out, grad_in):
        self.grad_norms += [(grad_in[0]**2).mean().cpu().numpy()]
        
        
class GradSamplesContainer(object):
    def __init__(self):
        self.dict = {}
        self.grad_samples = []
        self.active = False
        
    def __call__(self, module, grad_out, grad_in):
        if self.active:
            self.grad_samples += [grad_in[0][0].detach().cpu().numpy()]
            
    
    def activate(self, epoch):
        self.active = True
        self.grad_samples = []
        self.dict[epoch] = self.grad_samples
        
    def deactivate(self):
        self.active = False
        self.grad_samples = []