from numpy import broadcast_to
import torch
import numpy as np
import torch.nn as nn
from IPython import embed

# import algos.pyhessian as pyhessian 

def rotate_prod_np(xs):
    xs_masked = np.tile(xs, (len(xs), 1))
    np.fill_diagonal(xs_masked, 1)
    return np.prod(xs_masked, 1)

def rotate_prod_2_np(xs):
    d = np.diag(np.ones(len(xs)))
    mask = np.logical_or(np.expand_dims(d, 1), np.expand_dims(d, 0)).astype(int)
    xs_masked = np.ma.masked_array(np.tile(xs, (len(xs), len(xs), 1)), mask)
    return np.prod(xs_masked, 2)


class ScalarNet(nn.Module):
    def __init__(self, L, init_values=None):
        super(ScalarNet, self).__init__()
        self.L = L
        self.layers = []
        for i in range(L):
            self.layers.append(nn.Linear(1, 1, bias=False))
            self.add_module("W{}".format(i+1), self.layers[-1])
        if init_values is not None:
            for i, l in enumerate(self.layers):
                l.weight.data = torch.Tensor([[init_values[i]]])

    def forward(self, x):
        for i in range(self.L):
            x = self.layers[i](x)
        return x

    def hessian_comp(self, mu):
        xs = torch.Tensor([l.weight.data[0] for l in self.layers])
        xs_square = xs ** 2
        mu_hat = torch.prod(xs)
        c1 = (2 * mu_hat - mu) * mu_hat
        hes = c1 / xs.unsqueeze(1).matmul(xs.unsqueeze(0))
        hes.fill_diagonal_(0)
        hes += torch.diag(mu_hat ** 2 / xs_square)
        return hes

    def gradient_comp(self, mu):
        xs = torch.Tensor([l.weight.data[0] for l in self.layers])
        mu_hat = torch.prod(xs)
        ret = (mu_hat - mu) * mu_hat
        return ret / xs 

    def sharpness(self, mu):
        hes = self.hessian_comp(mu)
        return torch.svd(hes)[1][0]

    def hessian_eigenvals(self, mu):
        hes = self.hessian_comp(mu)
        return torch.svd(hes)[1]

    def weight_clone(self):
        return self.weights.clone().numpy()


class ScalarNetVecTorch():
    def __init__(self, init_values=None):
        self.L = len(init_values)
        self.weights = torch.Tensor(init_values)

    def forward(self, x):
        return torch.prod(self.weights) * x

    def loss(self, x, mu):
        return (self.forward(x) - mu) ** 2 / 2

    def hessian_comp(self, mu):
        xs = self.weights
        xs_square = xs ** 2
        mu_hat = torch.prod(xs)
        c1 = (2 * mu_hat - mu) * mu_hat
        hes = c1 / xs.unsqueeze(1).matmul(xs.unsqueeze(0))
        hes.fill_diagonal_(0)
        hes += torch.diag(mu_hat ** 2 / xs_square)
        return hes

    def gradient_comp(self, mu):
        xs = self.weights
        mu_hat = torch.prod(xs)
        ret = (mu_hat - mu) * mu_hat
        return ret / xs 

    def sharpness(self, mu):
        hes = self.hessian_comp(mu)
        return torch.svd(hes)[1][0]

    def weight_clone(self):
        return self.weights.clone().numpy()

class ScalarNetVec():
    def __init__(self, init_values=None):
        self.L = len(init_values)
        self.weights = np.array(init_values)

    def forward(self, x):
        return np.prod(self.weights) * x

    def loss(self, x, mu):
        return (self.forward(x) - mu) ** 2 / 2
    
    def diverge(self):
        return np.max(np.abs(self.weights)) > 1e5

    def hessian_comp(self, mu):
        xs = self.weights
        xs_square = xs ** 2
        mu_hat = np.prod(xs)
        hes = (2 * mu_hat - mu) * rotate_prod_2_np(xs)
        # hes = c1 / (np.expand_dims(xs, 1) @ np.expand_dims(xs, 0))
        np.fill_diagonal(hes, 0)
        hes += np.diag(rotate_prod_np(xs) ** 2)
        return hes

    def gradient_comp(self, mu):
        xs = self.weights
        mu_hat = np.prod(xs)
        ret = (mu_hat - mu) * rotate_prod_np(xs)
        return ret

    def sharpness(self, mu):
        hes = self.hessian_comp(mu)
        return np.linalg.svd(hes)[1][0]

    def hessian_eigenvals(self, mu):
        hes = self.hessian_comp(mu)
        return np.linalg.svd(hes)[1]

    def weight_clone(self):
        return self.weights.copy()