from torchvision import datasets, transforms
import torch
from .models import SimpleConvNet
from posteriors import model_to_function, extract_requires_grad_and_func, fvp, ggn, per_samplify
from utils import jac_and_hess, thermo_solve_fvp, ggnvp
from functools import partial
from optree import tree_map
import torchopt
import time
import json
from optree.integration.torch import tree_ravel
from tqdm import tqdm
from torch.utils.data import Subset

transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])
dataset1 = datasets.MNIST(root='data', train=True, download=True,
                       transform=transform)
dataset1 = Subset(dataset1, list(range(10000)))

dataset2 = Subset(datasets.MNIST(root='data', train=False,
                       transform=transform), list(range(10000)))

batch_size = 64
train_loader = torch.utils.data.DataLoader(dataset1,batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset2, batch_size=batch_size, shuffle=True)

# Define the model
model = SimpleConvNet()
model_fun = model_to_function(model)
params, model_fun = extract_requires_grad_and_func(dict(model.named_parameters()), model_fun)


def forward(params, inputs):
    logits = torch.func.functional_call(model, params, inputs)
    return logits

def loss_ggn(logits, labels):
    loss = torch.nn.functional.cross_entropy(logits, labels)
    return loss

def loss_fn(params, inputs, labels):
    logits = torch.func.functional_call(model, params, inputs)
    return torch.nn.functional.cross_entropy(logits, labels)

def accuracy(params, inputs, labels):
    logits = torch.func.functional_call(model, params, inputs)
    return torch.mean((torch.argmax(logits, axis=1) == labels).float())



train_accs: dict = {}
train_losses: dict = {}
test_accs: dict = {}
test_losses: dict = {}
times: dict = {}
optimizer_name = "ggn-thermo"
learning_rate = 1e-2
maxiter = 20
damping = 0.01
lm_damping = False
factor = 1.05
seeds = [0, 1, 2, 3, 4]
eps = 1e-8
num_train_epochs = 10
betas = (0.9, 0.999)
step = 0.1
step_delay = 0.0
delay_iters = 0
noise_variance = 0


average_regularization = False
momentum = 0.

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fisher_time = 0
results = {}  # Initialize a dictionary to store results for each seed

