import os
import torch
import torch.nn as nn
from activations import RePU, Swish, GeLU, GumbelLU, GudermanLU, AlgebraicLU


def activation_dictionary_generator():
    result = {}
    result['relu'] = nn.ReLU
    result['tanh'] = nn.Tanh
    result['repu'] = RePU
    result['swish'] = Swish
    result['gelu'] = GeLU
    result['gumbellu'] = GumbelLU
    result['gudermanlu'] = GudermanLU
    result['algebraiclu'] = AlgebraicLU
    return result


######### Base model #########


class FullyConnected(nn.Module):

    # Build a neural network according to given model parameters
    def __init__(self,
                 width,  # Width
                 depth,  # Desired aspect ratio
                 input_dim,  # Input dimension
                 activation,  # Activation function
                 Cw,  # IC distribution parameters
                 Cb,
                 control,  # for CSwish, RePu
                 power,  # for CSwish, RePu
                 default_initialization=0,
                 args=None,
                 ):

        # Inherit from PyTorch template
        super(FullyConnected, self).__init__()

        # Internalize model parameters
        self.width = width
        self.depth = depth
        self.input_dim = input_dim
        self.activation_dictionary = activation_dictionary_generator()

        if args is None:
            args = lambda x: x
            args.comments = ''

        if activation == 'relu' or activation == 'tanh':
            self.activation = self.activation_dictionary[activation]()
        else:
            self.activation = self.activation_dictionary[activation](control, power)

        self.control = control
        self.power = power
        self.Cw = Cw
        self.Cb = Cb

        # Record actual depth-to-width ratio
        self.ratio = self.depth / self.width

        # Build layers
        self.layers = nn.ModuleList([])
        self.pre_layers = nn.ModuleList([])
       

 
        # First layer is special
        self.layers.append(nn.Linear(input_dim, width)) 
        self.layers.append(self.activation)

        # Loop over hidden layers
        for i in range(self.depth - 1):
            linear_layer = nn.Linear(width, width)
            self.layers.append(linear_layer)
            self.pre_layers.append(linear_layer)
            act = self.activation
            self.layers.append(act)

        # Last layer is not special at the moment
        linear_layer = nn.Linear(self.width, self.width)
        self.layers.append(linear_layer)
        self.pre_layers.append(linear_layer)

        self.num_pre_layers = len(self.pre_layers)


        # Put the layers together
        self.stack = nn.Sequential(*self.layers)
        if default_initialization == 0:
            self.stack.apply(self.init_params)

        self.layers= None
        del self.layers

    # PDLT initialization scheme
    def init_params(self, m):
        if isinstance(m, nn.Linear):
            in_width = m.weight.shape[1]
            nn.init.normal_(m.weight, mean=0, std=1)
            nn.init.normal_(m.bias, mean=0, std=1)
            m.weight.data *= (self.Cw / in_width) ** 0.5
            m.bias.data *= self.Cb ** 0.5

    # Necessary for pytorch
    def forward(self, x):
        output = self.stack(x)
        return output

    def get_ratio(self):
        return self.ratio

    def get_width(self):
        return self.width

    def get_depth(self):
        return self.depth

    def get_variance_params(self):
        return (self.Cw, self.Cb)



class model(torch.nn.Module):
    # Build a neural network according to given model parameters
    def __init__(self,
                 width, 
                 depth,  
                 input_dim,  # Input dimension
                 classes,  # number of categories
                 activation,  # Activation function
                 Cw,  # IC distribution parameters
                 Cb,
                 control, 
                 power,  
                 default_initialization,
                 args=None
                 ):
        # Inherit from PyTorch template
        super(model, self).__init__()
        self.net = FullyConnected(width, depth, input_dim, activation, Cw, Cb, control, power,
                                  default_initialization, args)
        #weights=list(self.net.parameters())

        #for i,w in enumerate(weights):
        #    setattr(self,'weight_'+str(i),w)
        self.final_layer = torch.nn.Linear(width, classes)
        self.width = width
        self.input_dim = input_dim
        self.depth = depth
        self.classes = classes
        self.Cw = Cw
        self.Cb = Cb
        self.control = control
        self.power = power
        self.activation = activation
        self.performance = []
        self.test_performance = []
        self.NTK = []

    # Necessary for pytorch
    def forward(self, x):
        z_out = self.net(x)
        output = self.final_layer(z_out)
        return output

    def state_dict(self, save_dir):
        """
        Overrides the state dict method for nn modules.
        """
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        model_dict = {}
        model_dict['epoch'] = self.epoch
        model_dict['net'] = self.net.state_dict()
        model_dict['performance'] = self.performance
        model_dict['test_performance'] = self.test_performance
        model_dict['optim'] = {}
        model_dict['NTK'] = self.NTK
        model_dict['Cb'] = self.Cb
        model_dict['Cw'] = self.Cw
        # model_dict['args'] = self.args

        name = f'{self.name}_epoch_{self.epoch}.ckpt'
        print(f'Saving {name} in {save_dir}')
        self.ckpt_name = os.path.join(save_dir, name)
        torch.save(model_dict, os.path.join(save_dir, name))


def load_state_dict(fname=None):
    """
        Restores the ml
    """
    state_dict = torch.load(fname)
    epoch = state_dict['epoch']
    performance = state_dict['performance']
    test_performance = state_dict['test_performance']
    NTK = state_dict['NTK']
    net_params = state_dict['net']
    return epoch, performance, test_performance, net_params, NTK

