import torch
from torch import nn
from higher_grad import hessian, inference_and_hgrad, higher_grad


class SigmoidLayer(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_sigmoid_stack = nn.Sequential(
            nn.Linear(2, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        y = self.linear_sigmoid_stack(x)
        return y


def hessian_xx_ans(model, x):
    st_dict = model.state_dict()
    y = model(x)
    w0 = st_dict['linear_sigmoid_stack.0.weight'][0][0]
    w1 = st_dict['linear_sigmoid_stack.0.weight'][0][1]
    h = y*(1-y) * (1-2*y) * torch.tensor([[w0*w0, w0*w1], [w1*w0, w1*w1]])

    return h


def hessian_ww_ans(model, x):
    st_dict = model.state_dict()
    y = model(x)
    x0 = x[0]
    x1 = x[1]
    h = y*(1-y) * (1-2*y) * torch.tensor([[x0*x0, x0*x1, x0], [x1*x0, x1*x1, x1], [x0, x1, 1]])

    return h


def hgrad_x_ans(model, x):
    st_dict = model.state_dict()
    y = model(x)
    w0 = st_dict['linear_sigmoid_stack.0.weight'][0][0]
    w1 = st_dict['linear_sigmoid_stack.0.weight'][0][1]
    x0 = x[0]
    x1 = x[1]
    h = y * (1 - y) * torch.tensor([w0, w1])

    return h


def hgrad_xww_ans(model, x):
    st_dict = model.state_dict()
    y = model(x)
    x0 = x[0]
    x1 = x[1]
    w0 = st_dict['linear_sigmoid_stack.0.weight'][0][0]
    w1 = st_dict['linear_sigmoid_stack.0.weight'][0][1]
    poly3 = y*(1-y)*(1-2*y)
    poly4 = (1-6*y+6*y**2)*y*(1-y)
    h = torch.tensor([[[poly4*w0*x0**2+2*poly3*x0, poly4*w0*x0*x1+poly3*x1, poly4*w0*x0+poly3],
                      [poly4*w0*x0*x1+poly3*x1, poly4*w0*x1**2, poly4*w0*x1],
                      [poly4*w0*x0+poly3, poly4*w0*x1, poly4*w0]],
                      [[poly4*w1*x0**2, poly4*w1*x0*x1+poly3*x0, poly4*w1*x0],
                       [poly4*w1*x0*x1+poly3*x0, poly4*w1*x1**2+2*poly3*x1, poly4*w1*x1+poly3],
                       [poly4*w1*x0, poly4*w1*x1+poly3, poly4 * w1]]]
                    )

    return h


if __name__ == '__main__':
    device = 'cpu'
    model = SigmoidLayer().to(device)
    print(list(model.parameters()))
    print(model.state_dict())
    x = torch.tensor([1., 2.], requires_grad=True)
    y = model(x)
    print(f'x: {x}, y: {y}')
    print('----hessian----')
    print(f'ddf/dxdx: {hessian(model, x, x)}')
    print(f'correct answer: {hessian_xx_ans(model, x)}')
    print(f'ddf/dwdw: {hessian(model, x, model.parameters())}')
    print(f'correct answer: {hessian_ww_ans(model, x)}')
    print('----higher_grad----')
    print(f'df/dx: {inference_and_hgrad(model, x, x)}')
    print(f'correct answer: {hgrad_x_ans(model, x)}')
    print(f'ddf/dwdw: '
          f'{higher_grad(higher_grad(model(x), model.parameters()), model.parameters())}')
    print(f'ddf/dwdw: '
          f'{inference_and_hgrad(model, x, model.parameters(), model.parameters())}')
    print(f'correct answer: {hessian_ww_ans(model, x)}')
    print(f'dddf/dxdwdw: '
          f'{inference_and_hgrad(model, x, model.parameters(), model.parameters(), x)}')
    print(f'correct answer: {hgrad_xww_ans(model, x)}')

