import timeit

import torch
from tqdm import trange
from torch.nn.functional import relu, celu
from torch import nn

from utils import get_stateless_net_with_partials
from torch import vmap
from torch._functorch.eager_transforms import jacrev, jacfwd
from torch._functorch.functional_call import functional_call

class MyNet(nn.Module):
    def __init__(self, ks, act=celu):
        super(MyNet, self).__init__()
        self.ks = ks
        self.fcs = nn.ModuleList([nn.Linear(in_features, out_features)
            for in_features, out_features in zip(self.ks[:-1],self.ks[1:])])
        self.D = len(self.fcs)
        self.act = act

    def forward(self, x, z):
        '''
        First concatenates to tensor of dim [bz bx (nx+nz)]. Then it passes this tensor through the network.
        :param x: [bx nx]
        :param z: [1 nz]
        Will later be vectorized, to go for [bz bx (nx+nz)]
        :return:
        '''
        xz = torch.cat([x, z], dim=-1)
        x = self.fcs[0](xz)
        for i in range(2,self.D):
            x = self.fcs[i-1](self.act(x))
        x = self.fcs[self.D-1](self.act(x))
        return x

def bench():
    model = MyNet(ks=[18, 20, 20, 1], act=torch.sin)
    
    ## Parameters for stateless model
    params = dict(model.named_parameters())

    ## Stateless model
    def f(params, x, z):
        """
        Stateless call to the model. This works for
        1) single inputs:
        x: [nx]
        returns: [ny]
        -- and --
        2) batch inputs:
        x: [bx, nx]
        returns: [bx, ny]
        """
        return functional_call(model, params, (x, z))

    ## Jacobian
    f_x = jacrev(f, argnums=(1))  ## params, [nx] -> [ny, nx]
    vf_x = vmap(f_x, in_dims=(None, 0, 0), out_dims=(0))  ## params, [bx, nx] -> [bx, ny, nx]
    ## Hessian
    f_xx = jacfwd(f_x, argnums=(1))  ## params, [nx] -> [ny, nx, nx]
    vf_xx = vmap(f_xx, in_dims=(None, 0, 0), out_dims=(0))  ## params, [bx, nx] -> [bx, ny, nx, nx]

    x = torch.rand(2)  ## Sample bx number of points within the specified domain
    z = torch.rand(16)
    print(timeit.timeit(lambda: f(params, x, z), number=10))

    x = torch.rand(64, 2)  ## Sample bx number of points within the specified domain
    z = torch.rand(64, 16)

    print(timeit.timeit(lambda: f(params, x, z), number=10))
    print(timeit.timeit(lambda: f_x(params, x, z), number=10))
    print(timeit.timeit(lambda: vf_x(params, x, z), number=10))
    print(timeit.timeit(lambda: f_xx(params, x, z), number=10))
    print(timeit.timeit(lambda: vf_xx(params, x, z), number=10))
    
    print(f(params, x, z).shape)  # [bx ny]
    print(vf_x(params, x, z).shape)  # [bx ny]
    print(vf_xx(params, x, z).shape)  # [bx ny]
    
    # for res in vf_x(params, x, z):
    #     print(res.shape)  # [bx ny]
    # for res in vf_xx(params, x, z):
    #     print(res.shape)  # [bx ny]
    

def benchorig():
    # model = GeneralNet(ks=[2, 20, 20, 1], act=torch.sin)
    model = MyNet(ks=[2, 20, 20, 1], act=torch.sin)

    params = dict(model.named_parameters())
    opt = torch.optim.Adam(model.parameters(), lr=0.001)
    params, f, vf_x, vf_xx = get_stateless_net_with_partials(model)

    bounds = torch.tensor([[-1, 1],[-1, 1]])*1
    xc = bounds[:, 0] + (bounds[:, -1] - bounds[:, 0]) * torch.rand(256, 2)  ## Sample bx number of points within the specified domain

    print(timeit.timeit(lambda: f(params, xc), number=1000))
    print(timeit.timeit(lambda: vf_x(params, xc), number=1000))
    print(timeit.timeit(lambda: vf_xx(params, xc), number=1000))

    # xc.requires_grad = True  ## we will optimize xc, not the model self.params as usual
    # opt = torch.optim.Adam([xc], lr=1e-2)
    # for i in (pbar := trange(300)):
    #     opt.zero_grad()
    #     y_x = vf_x(params, xc)  ## [bx, ny, nx]
    #     loss = y_x.square().sum(2).sum(1).mean()  ## compute mean of squared norms

    #     loss.backward()
    #     opt.step()
    #     pbar.set_description(f"Finding CPs: {loss.item():.2e}")


def plot_mesh_and_boundary():
    
    import trimesh
    import k3d


    file_path = "interfaces.stl"
    mesh = trimesh.load(file_path)

    ## Sample points and determine the normals from the corresponding faces
    ## From https://github.com/mikedh/trimesh/issues/1285#issuecomment-880854466
    n_points = 10000
    points, face_idx = mesh.sample(n_points, return_index=True)
    normals = mesh.face_normals[face_idx]

    # ## Plot
    # plot = k3d.plot(height=1000)
    # plot += k3d.points(points, point_size=0.2, color=0xff0000)
    # plot += k3d.vectors(points, normals, color=0)
    # plot.display()

if __name__ == '__main__':
    # bench()
    bench()