from __future__ import print_function
import logging
import os
import sys
import datetime
import time
import setproctitle
import numpy as np
import argparse
from tqdm import tqdm, trange
import copy

# from numpy import linalg as LA

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch.cuda.amp import GradScaler, autocast


from pyhessian.hessian import hessian, my_hessian, ffcv_hessian
from pyhessian.utils import (
    hessian_vector_product,
    group_product,
    normalization,
    group_product_scalar,
    group_add_simple,
    get_grad,
    group_add,
    get_params,
)

from utils import *

EPS = 1e-24


def train(args):

    # logger
    logger = logging.getLogger(__name__)
    logging.basicConfig(
        filename=os.path.join("../log/" + args.name + "_" + time_now + ".log"),
        format="[%(asctime)s] - %(message)s",
        datefmt="%Y/%m/%d %H:%M:%S",
        # level=logging.DEBUG ############ Changed to Warning
    )
    # logger.info(args)
    logger.warning(args)

    # set random seed to reproduce the work
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    # get dataset / model / optimizer / rho_scheduler / criterion / lr_scheduler
    if args.ffcv:
        get_data_ = get_data_ffcv
        # hessian_ = ffcv_hessian
    else:
        get_data_ = get_data
        # hessian_ = my_hessian

    ### drop_last=True (for ffcv and train)
    train_loader, test_loader = get_data_(
        dataset=args.dataset,
        train_bs=args.batch_size,
        test_bs=args.test_batch_size,
        data_augmentation=False,
        normalization=True,
        shuffle=True,
        cutout=args.cutout,
        n_data=args.n_data,
    )

    model = get_model(args.model, dataset=args.dataset, num_classes=args.num_classes)

    model = model.to(memory_format=torch.channels_last)

    if args.cuda:
        model = model.cuda()
    if args.parallel:
        model = torch.nn.DataParallel(model)

    optimizer = optim.SGD(
        model.parameters(),
        lr=args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay,
    )

    method_tmp = args.method

    criterion = get_criterion(args.criterion, args.smoothing)

    lr_scheduler = get_lr_scheduler(
        args.lr_scheduler,
        optimizer=optimizer,
        milestones=args.milestones,
        gamma=args.gamma,
        epochs=args.epochs,
    )

    # loaders
    bs_list = [1]  # ,16,32,64,128,256,512,1024,2048,4096]
    data_loaders = []
    for bs in bs_list:
        sample_loader, _ = get_data_(
            dataset=args.dataset,
            train_bs=bs,
            test_bs=bs,
            data_augmentation=False,
            normalization=True,
            shuffle=True,
            cutout=args.cutout,
            n_data=args.n_data,
        )
    data_loaders.append(sample_loader)

    sample_loader, _ = get_data_(
        dataset=args.dataset,
        train_bs=bs,
        test_bs=bs,
        data_augmentation=False,
        normalization=True,
        shuffle=True,
        cutout=args.cutout,
        n_data=args.n_data,
    )
    full_batch_loader = sample_loader

    hessian_loader_, _ = get_data_(
        dataset=args.dataset,
        train_bs=args.hessian_test_batch_size,
        test_bs=args.hessian_test_batch_size,
        data_augmentation=False,
        normalization=True,
        shuffle=True,
        cutout=args.cutout,
        n_data=args.n_data,
    )

    hessian_loader = []
    for data_h, target_h in hessian_loader_:
        if args.cuda:
            data_h, target_h = data_h.cuda(), target_h.cuda()
        hessian_loader.append((data_h, target_h))
        break

    # GradScaler
    if args.ffcv:
        scaler = GradScaler()

    # Default values

    hessian_total_iter = 0.0
    top_eigenvalues = [0.0] * args.num_classes  # args.hessian_topn
    grad_norms = [0.0] * 11  # (len(bs_list)+1)
    acc = 1 / args.num_classes
    test_loss = np.log(args.num_classes)
    start_time = time.time()
    gtv = 0.0
    gtg = 0.0
    gtq = 0.0

    disable_running_stats(model)  ####################################
    # training (epoch)
    for epoch in range(1, args.epochs + 1):
        print("Current Epoch: ", epoch)
        train_loss = 0.0
        total_num = 0
        correct = 0
        lr = optimizer.__dict__["param_groups"][0]["lr"]

        # training (batch_idx)
        with tqdm(total=len(train_loader)) as progressbar:
            for batch_idx, (data, target) in enumerate(train_loader):
                step = (epoch - 1) * len(train_loader) + batch_idx
                if args.cuda:
                    data, target = data.cuda(), target.cuda()

                # Hessian computation
                model.train()
                for data_h, target_h in hessian_loader:
                    if args.cuda:
                        data_h, target_h = data_h.cuda(), target_h.cuda()

                        model.zero_grad(set_to_none=True)
                        hessian_comp = my_hessian(
                            criterion,  ######### smoothed version
                            data=(data_h, target_h),
                            cuda=args.cuda,
                            mode=args.hessian_mode,  ###########
                        )
                        hessian_comp.ready(model)
                        top_eigenvalues_, top_eigenvectors = hessian_comp.eigenvalues(
                            maxIter=args.hessian_maxiter,
                            tol=args.hessian_tol,
                            top_n=args.hessian_topn,
                        )
                        for idx in range(args.hessian_topn):
                            top_eigenvalues[idx] = top_eigenvalues_[idx]

                        evc1 = top_eigenvectors[0]

                        model.zero_grad(set_to_none=True)
                        hessian_total_iter = hessian_comp.total_iter
                        break

                model.train()
                grad_norms_ = []
                grad_square = 0.0
                grad_norm = 0.0
                nsample = 100
                for iii, (data_h, target_h) in enumerate(data_loaders[0]):
                    if args.cuda:
                        data_h, target_h = data_h.cuda(), target_h.cuda()
                    if args.ffcv:
                        with autocast():
                            output = model(data_h)
                            c_loss = criterion(output, target_h)
                    else:
                        output = model(data_h)
                        c_loss = criterion(output, target_h)

                    grad = torch.autograd.grad(c_loss, model.parameters())

                    grad_norm += group_product(grad, grad).sqrt().item()
                    grad_square += group_product(grad, grad).item()

                    model.zero_grad(set_to_none=True)
                    if iii == (nsample - 1):
                        grad_norms_.append(grad_norm / nsample)
                        grad_square = (grad_square / nsample) ** (1 / 2)
                        grad_std = np.abs(
                            grad_square**2 - (grad_norm / nsample) ** 2
                        ) ** (1 / 2)
                        break

                for sample_loader in data_loaders[1:]:
                    for data_h, target_h in sample_loader:
                        if args.cuda:
                            data_h, target_h = data_h.cuda(), target_h.cuda()
                        if args.ffcv:
                            with autocast():
                                output = model(data_h)
                                c_loss = criterion(output, target_h)
                        else:
                            output = model(data_h)
                            c_loss = criterion(output, target_h)

                        grad = torch.autograd.grad(c_loss, model.parameters())

                        grad_norm = group_product(grad, grad).sqrt().item()

                        model.zero_grad(set_to_none=True)
                        grad_norms_.append(grad_norm)
                        break

                total_grad = [torch.zeros_like(g) for g in grad]
                len_full = 0
                for data_h, target_h in full_batch_loader:
                    if args.cuda:
                        data_h, target_h = data_h.cuda(), target_h.cuda()
                    if args.ffcv:
                        with autocast():
                            output = model(data_h)
                            c_loss = criterion(output, target_h)
                    else:
                        output = model(data_h)
                        c_loss = criterion(output, target_h)

                    grad = torch.autograd.grad(c_loss, model.parameters())
                    total_grad = group_add_simple(total_grad, grad)
                    len_full += 1

                grad_norm = (
                    group_product(total_grad, total_grad).sqrt().item() / len_full
                )
                # grad_norms_.append(grad_norm)
                model.zero_grad(set_to_none=True)

                for idx in range(len(grad_norms_)):
                    grad_norms[idx] = grad_norms_[idx]
                grad_norms[-1] = grad_norm

                model.train()
                ### f(x)
                enable_running_stats(model)  ####################################
                if args.ffcv:
                    with autocast():
                        output = model(data)
                        c_loss = criterion(output, target)
                else:
                    output = model(data)
                    c_loss = criterion(output, target)
                c_loss_item = c_loss.item()
                disable_running_stats(model)  ####################################

                ### log + progressbar
                train_loss += target.size()[0] * c_loss.item()
                total_num += target.size()[0]
                _, predicted = output.max(1)
                correct_in_batch = predicted.eq(target).sum().item()
                correct += correct_in_batch

                progressbar.set_postfix(
                    loss=train_loss / total_num, acc=100.0 * correct / total_num
                )
                progressbar.update(target.size(0))

                # optimizer step

                # single gradient method
                optimizer.zero_grad(set_to_none=True)
                c_loss.backward()
                grad = get_grad(model)

                model.zero_grad(set_to_none=True)
                ###### \theta_{t+1}.... Debugging
                model.train()
                for data_h, target_h in hessian_loader:
                    if args.cuda:
                        data_h, target_h = data_h.cuda(), target_h.cuda()

                        model.zero_grad(set_to_none=True)
                        hessian_comp = my_hessian(
                            criterion,  ######### smoothed version
                            data=(data_h, target_h),
                            cuda=args.cuda,
                            mode=args.hessian_mode,  ###########
                        )
                        hessian_comp.ready(model)
                        Hg = hessian_comp.Hv(grad)

                        break
                model.zero_grad(set_to_none=True)

                losses = []
                HSns = []
                HSns_new = []

                cos_g_new_q_s = []
                cos_g_new_q_new_s = []
                lambdas = []
                if args.details:
                    for alpha_ in np.arange(1, 11):  # [0.25, 0.50, 0.75, 1.00, 1.25]:

                        alpha = 0.125 * alpha_
                        model_copy = copy.deepcopy(model)
                        params = get_params(model_copy)
                        group_add(params, grad, alpha=-lr * alpha)

                        output_copy = model_copy(data)
                        c_loss_copy = criterion(output_copy, target)
                        losses.append(c_loss_copy.item())

                        model_copy.train()
                        if alpha_ % 2 == 0:
                            for data_h, target_h in hessian_loader:
                                if args.cuda:
                                    data_h, target_h = data_h.cuda(), target_h.cuda()

                                model_copy.zero_grad(set_to_none=True)
                                hessian_comp = my_hessian(
                                    criterion,  ######### smoothed version
                                    data=(data_h, target_h),
                                    cuda=args.cuda,
                                    mode=args.hessian_mode,  ###########
                                )
                                hessian_comp.ready(model_copy)
                                grad_new = hessian_comp.gradsH

                                Hg_new = hessian_comp.Hv(grad_new)
                                gtHg_new = group_product(grad_new, Hg_new).item()
                                grad_norm_new = group_product(grad_new, grad_new).item()
                                HSn_new = gtHg_new / grad_norm_new
                                HSns_new.append(HSn_new)

                                Hg_copy = hessian_comp.Hv(grad)
                                gtHg_copy = group_product(grad, Hg_copy).item()
                                HSn = gtHg_copy / grad_norms[-1] ** 2
                                HSns.append(HSn)

                                normalized_g_new = normalization(grad_new)
                                cos_g_new_q = (
                                    group_product(normalized_g_new, evc1).abs().item()
                                )  ### abs
                                cos_g_new_q_s.append(cos_g_new_q)
                                (
                                    top_eigenvalues_new,
                                    top_eigenvectors_new,
                                ) = hessian_comp.eigenvalues(
                                    maxIter=args.hessian_maxiter,
                                    tol=args.hessian_tol,
                                    top_n=1,
                                )
                                lambda2 = top_eigenvalues_new[0]
                                evc2 = top_eigenvectors_new[0]
                                cos_g_new_q_new = (
                                    group_product(normalized_g_new, evc2).abs().item()
                                )  ### abs
                                lambdas.append(lambda2)
                                cos_g_new_q_new_s.append(cos_g_new_q_new)

                                break

                model.train()

                output = model(data)
                c_loss = criterion(output, target)
                c_loss.backward()

                optimizer.step()
                optimizer.zero_grad(set_to_none=True)  #############

                ######
                gtHg = group_product(grad, Hg).item()
                gtg = group_product(grad, grad).sqrt().item()  ###### sqrt

                gtHg_over_gtg = gtHg / gtg**2
                gtHg_over_GtG = gtHg / grad_norms[-1] ** 2

                gtG = group_product(grad, total_grad).item()
                gtHg_over_gtG = gtHg / gtG

                gtq = group_product(grad, evc1).abs().item()
                gtng = 0  # gtng = group_product(grad, ng).item()
                ngtng = 0  # ngtng = group_product(ng, ng).sqrt().item() ###### sqrt
                qtng = 0  # qtng = group_product(evc1, ng).abs().item()

                normalized_g = normalization(grad)
                normalized_Hg = normalization(Hg)
                # normalized_ng = normalization(ng)
                cos_gq = group_product(normalized_g, evc1).abs().item()  ### abs
                cos_gng = (
                    0  # cos_gng = group_product(normalized_g, normalized_ng).item()
                )
                cos_qng = 0  # cos_qng = group_product(evc1, normalized_ng).abs().item() ### abs

                cos_gHg = group_product(normalized_g, normalized_Hg).item()
                cos_qHg = group_product(evc1, normalized_Hg).abs().item()  ### abs

                if step != 0:
                    cos_gg = group_product(normalized_g, tmp_g).item()
                    cos_qq = group_product(evc1, tmp_q).abs().item()  ### abs
                    # cos_ngng =  group_product(normalized_ng, tmp_ng).item()
                    cos_HgHg = group_product(normalized_Hg, tmp_Hg).item()
                else:
                    cos_gg = 0.0
                    cos_qq = 0.0
                    cos_ngng = 0.0
                    cos_HgHg = 0.0

                tmp_g = normalized_g
                tmp_q = evc1
                # tmp_ng = normalized_ng
                tmp_Hg = normalized_Hg

                ### log and evaluation
                with torch.no_grad():
                    e_train_loss = c_loss_item
                    e_train_acc = correct_in_batch / target.size()[0]
                    e_reg = 0
                    train_time = time.time()
                    if step != 0 and step % args.test_every_n_steps == 0:
                        model.eval()
                        with autocast():  ###################################
                            if args.criterion == "mse":
                                acc, test_loss = test(
                                    model, test_loader, print_opt=False, cr="mse"
                                )
                            else:
                                acc, test_loss = test(
                                    model, test_loader, print_opt=False
                                )
                    test_time = time.time()
                    log_list = (
                        [
                            step,
                            train_time - start_time,
                            test_time - train_time,
                            lr,
                            e_train_loss,
                            e_train_acc,
                            e_reg,  ###################
                            0,
                            test_loss,
                            acc,
                            hessian_total_iter,
                            gtq,
                            gtg,
                        ]
                        + top_eigenvalues
                        + grad_norms
                        + [
                            cos_gq,
                            cos_gg,
                            cos_qq,
                            0,  # cos_gng,
                            0,  # cos_qng,
                            0,  # cos_ngng,
                            0,  # gtng,
                            0,  # ngtng,
                            0,  # qtng,
                            grad_square,
                            grad_std,
                            gtHg,
                            gtHg_over_GtG,
                            cos_gHg,
                            cos_qHg,
                            cos_HgHg,
                            gtHg_over_gtg,
                            gtHg_over_gtG,
                        ]
                        + losses
                        + HSns
                        + HSns_new
                        + cos_g_new_q_s
                        + cos_g_new_q_new_s
                        + lambdas
                    )

                    log_input = (*log_list,)
                    # logger.info(('%d'+'\t%.4f'*(len(log_input)-1))%(log_input))
                    logger.warning(
                        ("%d" + "\t%.5f" * (len(log_input) - 1)) % (log_input)
                    )
                    ##### Changed to Warning

        ### scheduler
        lr_scheduler.step()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Training ")
    parser.add_argument("--method", type=str, default="SGD", help="SGD")

    parser.add_argument(
        "--dataset", type=str, default="subcifar10", help="cifar10 / mnist / cifar100"
    )
    parser.add_argument("--cutout", action="store_true", help="do we use cutout or not")
    parser.add_argument(
        "--batch-size",
        type=int,
        default=1024,
        help="input batch size for training (default: 1024)",
    )
    parser.add_argument(
        "--test-batch-size",
        type=int,
        default=1024,
        help="input batch size for testing (default: 1024)",
    )
    parser.add_argument(
        "--hessian-test-batch-size",
        type=int,
        default=1024,
        help="input batch size for testing (default: 1024)",
    )
    parser.add_argument("--n-data", type=int, default=1024, help="(default: 1024)")
    parser.add_argument(
        "--epochs",
        "--epoch",
        type=int,
        default=1000,
        help="number of epochs to train (default: 1000)",
    )
    parser.add_argument(
        "--lr", type=float, default=0.01, help="learning rate (default: 0.01)"
    )
    parser.add_argument(
        "--weight-decay", default=0.0, type=float, help="weight decay (default: 0.0)"
    )
    parser.add_argument(
        "--momentum", default=0.0, type=float, help="momentum (default: 0.0)"
    )
    parser.add_argument("--model", type=str, default="Simple100", help="")

    parser.add_argument("--no-cuda", action="store_true", help="do we use gpu or not")
    parser.add_argument(
        "--no-parallel", action="store_true", help="do we use parallel or not"
    )
    parser.add_argument(
        "--saving-folder", type=str, default="pretrained/", help="choose saving name"
    )
    parser.add_argument("--savemodels", action="store_true", help="save models")
    parser.add_argument("--name", type=str, default="test", help="choose saving name")
    parser.add_argument("--seed", type=int, default=1, help="random seed (default: 1)")
    parser.add_argument("--criterion", type=str, default="cross-entropy")
    parser.add_argument(
        "--lr-scheduler", type=str, default="constant", help="cosine / multistep/ constant"
    )
    parser.add_argument("--milestones", nargs="*", type=int, default=None)
    parser.add_argument("--gamma", type=float, default=0.2)
    parser.add_argument("--test-every-n-steps", type=int, default=200)
    parser.add_argument(
        "--label-smoothing", action="store_true", help="label smoothing"
    )
    parser.add_argument("--smoothing", type=float, default=0.0)
    parser.add_argument("--ffcv", action="store_true")
    parser.add_argument(
        "--hessian-topn", "--igs-topn", "--igs-top-n", type=int, default=1
    )
    parser.add_argument(
        "--hessian-tol", "--igs-tol", default=0.001, type=float, help="tol for IGS"
    )
    parser.add_argument(
        "--hessian-maxiter", "--igs-maxiter", "--maxiter-igs", type=int, default=100
    )
    parser.add_argument(
        "--hessian-mode", type=str, default="eval", help="eval or train"
    )
    parser.add_argument("--details", action="store_true", help="to log details")

    args = parser.parse_args()

    ## default: ON
    args.cuda = not (args.no_cuda)
    args.parallel = not (args.no_parallel)

    ## modification
    if args.label_smoothing:
        args.criterion = "label_smoothing"
        
    if args.lr_scheduler == "multistep" and args.milestones is None:
        args.milestones = [
            int(args.epochs * 0.3),
            int(args.epochs * 0.6),
            int(args.epochs * 0.8),
        ]

    if args.dataset in [
        "cifar10",
        "subcifar10",
        "mnist",
        "submnist",
        "stl",
        "substl",
        "svhn",
        "subsvhn",
    ]:
        args.num_classes = 10
    elif args.dataset == "cifar100" or args.dataset == "subcifar100":
        args.num_classes = 100
    elif args.dataset in ["imagenet", "subimagenet"]:
        args.num_classes = 1000
    elif args.dataset in ["timagenet", "subtimagenet"]:
        args.num_classes = 200
    else:
        raise ValueError("Unknown dataset")
        
    if args.hessian_topn is None:
        args.hessian_topn = args.num_classes
        
    # time
    time_now = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
    setproctitle.setproctitle(args.name + "_" + time_now)
    print(time_now)

    for arg in vars(args):
        print(arg, getattr(args, arg))

    train(args)
