import numpy as np
import math
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from collections import OrderedDict



class NeuralTangentFeature(nn.Module): 
    
    def __init__(self, dims):
        # self.beta = beta
        super(NeuralTangentFeature, self).__init__()
        
        layers = OrderedDict() 
        for i in range(len(dims)-1):
            layers['linear_%d'%i] = nn.Linear(dims[i], dims[i+1])
            if i < len(dims)-2:
                layers['relu_%d'%i] = nn.ReLU()
        # import pdb; pdb.set_trace()
        self.layers = nn.Sequential(layers)

        self.reset_parameters('xavier')
    
    @property
    def param_dict(self):
        return OrderedDict(self.named_parameters())

    def forward(self, x, params=None):
        if params is None:
            return self.layers(x).squeeze()
        else:
            for key, module in self.named_modules():
                if 'linear' in key:
                    x = F.linear(x, weight=params[key + '.weight'],
                                bias=params[key + '.bias'])
                if 'relu' in key:
                    x = F.relu(x)
            return x

      
    def reset_parameters(self, method='default'):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                if method == 'default':
                    stdv = 1. / math.sqrt(m.weight.size(1))
                    m.weight.data.uniform_(-stdv, stdv)
                    if m.bias is not None:
                        m.bias.data.uniform_(-stdv, stdv)
                elif method == 'normal':
                    stdv = 1. / math.sqrt(m.weight.size(1))
                    nn.init.normal_(m.weight, mean=0, std=1)
                    if m.bias is not None:
                        nn.init.normal_(m.bias, mean=0, std=1)
                elif method == 'xavier':
                    nn.init.xavier_normal_(m.weight)
                    # if m.bias is not None:
                    #     nn.init.xavier_normal_(m.bias)
                elif method == 'kaiming':
                    nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
    


    # def forward(self, input):
    #     return F.linear(input, self.w_sig * self.weight/np.sqrt(self.in_features), self.beta * self.bias)


# class LinearNeuralTangentKernel(nn.Linear): 
    
#     def __init__(self, in_features, out_features, bias=True, beta=0.1, w_sig = 1):
#         self.beta = beta
#         super(LinearNeuralTangentKernel, self).__init__(in_features, out_features)
#         self.reset_parameters()
#         self.w_sig = w_sig
      
#     def reset_parameters(self):
#         torch.nn.init.normal_(self.weight, mean=0, std=1)
#         if self.bias is not None:
#             torch.nn.init.normal_(self.bias, mean=0, std=1)

#     def forward(self, input):
#         return F.linear(input, self.w_sig * self.weight/np.sqrt(self.in_features), self.beta * self.bias)

#     def extra_repr(self):
#         return 'in_features={}, out_features={}, bias={}, beta={}'.format(
#             self.in_features, self.out_features, self.bias is not None, self.beta
#         )

# class FourLayersNet(nn.Module):

#     def __init__(self, n_wid, n_out = 1, beta=0.1):
#         super(FourLayersNet, self).__init__()
#         self.fc1 = LinearNeuralTangentKernel(2, n_wid, beta=beta)
#         self.fc2 = LinearNeuralTangentKernel(n_wid, n_wid, beta=beta)
#         self.fc3 = LinearNeuralTangentKernel(n_wid, n_wid, beta=beta)
#         self.fc4 = LinearNeuralTangentKernel(n_wid, n_out, beta=beta)

#     def forward(self, x):
#         x = F.relu(self.fc1(x))
#         x = F.relu(self.fc2(x))
#         x = F.relu(self.fc3(x))
#         x = self.fc4(x)
#         return x