from __future__ import print_function
import logging
import os
import sys
import datetime
import time
import setproctitle
import numpy as np
import argparse
from tqdm import tqdm, trange
import copy

# from numpy import linalg as LA

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch.cuda.amp import GradScaler, autocast


from pyhessian.hessian import hessian, my_hessian, ffcv_hessian
from pyhessian.utils import (
    hessian_vector_product,
    group_product,
    normalization,
    group_product_scalar,
    group_add_simple,
    get_grad,
    group_add,
    get_params,
)

from utils import *


EPS = 1e-24


def train(args):

    # logger
    logger = logging.getLogger(__name__)
    logging.basicConfig(
        filename=os.path.join("../log/" + args.name + "_" + time_now + ".log"),
        format="[%(asctime)s] - %(message)s",
        datefmt="%Y/%m/%d %H:%M:%S",
        # level=logging.DEBUG ############ Changed to Warning
    )
    # logger.info(args)
    logger.warning(args)

    # set random seed to reproduce the work
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)

    # get dataset / model / optimizer / rho_scheduler / criterion / lr_scheduler
    if args.ffcv:
        get_data_ = get_data_ffcv
        # hessian_ = ffcv_hessian

    else:
        get_data_ = get_data
        # hessian_ = my_hessian

    ### drop_last=True (for ffcv and train)
    train_loader, test_loader = get_data_(
        dataset=args.dataset,
        train_bs=args.batch_size,
        test_bs=args.test_batch_size,
        data_augmentation=False,
        normalization=True,
        shuffle=True,
        cutout=args.cutout,
        n_data=args.n_data,
    )

    train_loader2, _ = get_data_(
        dataset=args.dataset,
        train_bs=args.batch_size,
        test_bs=args.test_batch_size,
        data_augmentation=False,
        normalization=True,
        shuffle=True,
        cutout=args.cutout,
        n_data=args.n_data,
    )

    model = get_model(args.model, dataset=args.dataset, num_classes=args.num_classes)

    model = model.to(memory_format=torch.channels_last)

    if args.cuda:
        model = model.cuda()
    if args.parallel:
        model = torch.nn.DataParallel(model)

    optimizer = optim.SGD(
        model.parameters(),
        lr=args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay,
    )
    
    criterion = get_criterion(args.criterion, args.smoothing)

    lr_scheduler = get_lr_scheduler(
        args.lr_scheduler,
        optimizer=optimizer,
        milestones=args.milestones,
        gamma=args.gamma,
        epochs=args.epochs,
    )

    # loaders
    bs_list = [1]  # ,16,32,64,128,256,512,1024,2048,4096]
    data_loaders = []
    for bs in bs_list:
        sample_loader, _ = get_data_(
            dataset=args.dataset,
            train_bs=bs,
            test_bs=bs,
            data_augmentation=False,
            normalization=True,
            shuffle=True,
            cutout=args.cutout,
            n_data=args.n_data,
        )
        data_loaders.append(sample_loader)

    sample_loader, _ = get_data_(
        dataset=args.dataset,
        train_bs=bs,
        test_bs=bs,
        data_augmentation=False,
        normalization=True,
        shuffle=True,
        cutout=args.cutout,
        n_data=args.n_data,
    )
    full_batch_loader = sample_loader

    hessian_loader_, _ = get_data_(
        dataset=args.dataset,
        train_bs=args.hessian_test_batch_size,
        test_bs=args.hessian_test_batch_size,
        data_augmentation=False,
        normalization=True,
        shuffle=True,
        cutout=args.cutout,
        n_data=args.n_data,
    )

    hessian_loader = []
    for data_h, target_h in hessian_loader_:
        if args.cuda:
            data_h, target_h = data_h.cuda(), target_h.cuda()
        hessian_loader.append((data_h, target_h))
        break

    # GradScaler
    if args.ffcv:
        scaler = GradScaler()

    # starts values
    option_off = False

    # Default values

    # Ju_item = 0.
    # delta_cos = 0.
    # delta_cos2 = 0.
    # cos = 0.
    hessian_total_iter = 0.0
    top_eigenvalues = [0.0] * args.num_classes  # args.hessian_topn
    grad_norms = [0.0] * 11  # (len(bs_list)+1)
    acc = 1 / args.num_classes
    test_loss = np.log(args.num_classes)
    start_time = time.time()
    gtv = 0.0
    gtg = 0.0


    disable_running_stats(model)  ####################################
    # training (epoch) from 1 to 201
    for epoch in range(1, args.epochs + 1):
        print("Current Epoch: ", epoch)
        train_loss = 0.0
        total_num = 0
        correct = 0
        lr = optimizer.__dict__["param_groups"][0]["lr"]

        # training (batch_idx)
        with tqdm(total=len(train_loader)) as progressbar:
            for batch_idx, (data, target) in enumerate(train_loader):
                step = (epoch - 1) * len(train_loader) + batch_idx
                if args.cuda:
                    data, target = data.cuda(), target.cuda()

                model.train()
                for data_h, target_h in hessian_loader:
                    if args.cuda:
                        data_h, target_h = data_h.cuda(), target_h.cuda()

                        model.zero_grad(set_to_none=True)
                        hessian_comp = my_hessian(
                            criterion,  ######### smoothed version
                            data=(data_h, target_h),
                            cuda=args.cuda,
                            mode=args.hessian_mode,  ###########
                        )
                        hessian_comp.ready(model)
                        top_eigenvalues_, top_eigenvectors = hessian_comp.eigenvalues(
                            maxIter=args.hessian_maxiter,
                            tol=args.hessian_tol,
                            top_n=args.hessian_topn,
                        )
                        for idx in range(args.hessian_topn):
                            top_eigenvalues[idx] = top_eigenvalues_[idx]

                        evc1 = top_eigenvectors[0]

                        model.zero_grad(set_to_none=True)
                        hessian_total_iter = hessian_comp.total_iter
                        break

                model.train()
                # grad_norms_ = []
                grad_square = 0.0
                grad_std = 0.0
                grad_norm = 0.0

                total_grad = [torch.zeros_like(g) for g in evc1]
                len_full = 0
                L_t = 0.0
                for data_h, target_h in full_batch_loader:
                    if args.cuda:
                        data_h, target_h = data_h.cuda(), target_h.cuda()
                    if args.ffcv:
                        with autocast():
                            output = model(data_h)
                            c_loss = criterion(output, target_h)
                    else:
                        output = model(data_h)
                        c_loss = criterion(output, target_h)
                    L_t += c_loss.item()
                    grad = torch.autograd.grad(c_loss, model.parameters())
                    total_grad = group_add_simple(total_grad, grad)
                    len_full += 1
                L_t /= len_full
                grad_norm = (
                    group_product(total_grad, total_grad).sqrt().item() / len_full
                )
                model.zero_grad(set_to_none=True)
                grad_norms[-1] = grad_norm

                ##### subloop
                gtHg_list = []
                gtg_list = []
                L_t_next_list = []

                n_exp = len(train_loader2) - 1  #########
                for batch_idx2, (data2, target2) in enumerate(train_loader2):
                    if batch_idx2 >= n_exp:

                        ######
                        L_t_next = np.mean(L_t_next_list)
                        L_t_next_std = np.std(L_t_next_list)
                        gtHg = np.mean(gtHg_list)
                        gtHg_std = np.std(gtHg_list)
                        gtHg_over_GtG = gtHg / grad_norms[-1] ** 2
                        
                        gtg = np.mean(gtg_list)

                        ######
                        gtHg_over_gtg = gtHg/gtg # gtHg/gtg
                        
                        # gtg = 0 # group_product(grad, grad).sqrt().item() ###### sqrt   
                        gtHg_over_gtg = 0  # gtHg/gtg**2
                        gtG = 0  # group_product(grad, total_grad).item()
                        gtHg_over_gtG = 0  # gtHg/gtG
                        gtq = 0  # group_product(grad, evc1).abs().item()
                        gtng = 0  # gtng = group_product(grad, ng).item()
                        ngtng = (
                            0  # ngtng = group_product(ng, ng).sqrt().item() ###### sqrt
                        )
                        qtng = 0  # qtng = group_product(evc1, ng).abs().item()
                        normalized_g = 0  #  normalization(grad)
                        normalized_Hg = 0  # normalization(Hg)
                        # normalized_ng = normalization(ng)
                        cos_gq = (
                            0  # group_product(normalized_g, evc1).abs().item() ### abs
                        )
                        cos_gng = 0  # cos_gng = group_product(normalized_g, normalized_ng).item()
                        cos_qng = 0  # cos_qng = group_product(evc1, normalized_ng).abs().item() ### abs
                        cos_gHg = 0  # group_product(normalized_g, normalized_Hg).item()
                        cos_qHg = (
                            0  # group_product(evc1, normalized_Hg).abs().item() ### abs
                        )
                        if step != 0:
                            cos_gg = 0  # group_product(normalized_g, tmp_g).item()
                            cos_qq = (
                                0  # group_product(evc1, tmp_q).abs().item() ### abs
                            )
                            # cos_ngng =  group_product(normalized_ng, tmp_ng).item()
                            cos_HgHg = 0  # group_product(normalized_Hg, tmp_Hg).item()
                        else:
                            cos_gg = 0.0
                            cos_qq = 0.0
                            cos_ngng = 0.0
                            cos_HgHg = 0.0
                        tmp_g = 0  # normalized_g
                        tmp_q = 0  # evc1
                        # tmp_ng = normalized_ng
                        tmp_Hg = 0  # normalized_Hg
                        break

                    if args.cuda:
                        data2, target2 = data2.cuda(), target2.cuda()

                    model2 = copy.deepcopy(model)
                    base_optimizer2 = optim.SGD(
                        model2.parameters(),
                        lr=args.lr,
                        momentum=0.,
                        weight_decay=0.,
                    )
                    model2.train()
                    output = model2(data2)
                    c_loss = criterion(output, target2)

                    # optimizer step
                    ## (simple)
                    # single gradient method
                    base_optimizer2.zero_grad(set_to_none=True)
                    c_loss.backward()
                    grad = get_grad(model2)

                    # preconditioner_log.step()
                    # ng = preconditioner_log.natural_gradient()

                    model2.zero_grad(set_to_none=True)
                    ###### \theta_{t+1}.... Debugging
                    model2.train()
                    for data_h, target_h in hessian_loader:
                        if args.cuda:
                            data_h, target_h = data_h.cuda(), target_h.cuda()

                            model2.zero_grad(set_to_none=True)
                            hessian_comp = my_hessian(
                                criterion,  ######### smoothed version
                                data=(data_h, target_h),
                                cuda=args.cuda,
                                mode=args.hessian_mode,  ###########
                            )
                            hessian_comp.ready(model2)

                            Hg = hessian_comp.Hv(grad)

                            break
                    model2.zero_grad(set_to_none=True)

                    gtHg_each = group_product(grad, Hg).item()

                    gtHg_list.append(gtHg_each)

                    model2.train()
                    output = model2(data2)
                    c_loss = criterion(output, target2)
                    c_loss.backward()
                    base_optimizer2.step()

                    len_full = 0
                    L_t_next_each = 0.0
                    for data_h, target_h in full_batch_loader:
                        if args.cuda:
                            data_h, target_h = data_h.cuda(), target_h.cuda()
                        if args.ffcv:
                            with autocast():
                                output = model2(data_h)
                                c_loss = criterion(output, target_h)
                        else:
                            output = model2(data_h)
                            c_loss = criterion(output, target_h)
                        L_t_next_each += c_loss.item()
                        len_full += 1
                    L_t_next_each /= len_full
                    L_t_next_list.append(L_t_next_each)

                model.train()

                enable_running_stats(model)  ####################################
                output = model(data)
                disable_running_stats(model)  ####################################
                c_loss = criterion(output, target)
                c_loss.backward()

                optimizer.step()
                optimizer.zero_grad(set_to_none=True)  #############

                ### log + progressbar
                c_loss_item = c_loss.item()
                train_loss += target.size()[0] * c_loss.item()
                total_num += target.size()[0]
                _, predicted = output.max(1)
                correct_in_batch = predicted.eq(target).sum().item()
                correct += correct_in_batch

                progressbar.set_postfix(
                    loss=train_loss / total_num, acc=100.0 * correct / total_num
                )
                progressbar.update(target.size(0))

                ### log and evaluation
                with torch.no_grad():
                    e_train_loss = c_loss_item
                    e_train_acc = correct_in_batch / target.size()[0]
                    e_reg = 0
                    train_time = time.time()
                    if step != 0 and step % args.test_every_n_steps == 0:
                        model.eval()
                        with autocast():  ###################################
                            if args.criterion == "mse":
                                acc, test_loss = test(
                                    model, test_loader, print_opt=False, cr="mse"
                                )
                            else:
                                acc, test_loss = test(
                                    model, test_loader, print_opt=False
                                )
                    test_time = time.time()
                    log_list = (
                        [
                            step,
                            train_time - start_time,
                            test_time - train_time,
                            lr,
                            e_train_loss,
                            e_train_acc,
                            e_reg,  ###################
                            0,
                            test_loss,
                            acc,
                            hessian_total_iter,
                            gtq,
                            gtg,
                        ]
                        + top_eigenvalues
                        + grad_norms
                        + [
                            cos_gq,
                            cos_gg,
                            cos_qq,
                            0,  # cos_gng,
                            0,  # cos_qng,
                            0,  # cos_ngng,
                            0,  # gtng,
                            0,  # ngtng,
                            0,  # qtng,
                            grad_square,
                            grad_std,
                            gtHg,
                            gtHg_over_GtG,
                            cos_gHg,
                            cos_qHg,
                            cos_HgHg,
                            gtHg_over_gtg,
                            gtHg_over_gtG,
                            L_t_next,
                            L_t_next_std,
                            gtHg_std,
                            L_t,
                        ]
                    )

                    log_input = (*log_list,)
                    # logger.info(('%d'+'\t%.4f'*(len(log_input)-1))%(log_input))
                    logger.warning(
                        ("%d" + "\t%.5f" * (len(log_input) - 1)) % (log_input)
                    )
                    ##### Changed to Warning

        ### scheduler
        lr_scheduler.step()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Training ")
    parser.add_argument(
        "--method",
        type=str,
        default="SGD",
        help="SGD",
    )
    parser.add_argument(
        "--dataset", type=str, default="subcifar10", help="cifar10 / mnist / cifar100"
    )
    parser.add_argument("--cutout", action="store_true", help="do we use cutout or not")
    parser.add_argument(
        "--batch-size",
        type=int,
        default=128,
        help="input batch size for training (default: 128)",
    )
    parser.add_argument("--n-data", type=int, default=1024)
    parser.add_argument(
        "--test-batch-size",
        type=int,
        default=1024,
        help="input batch size for testing (default: 1024)",
    )
    parser.add_argument(
        "--hessian-test-batch-size",
        type=int,
        default=1024,
        help="input batch size for testing (default: 1024)",
    )
    parser.add_argument(
        "--epochs",
        "--epoch",
        type=int,
        default=100,
        help="number of epochs to train (default: 100)",
    )
    parser.add_argument(
        "--lr", type=float, default=0.01, help="learning rate (default: 0.01)"
    )
    parser.add_argument(
        "--weight-decay", default=0.0, type=float, help="weight decay (default: 0.0)"
    )
    parser.add_argument(
        "--momentum", default=0.0, type=float, help="momentum (default: 0.0)"
    )
    parser.add_argument("--model", type=str, default="Simple100", help="")

    parser.add_argument("--no-cuda", action="store_true", help="do we use gpu or not")
    parser.add_argument(
        "--no-parallel", action="store_true", help="do we use parallel or not"
    )
    parser.add_argument("--name", type=str, default="noname", help="choose saving name")
    parser.add_argument("--seed", type=int, default=1, help="random seed (default: 1)")

    parser.add_argument("--criterion", type=str, default="cross-entropy")
    parser.add_argument(
        "--lr-scheduler", type=str, default="constant", help="cosine / multistep/ constant"
    )
    parser.add_argument("--milestones", nargs="*", type=int, default=None)
    parser.add_argument("--gamma", type=float, default=0.2)

    parser.add_argument("--test-every-n-steps", type=int, default=200)

    parser.add_argument(
        "--label-smoothing", action="store_true", help="label smoothing"
    )
    parser.add_argument("--smoothing", type=float, default=0.0)
    parser.add_argument("--ffcv", action="store_true")

    parser.add_argument(
        "--hessian-topn", "--igs-topn", "--igs-top-n", type=int, default=1
    )
    parser.add_argument(
        "--hessian-tol", "--igs-tol", default=0.001, type=float, help="tol for IGS"
    )
    parser.add_argument(
        "--hessian-maxiter", "--igs-maxiter", "--maxiter-igs", type=int, default=100
    )
    parser.add_argument(
        "--hessian-mode", type=str, default="eval", help="eval or train"
    )

    args = parser.parse_args()

    ## default: ON
    args.cuda = not (args.no_cuda)
    args.parallel = not (args.no_parallel)

    ## modification
    if args.label_smoothing:
        args.criterion = "label_smoothing"

    if args.lr_scheduler == "multistep" and args.milestones is None:
        args.milestones = [
            int(args.epochs * 0.3),
            int(args.epochs * 0.6),
            int(args.epochs * 0.8),
        ]

    if args.dataset in ["cifar10", "subcifar10", "mnist", "submnist", "substl"]:
        args.num_classes = 10
    elif args.dataset in ["cifar100", "subcifar100"]:
        args.num_classes = 100
    else:
        raise ValueError("Unknown dataset")

    if args.hessian_topn is None:
        args.hessian_topn = args.num_classes

    # time
    time_now = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
    setproctitle.setproctitle(args.name + "_" + time_now)
    print(time_now)

    for arg in vars(args):
        print(arg, getattr(args, arg))

    train(args)