for seed in seeds:
    torch.manual_seed(seed)
    with torch.no_grad():
        if optimizer_name == "adam":
            optimizer = torchopt.adam(lr=learning_rate, betas=betas, eps=1e-8)
        elif optimizer_name == "adamw":
            optimizer = torchopt.adamw(lr=learning_rate, betas=betas, eps=1e-8, weight_decay=0)
        elif optimizer_name == "ngd-ggn" or optimizer_name == "ggn-thermo" or optimizer_name == 'sgd':
            optimizer = torchopt.sgd(lr=learning_rate, momentum=momentum)
        elif optimizer_name == "ngd-ggn-adam" or optimizer_name == "ggn-thermo-adam":
            optimizer = torchopt.adam(lr=learning_rate, betas=betas, eps=1e-8)
        
        model = SimpleConvNet()
        model_fun = model_to_function(model)
        params, model_fun = extract_requires_grad_and_func(dict(model.named_parameters()), model_fun)

        opt_state = optimizer.init(params) 
        cache = None

        # Initialize dictionaries for storing results of the current seed
        results[seed] = {
            'train_accs': [],
            'train_losses': [],
            'test_accs': [],
            'test_losses': [],
            'times': []
        }

        print(f"\nOptimizer {optimizer_name} starting for seed {seed}")
        for epoch in range(num_train_epochs):
            start_time = time.time()
            pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
            epoch_train_loss, epoch_test_loss = 0, 0

            for batch in pbar:
                x, y = batch
                
                if optimizer_name == "adam" or optimizer_name == "sgd":
                    grad = torch.func.grad(loss_fn)(params, x, y)

                elif optimizer_name == "ggn-thermo":
                    def forward_ggn(params):
                        logits = torch.func.functional_call(model, params, x)
                        return logits

                    def loss_ggn_thermo(logits):
                        loss = torch.nn.functional.cross_entropy(logits, y)
                        return loss

                    def partial_ggnvp(v):
                        return tree_map(lambda x: x/batch_size, ggnvp(forward_ggn, loss_ggn_thermo, (params,), (v,), normalize=False)[1])

                    grad_0 = torch.func.grad(loss_fn, 0)(params, x, y)  
                    if cache is None:
                        cache = grad_0
                    grad = thermo_solve_fvp(partial_ggnvp, 
                                            grad_0, 
                                            x0=cache, 
                                            iterations=maxiter, 
                                            step=step, 
                                            damping=damping, 
                                            average_regularization=average_regularization,
                                            noise_variance=noise_variance)
                    
                    if delay_iters > 0:
                        cache = thermo_solve_fvp(partial_ggnvp, 
                                                grad_0, 
                                                x0=grad, 
                                                iterations=delay_iters, 
                                                step=step_delay, 
                                                damping=damping, 
                                                average_regularization=average_regularization,
                                                noise_variance=noise_variance)
                    else:
                        cache = grad

                elif optimizer_name == "ngd-ggn":
                    def forward(params):
                        logits = torch.func.functional_call(model, params, x)
                        return logits

                    def loss_ggn(logits):
                        loss = torch.nn.functional.cross_entropy(logits, y)
                        return loss


                    def partial_ggnvp(v):
                        return tree_map(lambda x: x/batch_size, ggnvp(forward, loss_ggn, (sub_params,), (v,), normalize=False)[1])

                    grad_0 = torch.func.grad(loss_fn)(params, x, y)
                    start = time.time()
                    G = ggn(forward, loss_ggn, normalize=True)(params)
                    if average_regularization:
                        G = (1 - damping) * G + damping * torch.eye(G.shape[1])
                    else:
                        G = G + damping * torch.eye(G.shape[1])

                    flat_grad_0, unravel = tree_ravel(grad_0)
                    flat_grad = torch.linalg.solve(G, flat_grad_0)
                    fisher_time = time.time() - start
                    grad = unravel(flat_grad)
                    #grad = grad_0
                
                updates, opt_state = optimizer.update(grad, opt_state)
                loss = loss_fn(params, x, y)
                pbar.set_description(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
                
                if lm_damping:
                    delta_loss = loss_fn(_add(params, _mul(-1,grad)), x, y) - loss
                    damping = adjust_damping(damping, delta_loss, grad, grad_0, partial_ggnvp, factor=factor)

                params = torchopt.apply_updates(params, updates)
                results[seed]['train_losses'].append(loss.item())
                results[seed]['train_accs'].append(accuracy(params, x, y).item())
                epoch_train_loss += loss.item()

            epoch_train_loss /= len(train_loader)
            results[seed]['times'].append(time.time() - start_time)

            for batch in test_loader:
                test_inputs, test_labels = batch
                results[seed]['test_accs'].append(accuracy(params, test_inputs, test_labels).item())
                results[seed]['test_losses'].append(loss_fn(params, test_inputs, test_labels).item())
                epoch_test_loss += loss_fn(params, test_inputs, test_labels).item()
            epoch_test_loss /= len(test_loader)

            print(f"Epoch {epoch+1} - Train Loss: {epoch_train_loss:.4f}, Test Loss: {epoch_test_loss:.4f}")


import os
from datetime import datetime
import numpy as np 

def save_data():
    # Configuration dictionary
    dir_name = 'data/' +datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    os.makedirs(dir_name, exist_ok=True)
    config = {
        "optimizer_name": optimizer_name,
        "learning_rates": learning_rate,
        "maxiter": maxiter,
        "damping": damping,
        "lm_damping": lm_damping,
        "factor": factor,
        "seeds": seeds,
        "eps": eps,
        "num_train_epochs": num_train_epochs,
        "betas": betas,
        "batch_size": batch_size,
        "step": step,
        "average_regularization": average_regularization,
        "momentum": momentum,
        "delay_iters": delay_iters,
        "step_delay": step_delay,
        "noise_variance": noise_variance
    }
    
    # Writing the dictionary to a text file as JSON
    filename = dir_name+'/config.txt'
    with open(filename, 'w') as file:
        json.dump(config, file, indent=4)
        
    print(f"Configuration saved to {filename}")

    # Save the data
    np.save(dir_name+'/results.npy', results)

save_data()
