import os

import wandb

import torch
from torch.nn.utils import parameters_to_vector
from torch.optim.lr_scheduler import MultiStepLR

import argparse

from si import si_load_architecture, si_get_global_hessian_eigenvalues, si_get_spherical_hessian_eigen, si_get_spherical_hess_grad_product
from utilities import get_gd_optimizer,  get_loss_and_acc, compute_losses, iterate_dataset, recalc_bn_stats
    # get_gd_directory, save_files, save_files_final,get_hessian_eigenvalues,
from data import load_dataset, take_first, DATASETS, num_input_channels, image_size
import torchinfo


def main(dataset: str, arch_id: str, loss: str, opt: str, lr: float, max_steps: int, neigs: int = 0,
         physical_batch_size: int = 1000, eig_freq: int = -1, geig_freq: int = -1, test_freq: int = -1,
         save_freq: int = -1,
         # save_model: bool = False,
         beta: float = 0.0, siwd: float = 0.0, nsiwd: float = 0.0, siwd_decay_sche: list = None, lr_decay_sche: list = None,
         sisigma: float = 0.0,
         test_mode: str = 'like_training',
         loss_goal: float = None, acc_goal: float = None, abridged_size: int = 5000,
         seed: int = 0, use_double: bool = False, load_path: str = ''
         ):

    if siwd_decay_sche is None:
        siwd_decay_sche = []
    
    if use_double:
        torch.set_default_dtype(torch.float64)

    train_dataset, test_dataset = load_dataset(dataset, loss)
    abridged_train = take_first(train_dataset, abridged_size)

    loss_fn, acc_fn = get_loss_and_acc(loss)

    torch.manual_seed(seed)
    network = si_load_architecture(arch_id, dataset).cuda()
    if load_path:
        checkpoint = torch.load(load_path)
        network.load_state_dict(checkpoint)

    wandb.watch(network)
    torchinfo.summary(
        network,
        input_size=(physical_batch_size, num_input_channels(dataset), image_size(dataset), image_size(dataset)),
        dtypes=[torch.get_default_dtype()]
    )

    torch.manual_seed(7)

    optimizer = get_gd_optimizer([
        {'params': network.nsi_parameters(), 'weight_decay': nsiwd},
        {'params': network.si_parameters(), 'weight_decay': siwd},
    ], opt, lr, beta)
    
    if lr_decay_sche is not None:
        lr_sche = MultiStepLR(optimizer, milestones=lr_decay_sche)
    else:
        lr_sche = None

    phase = 0

    for step in range(0, max_steps + 1): # add one more step to save the logs for the final model
        for sc in siwd_decay_sche:
            if step == sc:
                factor = 0.1
                siwd *= factor
                for group in optimizer.param_groups:
                    group['weight_decay'] *= factor
                phase += 1
        if lr_sche is not None:
            lr = optimizer.param_groups[0]['lr']
        
        cur_log = {}

        si_norm = network.si_norm()
        if si_norm > 0:
            elr = lr / ((1 - lr * siwd) * si_norm ** 2)
        else:
            elr = -1
        
        cur_log.update({
            'phase': phase,
            'lr': lr,
            'two_over_lr': 2 / lr,
            'siwd': siwd,
            'elr/si': elr,
            'two_over_elr/si': 2 / elr,
        })

        if test_freq != -1 and step % test_freq == 0:
            if test_mode == 'recalc':
                recalc_bn_stats(network, train_dataset, physical_batch_size)
                network.eval()

            test_loss, test_acc = compute_losses(network, [loss_fn, acc_fn],
                                                            test_dataset, physical_batch_size)

            if test_mode == 'recalc':
                network.train()

            cur_log.update({
                'loss/test': test_loss,
                'acc/test': test_acc,
            })

            print(f"{step}\t{test_loss:.3f}\t{test_acc:.3f}")

        if geig_freq != -1 and step % geig_freq == 0:
            geigs = si_get_global_hessian_eigenvalues(
                network, loss_fn, abridged_train, neigs=neigs,
                physical_batch_size=physical_batch_size, siwd=siwd, nsiwd=nsiwd
            )
            cur_log.update({
                f'geigs/{k}': geigs[k] for k in range(neigs)
            })
            print("global eigenvalues: ", geigs)

        if eig_freq != -1 and step % eig_freq == 0:
            seigs, seigv = si_get_spherical_hessian_eigen(
                network, loss_fn, abridged_train, neigs=neigs,
                physical_batch_size=physical_batch_size
            )
            cur_log.update({
                f'seigs/{k}': seigs[k] for k in range(neigs)
            })
            print("spherical eigenvalues: ", seigs)
        else:
            seigv = None

        if save_freq != -1 and step % save_freq == 0:
            torch.save(network.state_dict(), os.path.join(wandb.run.dir, f'model-{step}.pth'))

        train_loss, train_acc = 0, 0

        optimizer.zero_grad()
        for (X, y) in iterate_dataset(train_dataset, physical_batch_size):
            out = network(X)
            cur_loss = loss_fn(out, y) / len(train_dataset)
            cur_loss.backward()
            with torch.no_grad():
                train_loss += cur_loss
                train_acc += acc_fn(out, y) / len(train_dataset)

        cur_log.update({
            'loss/train': train_loss,
            'acc/train': train_acc,
        })
        network.log_dict(cur_log)
        if seigv is not None:
            hess_grad_prod = si_get_spherical_hess_grad_product(network, loss_fn, abridged_train, physical_batch_size)
            network.log_eigg(cur_log, seigv, hess_grad_prod)
            

        wandb.log(cur_log)

        if (loss_goal != None and train_loss < loss_goal) or (acc_goal != None and train_acc > acc_goal):
            break

        optimizer.step()

        if sisigma > 0:
            with torch.no_grad():
                for p in network.si_parameters():
                    p.data += torch.randn_like(p.data) * sisigma * si_norm
        
        if lr_sche is not None:
            lr_sche.step()



