import torch
import torch.nn as nn
import torch.nn.functional as F
from optree.integration.torch import tree_ravel
from posteriors import ggn, empirical_fisher, per_samplify, cg, ggnvp
from posteriors import model_to_function, extract_requires_grad_and_func
from posteriors.utils import _hess_and_jac_for_ggn
from functools import partial
from optree import tree_map, tree_leaves
import numpy as np
import time

class SimpleConvNet(nn.Module):
    def __init__(self, fc_features, output_dim=10):
        super(SimpleConvNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=3)
        self.fc1 = nn.Linear(in_features=4 * 13 * 13, out_features=fc_features)
        self.fc2 = nn.Linear(in_features=fc_features, out_features=output_dim)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, kernel_size=2, stride=2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x

def forward(params):
    logits = torch.func.functional_call(model, params, x)
    return logits

def loss_ggn(logits):
    loss = torch.nn.functional.cross_entropy(logits, y)
    return loss

def loss_fn(params, inputs, labels):
    logits = torch.func.functional_call(model, params, inputs)
    return torch.nn.functional.cross_entropy(logits, labels)



# Define the single-sample forward function
def forward_single_sample(params, x):
    logits = torch.func.functional_call(model, params, x.unsqueeze(0))
    return logits.squeeze(0)

# Compute the Jacobian of the single-sample forward function
jac_single_sample_fn = torch.func.jacrev(forward_single_sample)

# Use vmap to vectorize the Jacobian computation over the batch dimension
def batched_jacobian(params, x):
    return torch.func.vmap(jac_single_sample_fn, in_dims=(None, 0))(params, x)


def partial_ggnvp(v):
    return tree_map(lambda x: x/batch_size, ggnvp(forward, loss_ggn, (params,), (v,), normalize=False)[1])

def timing(model, params, batch, optimizer_name, maxiter=100):
    if optimizer_name == 'ngd':
        start = time.time()
        grad_0 = torch.func.grad(loss_fn)(params, x, y)
        flat_params, params_unravel = tree_ravel(params)
        z = flat_params_to_forward(flat_params)
        hess = tree_ravel(torch.func.jacfwd(torch.func.jacrev(loss_ggn))(z))[0].reshape(batch_size*output_dim, batch_size*output_dim)
        jac = tree_ravel(batched_jacobian(params, x))[0].reshape(-1, flat_params.shape[0])
        kt = hess @ jac
        G = jac.T @ kt
        flat_grad_0, unravel = tree_ravel(grad_0)
        flat_grad = torch.linalg.solve(G + torch.eye(G.shape[0]), flat_grad_0)
        _ = unravel(flat_grad)
        times = time.time() - start
    elif optimizer_name == 'ngd-cg':
        start = time.time()
        grad_0 = torch.func.grad(loss_fn)(params, x, y)
        grad, _ = cg(partial_ggnvp, grad_0, x0=grad_0, maxiter=1, damping=1, tol=1e-2)
        times = (time.time() - start)
    elif optimizer_name == "ngd-woodbury":
        start = time.time()
        grad_0 = torch.func.grad(loss_fn)(params, x, y)
        flat_grads, params_unravel = tree_ravel(grad_0)
        flat_params, params_unravel = tree_ravel(params)
        z = flat_params_to_forward(flat_params)
        hess = tree_ravel(torch.func.jacfwd(torch.func.jacrev(loss_ggn))(z))[0].reshape(batch_size*output_dim, batch_size*output_dim)
        jac = tree_ravel(batched_jacobian(params, x))[0].reshape(-1, flat_params.shape[0])
        print(hess.shape, jac.shape)
        factor = hess @ jac
        factor = torch.linalg.solve(factor @ factor.T + torch.eye(factor.shape[0]), factor @ flat_grads)
        factor = jac.T @ factor
    
        grad = flat_grads -  factor
        times = time.time() - start
    elif optimizer_name == "ngd-thermo":
        start = time.time()
        grad_0 = torch.func.grad(loss_fn)(params, x, y)
        flat_params, params_unravel = tree_ravel(params)
        z = flat_params_to_forward(flat_params)
        hess = torch.func.jacfwd(torch.func.jacrev(loss_ggn))(z)
        jac = batched_jacobian(params, x)
        times = time.time() - start
    return times
        



optimizer_names = ['ngd', 'ngd-cg', 'ngd-woodbury', 'ngd-thermo']
batch_size = 32
fc_features = np.logspace(1, 2.5, 20).astype(int)
output_dim = 20
maxiter=200
times = {}
seeds = np.arange(0,5)
num_params = np.zeros_like(fc_features)
for optimizer_name in optimizer_names:
    times[optimizer_name] = np.zeros((len(seeds), len(fc_features)))
    for s, seed in enumerate(seeds):

        torch.manual_seed(seed)
        with torch.no_grad():
            for f, fc_feature in enumerate(fc_features):
                
                x = torch.ones([batch_size, 1, 28, 28])
                y = torch.ones([batch_size, output_dim])
                batch = (x, y)
                model = SimpleConvNet(fc_features=fc_feature, output_dim=output_dim)
                model_fun = model_to_function(model)
                params, model_fun = extract_requires_grad_and_func(dict(model.named_parameters()), model_fun)
                flat_params, params_unravel = tree_ravel(params)

                def flat_params_to_forward(fps):
                    return forward(params_unravel(fps))
                z = flat_params_to_forward(flat_params)
                hess = torch.func.jacfwd(torch.func.jacrev(loss_ggn))(z)
                jac = batched_jacobian(params, x)
                result = timing(model, params, batch, optimizer_name, maxiter=maxiter)
                
                times[optimizer_name][s, f] = result
                num_params[f] = flat_params.shape[0]
                print(optimizer_name, result, fc_feature, num_params[f])
                if result > 15:
                    break
    np.save("times"+optimizer_name+"b"+str(batch_size)+"dz"+str(output_dim)+"iter"+str(maxiter)+".npy", times[optimizer_name])

np.save("num_params.npy", num_params)



    