import torch 
import torch.nn as nn
from src.utils.train import OptimModule

def _get_activation(activation, hidden_dim, **kwargs):
    if activation == 'relu':
        return nn.ReLU()
    elif activation == 'tanh':
        return nn.Tanh()
    elif activation == 'sin':
        return Sin(dim=hidden_dim, **kwargs)


class SinFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, freq):
        ctx.save_for_backward(x, freq)
        return torch.sin(freq * x)

    @staticmethod
    def backward(ctx, dout):
        x, freq = ctx.saved_tensors
        # TODO: add custom grad w.r.t frequency
        return dout * torch.cos(freq * x), None


class Sin(nn.Module):
    def __init__(self, dim, w=10, train_freq=True):
        super().__init__()
        self.freq = (
            nn.Parameter(w * torch.ones(1, dim))
            if train_freq
            else w * torch.ones(1, dim)
        )

    def forward(self, x):
        return SinFunc.apply(x, self.freq)



class SirenMLP(OptimModule):
    def __init__(self, in_dim, out_dim, hidden_dim, num_layers=2, activation='sin', **activation_kwargs):
        super().__init__()
        self.num_layers = num_layers
        activation = _get_activation(activation, hidden_dim, **activation_kwargs)
        self.layers = nn.ModuleList()
        self.weights, self.biases = nn.ParameterList(), nn.ParameterList()
        self.activation = activation
        for i in range(num_layers):
            if i == 0:
                self.layers.append(nn.Linear(in_dim, hidden_dim))
                self.weights.append(self.layers[-1].weight)
                self.biases.append(self.layers[-1].bias)
            elif i == num_layers - 1:
                self.layers.append(nn.Linear(hidden_dim, out_dim))
                self.weights.append(self.layers[-1].weight)
                self.biases.append(self.layers[-1].bias)
            else:
                self.layers.append(nn.Linear(hidden_dim, hidden_dim))
                self.weights.append(self.layers[-1].weight)
                self.biases.append(self.layers[-1].bias)
        self.layers.append(nn.Linear(hidden_dim, out_dim))
        self.weights.append(self.layers[-1].weight)
        self.biases.append(self.layers[-1].bias)

    def forward(self, x):
        return SirenMLPFunc.apply(x, weights, biases, self.activation)


class SirenMLPFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, weights, biases, activation):
        ctx.save_for_backward(x, weights, biases)
        for layer in range(len(weights)):
            x = torch.einsum('...i,...ij->...j', x, weights[layer]) + biases[layer]
            x = activation(x)
        return x

    def backward(ctx, dout):
        x, weights, biases = ctx.saved_tensors
        # recompute forward
        acts = []
         for layer in range(len(weights)):
            x = torch.einsum('...i,...ij->...j', x, weights[layer]) + biases[layer]
            acts.append(x)
            x = activation(x)


        # backward
        dtheta = [
            torch.empty_like(weight) for weight in weights]
        dbias = [torch.empty_like(bias) for bias in biases]

        for i in range(len(layers)-1, -1, -1):
            if type(layers[i]) == nn.Linear:
                dout = torch.einstein_sum('...i,...ij->...j', dout, self.layers[i].weight)
                dtheta[i] = torch.einsum('...i,...ij->...j', dout, acts[i-1])
            else: 
                dout = layers[i].backward(dout)
        import pdb; pdb.set_trace()
        return dout, dtheta, dbias, None


if __name__ == "__main__":
    x = torch.randn(1, 3, requires_grad=True)
    model = SirenMLP(3, 1, 64, 3, 'sin', w=10, train_freq=True)
    y = model(x)
    print(y)

    y.backward(torch.ones_like(y))
    print(x.grad)
    print(model.layers[0].weight.grad)