if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train using gradient descent.")
    parser.add_argument("dataset", type=str, choices=DATASETS, help="which dataset to train")
    parser.add_argument("arch_id", type=str, help="which network architectures to train")
    parser.add_argument("loss", type=str, choices=["ce", "mse", "cels0.1"], help="which loss function to use")
    parser.add_argument("lr", type=float, help="the learning rate")
    parser.add_argument("max_steps", type=int, help="the maximum number of gradient steps to train for")
    parser.add_argument("--opt", type=str, choices=["gd", "polyak", "nesterov"],
                        help="which optimization algorithm to use", default="gd")
    parser.add_argument("--seed", type=int,
                        help="the random seed used when initializing the network weights", default=0)
    parser.add_argument("--beta", type=float, help="momentum parameter (used if opt = polyak or nesterov)")
    parser.add_argument("--siwd", default=5e-4, type=float,
                        help="weight decay parameter for scale-invariant parameters")
    parser.add_argument("--nsiwd", default=5e-4, type=float,
                        help="weight decay parameter for non-scale-invariant parameters")
    parser.add_argument("--siwd-decay-sche", nargs='+', type=float,
                        help="decay schedule for the weight decay parameter for scale-invariant parameters")
    parser.add_argument("--lr-decay-sche", nargs='+', type=float,
                        help="decay schedule for LR")
    parser.add_argument("--physical-batch-size", type=int,
                        help="the maximum number of examples that we try to fit on the GPU at once", default=1000)
    parser.add_argument("--acc-goal", type=float,
                        help="terminate training if the train accuracy ever crosses this value")
    parser.add_argument("--loss-goal", type=float, help="terminate training if the train loss ever crosses this value")
    parser.add_argument("--neigs", type=int,
                        help="the number of top eigenvalues to compute", default=3)
    parser.add_argument("--test-freq", type=int, default=10,
                        help="the frequency at which we evaluate the model on the test set")
    parser.add_argument("--test-mode", type=str, default='like_training', choices=['like_training', 'recalc'])
    parser.add_argument("--eig-freq", type=int, default=250,
                        help="the frequency at which we compute the top Hessian eigenvalues (-1 means never)")
    parser.add_argument("--geig-freq", type=int, default=250,
                        help="the frequency at which we compute the top global Hessian eigenvalues (-1 means never)")
    parser.add_argument("--abridged-size", type=int, default=5000,
                        help="when computing top Hessian eigenvalues, use an abridged dataset of this size")
    parser.add_argument("--save-freq", type=int, default=-1,
                        help="the frequency at which we save the model")
    parser.add_argument("--sisigma", default=0, type=float,
                        help="spherical noise scale")
    parser.add_argument("--use-double", default=False, action='store_true')
    parser.add_argument("--load-path", type=str, default='')
    args = parser.parse_args()

    wandb.init(
        project='__project_name__',
        save_code=True,
        config=vars(args),
        settings=wandb.Settings(start_method='fork')
    )

    main(dataset=args.dataset, arch_id=args.arch_id, loss=args.loss, opt=args.opt, lr=args.lr, max_steps=args.max_steps,
         neigs=args.neigs, physical_batch_size=args.physical_batch_size,
         eig_freq=args.eig_freq,
         geig_freq=args.geig_freq,
         test_freq=args.test_freq,
         save_freq=args.save_freq,
         beta=args.beta, siwd=args.siwd, nsiwd=args.nsiwd,
         sisigma=args.sisigma,
         siwd_decay_sche=args.siwd_decay_sche, lr_decay_sche=args.lr_decay_sche,
         test_mode=args.test_mode,
         loss_goal=args.loss_goal, acc_goal=args.acc_goal,
         abridged_size=args.abridged_size,
         seed=args.seed, use_double=args.use_double, load_path=args.load_path
         )
