# Adapted from: https://github.com/locuslab/edge-of-stability
# Original paper: Cohen, J. M., Kaur, S., Li, Y., Kolter, J. Z., & Talwalkar, A. (2021).
# "Gradient Descent on Neural Networks Typically Occurs at the Edge of Stability", ICLR 2021.
# If you use this code, please cite the original work.
#
# Our modifications: lines 9, 14-15, 20, 23-29, 32-38, 49-57, 64-97, 105-109, 115-135, 146-156, 186-191, 196, 198-199.

from os import makedirs
import os

import torch
from torch.nn.utils import parameters_to_vector

import math
import copy 
import argparse

from archs import load_architecture
from utilities import get_gd_optimizer, get_gd_directory, get_loss_and_acc, compute_losses, \
    save_files, get_hessian_eigenvalues, iterate_dataset, save_files_es, get_norm, get_norm_from_init
from data import load_dataset, take_first, DATASETS

def int_or_str(value):
    try:
        return int(value)
    except ValueError:
        if value in {"nuc", "fro", "all"}:
            return value
        raise argparse.ArgumentTypeError(f"Invalid value: {value}. Must be an integer, 'nuc', or 'fro'.")  

def main(dataset: str, arch_id: str, loss: str, opt: str, lr: float, max_steps: int, neigs: int = 0,
         physical_batch_size: int = 1000, eig_freq: int = 3000, iterate_freq: int = -1, save_freq: int = 10000, #save_model: bool = False,
         beta: float = 0.0, nproj: int = 0,
         loss_goal: float = None, acc_goal: float = None, abridged_size: int = 5000, seed: int = 0, norm_input = "all",
         es_loss_goal: float = None, es_acc_goal: float = None, init_scaling: float = 1.0):
    directory = os.path.expanduser(get_gd_directory(dataset, lr, arch_id, seed, opt, loss, beta))
    if not init_scaling == 1.0:
        directory = directory + f"/scaling_{init_scaling}"
    print(f"output directory: {directory}")
    makedirs(directory, exist_ok=True)

    train_dataset, test_dataset = load_dataset(dataset, loss)
    abridged_train = take_first(train_dataset, abridged_size)

    loss_fn, acc_fn = get_loss_and_acc(loss)

    torch.manual_seed(seed)
    network = load_architecture(arch_id, dataset).cuda()
    if os.path.exists(directory+"/num_iterations"):
        num_iter = torch.load(directory+"/num_iterations", weights_only=False)
    else:
        num_iter = 0
        with torch.no_grad():  # Disable gradient tracking for this operation
            for param in network.parameters():
                param.mul_(init_scaling)
        network_init =  copy.deepcopy(network)
        save_files(directory, [("full_model_initial", network_init)]) # save initial model

    torch.manual_seed(7)
    projectors = torch.randn(nproj, len(parameters_to_vector(network.parameters())))

    optimizer = get_gd_optimizer(network.parameters(), opt, lr, beta)

    if norm_input == "all":
        norm_types = [1,"fro","nuc"]
    else:
        norm_types == [norm_input]

    if num_iter == 0:
        train_loss, test_loss, train_acc, test_acc = \
        torch.zeros(max_steps), torch.zeros(max_steps), torch.zeros(max_steps), torch.zeros(max_steps)
        iterates = torch.zeros(math.ceil(max_steps / iterate_freq) if iterate_freq > 0 else 0, len(projectors))
        eigs = torch.zeros(math.ceil(max_steps / eig_freq) if eig_freq > 0 else 0, neigs)
        norms = torch.zeros(math.ceil(max_steps / eig_freq) if eig_freq > 0 else 0, len(norm_types))
        norms_from_init = torch.zeros(math.ceil(max_steps / eig_freq) if eig_freq > 0 else 0, len(norm_types))
    else:
        if not all([os.path.exists(directory+"/"+st) for st in ["full_model_final", "eigs", "iterates", "test_acc", "train_acc", "train_loss", "norms", "full_model_initial", "norms_from_init"]]):
            raise FileNotFoundError("Number of past iterations is nonzero but a saved file does not exist")
        else:
            network = torch.load(directory+"/full_model_final", weights_only=False)
            network_init = torch.load(directory+"/full_model_initial", weights_only=False)
            train_loss = torch.cat((torch.load(directory+"/train_loss", weights_only=False)[:num_iter], torch.zeros(max_steps)))
            test_loss = torch.cat((torch.load(directory+"/test_loss", weights_only=False)[:num_iter], torch.zeros(max_steps)))
            train_acc = torch.cat((torch.load(directory+"/train_acc", weights_only=False)[:num_iter], torch.zeros(max_steps)))
            test_acc = torch.cat((torch.load(directory+"/test_acc", weights_only=False)[:num_iter], torch.zeros(max_steps)))
            iterates = torch.cat((torch.load(directory+"/iterates", weights_only=False)[:math.ceil(num_iter / iterate_freq)], torch.zeros(math.ceil(num_iter + max_steps / iterate_freq) - math.ceil(num_iter / iterate_freq) if iterate_freq > 0 else 0, len(projectors))), dim=0)
            eigs = torch.cat((torch.load(directory+"/eigs", weights_only=False)[:math.ceil(num_iter / eig_freq)], torch.zeros(math.ceil(num_iter + max_steps / eig_freq) - math.ceil(num_iter / eig_freq) if eig_freq > 0 else 0, neigs)), dim=0)
            norms = torch.cat((torch.load(directory+"/norms", weights_only=False)[:math.ceil(num_iter / eig_freq)], torch.zeros(math.ceil(num_iter + max_steps / eig_freq) - math.ceil(num_iter / eig_freq) if eig_freq > 0 else 0, len(norm_types))), dim=0)
            norms_from_init = torch.cat((torch.load(directory+"/norms_from_init", weights_only=False)[:math.ceil(num_iter / eig_freq)], torch.zeros(math.ceil(num_iter + max_steps / eig_freq) - math.ceil(num_iter / eig_freq) if eig_freq > 0 else 0, len(norm_types))), dim=0)
            print("Successfully loaded all files")

    
    ls, _ = compute_losses(network, [loss_fn, acc_fn], train_dataset, physical_batch_size)
    print(ls)
    next_power = 10**(math.floor(math.log10(ls)) )
    
    for step in range(num_iter, num_iter + max_steps):
        train_loss[step], train_acc[step] = compute_losses(network, [loss_fn, acc_fn], train_dataset,
                                                           physical_batch_size)
        test_loss[step], test_acc[step] = compute_losses(network, [loss_fn, acc_fn], test_dataset, physical_batch_size)

        if eig_freq != -1 and step % eig_freq == 0:
            eigs[step // eig_freq, :] = get_hessian_eigenvalues(network, loss_fn, abridged_train, neigs=neigs,
                                                                physical_batch_size=physical_batch_size)
            for i in range(len(norm_types)):
                norms[step // eig_freq, i] = get_norm(network, norm_types[i])
                norms_from_init[step // eig_freq, i] = get_norm_from_init(network, network_init, norm_types[i])             

            print("eigenvalues: ", eigs[step//eig_freq, :], ", norms: ", norms[step // eig_freq, :],", norms_from_init: ", norms_from_init[step // eig_freq, :])

        if iterate_freq != -1 and step % iterate_freq == 0:
            iterates[step // iterate_freq, :] = projectors.mv(parameters_to_vector(network.parameters()).cpu().detach())

        if save_freq != -1 and step % save_freq == 0:
            save_files(directory, [("full_model_final", network), ("norms", norms[:(step // eig_freq)+1,:]), ("norms_from_init", norms_from_init[:(step // eig_freq)+1,:]),
                                   ("eigs", eigs[:(step // eig_freq)+1]), ("iterates", iterates[:(step // iterate_freq)+1]),
                                   ("train_loss", train_loss[:step+1]), ("test_loss", test_loss[:step+1]),
                                   ("train_acc", train_acc[:step+1]), ("test_acc", test_acc[:step+1]), ("num_iterations", step+1)])
        print(f"{step}\t{train_loss[step]:.4f}\t{train_acc[step]:.4f}\t{test_loss[step]:.4f}\t{test_acc[step]:.4f}")

        if train_loss[step] <= next_power:
            norms_ = copy.deepcopy(norms)
            norms_from_init_ = copy.deepcopy(norms_from_init)
            for i in range(len(norm_types)):
                norms_[step // eig_freq, i] = get_norm(network, norm_types[i])
                norms_from_init_[step // eig_freq, i] = get_norm_from_init(network, network_init, norm_types[i])             
            eigs_ = copy.deepcopy(eigs)

            eigs_[step // eig_freq, :] = get_hessian_eigenvalues(network, loss_fn, abridged_train, neigs=neigs,
                                                                physical_batch_size=physical_batch_size)
            save_files_es(directory, next_power, [("full_model_final", network), ("norms", norms_[:(step // eig_freq)+1,:]), ("norms_from_init", norms_from_init_[:(step // eig_freq)+1,:]),
                                   ("eigs", eigs_[:(step // eig_freq)+1]), ("iterates", iterates[:(step // iterate_freq)+1]),
                                   ("train_loss", train_loss[:step+1]), ("test_loss", test_loss[:step+1]),
                                   ("train_acc", train_acc[:step+1]), ("test_acc", test_acc[:step+1]), ("num_iterations", step+1)])
            next_power = next_power/10

        if (loss_goal != None and train_loss[step] < loss_goal) or (acc_goal != None and train_acc[step] > acc_goal):
            break

        optimizer.zero_grad()
        for (X, y) in iterate_dataset(train_dataset, physical_batch_size):
            loss = loss_fn(network(X.cuda()), y.cuda()) / len(train_dataset)
            loss.backward()
        optimizer.step()

    eigs[step // eig_freq, :] = get_hessian_eigenvalues(network, loss_fn, abridged_train, neigs=neigs,
                                                                physical_batch_size=physical_batch_size)
    for i in range(len(norm_types)):
                norms[step // eig_freq, i] = get_norm(network, norm_types[i])
                norms_from_init[step // eig_freq, i] = get_norm_from_init(network, network_init, norm_types[i])             

    
    save_files(directory, [("full_model_final", network), ("norms", norms[:(step // eig_freq)+1,:]), ("norms_from_init", norms_from_init[:(step // eig_freq)+1,:]),
                                   ("eigs", eigs[:(step // eig_freq)+1]), ("iterates", iterates[:(step // iterate_freq)+1]),
                                   ("train_loss", train_loss[:step+1]), ("test_loss", test_loss[:step+1]),
                                   ("train_acc", train_acc[:step+1]), ("test_acc", test_acc[:step+1]), ("num_iterations", step+1)])


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train using gradient descent.")
    parser.add_argument("dataset", type=str, choices=DATASETS, help="which dataset to train")
    parser.add_argument("arch_id", type=str, help="which network architectures to train")
    parser.add_argument("loss", type=str, choices=["ce", "mse"], help="which loss function to use")
    parser.add_argument("lr", type=float, help="the learning rate")
    parser.add_argument("max_steps", type=int, help="the maximum number of gradient steps to train for")
    parser.add_argument("--opt", type=str, choices=["gd", "polyak", "nesterov"],
                        help="which optimization algorithm to use", default="gd")
    parser.add_argument("--seed", type=int, help="the random seed used when initializing the network weights",
                        default=0)
    parser.add_argument("--beta", type=float, help="momentum parameter (used if opt = polyak or nesterov)")
    parser.add_argument("--physical_batch_size", type=int,
                        help="the maximum number of examples that we try to fit on the GPU at once", default=1000)
    parser.add_argument("--acc_goal", type=float,
                        help="terminate training if the train accuracy ever crosses this value")
    parser.add_argument("--loss_goal", type=float, help="terminate training if the train loss ever crosses this value")
    parser.add_argument("--neigs", type=int, help="the number of top eigenvalues to compute")
    parser.add_argument("--eig_freq", type=int, default=-1,
                        help="the frequency at which we compute the top Hessian eigenvalues (-1 means never)")
    parser.add_argument("--nproj", type=int, default=0, help="the dimension of random projections")
    parser.add_argument("--iterate_freq", type=int, default=-1,
                        help="the frequency at which we save random projections of the iterates")
    parser.add_argument("--abridged_size", type=int, default=5000,
                        help="when computing top Hessian eigenvalues, use an abridged dataset of this size")
    parser.add_argument("--save_freq", type=int, default=-1,
                        help="the frequency at which we save resuls")
    parser.add_argument("--norm_type", type=int_or_str, default="all",
                        help="which type of norm should be computed during training: integer, nuc, or fro, use all for [1,2,nuc]")
    parser.add_argument("--es_acc_goal", type=float,
                        help="save model as early stopping if the train accuracy ever crosses this value")
    parser.add_argument("--es_loss_goal", type=float, help="save model as early stopping if the train loss ever crosses this value")
    parser.add_argument("--init_scaling", type=float, help="multiply initial weights by this", default=1.0)
    args = parser.parse_args()

    main(dataset=args.dataset, arch_id=args.arch_id, loss=args.loss, opt=args.opt, lr=args.lr, max_steps=args.max_steps,
         neigs=args.neigs, physical_batch_size=args.physical_batch_size, eig_freq=args.eig_freq,
         iterate_freq=args.iterate_freq, save_freq=args.save_freq, beta=args.beta,
         nproj=args.nproj, loss_goal=args.loss_goal, acc_goal=args.acc_goal, abridged_size=args.abridged_size,
         seed=args.seed, norm_input=args.norm_type, es_loss_goal=args.es_loss_goal, es_acc_goal=args.es_acc_goal,
         init_scaling=args.init_scaling)
