from itertools import repeat
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.datasets import fetch_20newsgroups_vectorized
from stocBiO import *

import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
import argparse
import torch
import hypergrad as hg
import numpy as np
import math
import time
import os


class CustomTensorIterator:
    def __init__(self, tensor_list, batch_size, **loader_kwargs):
        self.loader = DataLoader(TensorDataset(*tensor_list), batch_size=batch_size, **loader_kwargs)
        self.iterator = iter(self.loader)

    def __next__(self, *args):
        try:
            idx = next(self.iterator)
        except StopIteration:
            self.iterator = iter(self.loader)
            idx = next(self.iterator)
        return idx


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--epochs', default=100, type=int, help='epoch numbers')
    parser.add_argument('--T', default=10, type=int, help='inner update iterations')
    parser.add_argument('--batch_size', type=int, default=5657)
    parser.add_argument('--val_size', type=int, default=5657)
    parser.add_argument('--eta', type=float, default=0.5, help='used in Hessian')
    parser.add_argument('--hessian_q', type=int, default=10, help='number of steps to approximate hessian')
    parser.add_argument('--alg', type=str, default='RAGD-GS', choices=['RAHGD', 'RAF2BA', 'PRAHGD', 'PRAF2BA',
                                                                      'AID', 'PAID', 'ITD', 'BA-CG','F2BA','RAGD-GS'])
    # parser.add_argument('--alg', type=str, default='RAF2BA', choices=['RAHGD', 'RAF2BA', 'PRAHGD', 'PRAF2BA',
                                                                     # 'AID', 'PAID', 'ITD', 'BA-CG'])
    
    parser.add_argument('--training_size', type=int, default=5657)
    parser.add_argument('--inner_lr', type=float, default=100.0)
    parser.add_argument('--inner_mu', type=float, default=0.0)
    parser.add_argument('--outer_lr', type=float, default=100.0)
    parser.add_argument('--outer_mu', type=float, default=0.0)
    parser.add_argument('--save_folder', type=str, default='', help='path to save result')
    parser.add_argument('--model_name', type=str, default='', help='Experiment name')
    args = parser.parse_args()

    if not args.save_folder:
        args.save_folder = 'news_exp'
    args.model_name = '{}_bs_{}_vbs_{}_olrmu_{}_{}_ilrmu_{}_{}_eta_{}_T_{}_hessianq_{}'.format(args.alg, 
                       args.batch_size, args.val_size, args.outer_lr, args.outer_mu, args.inner_lr, 
                       args.inner_mu, args.eta, args.T, args.hessian_q)
    args.save_folder = os.path.join(args.save_folder, args.model_name)
    print(args.alg)
    if not os.path.isdir(args.save_folder):
        os.makedirs(args.save_folder)
   
    return args


def train_model(args):

    # Constant
    tol = 1e-12
    warm_start = True
    bias = False  # without bias outer_lr can be bigger (much faster convergence)
    train_log_interval = 100
    val_log_interval = 1

    # Basic Setting 
    seed = 0
    torch.manual_seed(seed)
    np.random.seed(seed)

    cuda = True and torch.cuda.is_available()
    cuda = False
    default_tensor_str = 'torch.cuda.FloatTensor' if cuda else 'torch.FloatTensor'
    kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
    torch.set_default_tensor_type(default_tensor_str)
    #torch.multiprocessing.set_start_method('forkserver')

    # Functions 
    def frnp(x): return torch.from_numpy(x).cuda().float() if cuda else torch.from_numpy(x).float()
    def tonp(x, cuda=cuda): return x.detach().cpu().numpy() if cuda else x.detach().numpy()

    def train_loss(params, hparams, data):
        x_mb, y_mb = data
        # print(x_mb.size()) = torch.Size([5657, 130107])
        out = out_f(x_mb,  params)
        return F.cross_entropy(out, y_mb) + reg_f(params, *hparams)

    def val_loss(opt_params, hparams):
        x_mb, y_mb = next(val_iterator)
        # print(x_mb.size()) = torch.Size([5657, 130107])
        out = out_f(x_mb,  opt_params[:len(parameters)])
        val_loss = F.cross_entropy(out, y_mb)
        pred = out.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        acc = pred.eq(y_mb.view_as(pred)).sum().item() / len(y_mb)

        val_losses.append(tonp(val_loss))
        val_accs.append(acc)
        return val_loss

    def reg_f(params, l2_reg_params, l1_reg_params=None):
        r = 0.5 * ((params[0] ** 2) * torch.exp(l2_reg_params.unsqueeze(1) * ones_dxc)).mean()
        if l1_reg_params is not None:
            r += (params[0].abs() * torch.exp(l1_reg_params.unsqueeze(1) * ones_dxc)).mean()
        return r

    def out_f(x, params):
        out = x @ params[0]
        out += params[1] if len(params) == 2 else 0
        return out

    def eval(params, x, y):
        out = out_f(x,  params)
        loss = F.cross_entropy(out, y)
        pred = out.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        acc = pred.eq(y.view_as(pred)).sum().item() / len(y)

        return loss, acc

    
    # load twentynews and preprocess
    val_size_ratio = 0.5
    X, y = fetch_20newsgroups_vectorized(subset='train', return_X_y=True,
                                        #remove=('headers', 'footers', 'quotes')
                                        )
    x_test, y_test = fetch_20newsgroups_vectorized(subset='test', return_X_y=True,
                                                #remove=('headers', 'footers', 'quotes')
                                                )
    x_train, x_val, y_train, y_val = train_test_split(X, y, stratify=y, test_size=val_size_ratio)
    train_samples, n_features = x_train.shape
    test_samples, n_features = x_test.shape
    val_samples, n_features = x_val.shape
    n_classes = np.unique(y_train).shape[0]
    # train_samples=5657, val_samples=5657, test_samples=7532, n_features=130107, n_classes=20
    print('Dataset 20newsgroup, train_samples=%i, val_samples=%i, test_samples=%i, n_features=%i, n_classes=%i'
        % (train_samples, val_samples, test_samples, n_features, n_classes))
    ys = [frnp(y_train).long(), frnp(y_val).long(), frnp(y_test).long()]
    xs = [x_train, x_val, x_test]

    if cuda:
        xs = [from_sparse(x).cuda() for x in xs]
    else:
        xs = [from_sparse(x) for x in xs]

    # x_train.size() = torch.Size([5657, 130107])
    # y_train.size() = torch.Size([5657])
    x_train, x_val, x_test = xs
    y_train, y_val, y_test = ys
    
    # torch.DataLoader has problems with sparse tensor on GPU
    iterators, train_list, val_list = [], [], []

    for bs, x, y in [(len(y_train), x_train, y_train), (len(y_val), x_val, y_val)]:
        iterators.append(repeat([x, y]))
    train_iterator, val_iterator = iterators


    # Initialize parameters
    l2_reg_params = torch.zeros(n_features).requires_grad_(True)  # one hp per feature
    l1_reg_params = (0.*torch.ones(1)).requires_grad_(True)  # one l1 hp only (best when really low)
    #l2_reg_params = (-20.*torch.ones(1)).requires_grad_(True)  # one l2 hp only (best when really low)
    #l1_reg_params = (-1.*torch.ones(n_features)).requires_grad_(True)
    hparams = [l2_reg_params]
    # hparams: the outer variables (or hyperparameters)
    ones_dxc = torch.ones(n_features, n_classes)

    outer_opt = torch.optim.SGD(lr=args.outer_lr, momentum=args.outer_mu, params=hparams)
    # outer_opt = torch.optim.Adam(lr=0.01, params=hparams)

    params_history = []
    val_losses, val_accs = [], []
    test_losses, test_accs = [], []
    w = torch.zeros(n_features, n_classes).requires_grad_(True)
    parameters = [w]

    # params_history: the inner iterates (from first to last)
    if bias:
        b = torch.zeros(n_classes).requires_grad_(True)
        parameters.append(b)
 
    if args.inner_mu > 0:
        #inner_opt = hg.Momentum(train_loss, inner_lr, inner_mu, data_or_iter=train_iterator)
        inner_opt = hg.HeavyBall(train_loss, args.inner_lr, args.inner_mu, data_or_iter=train_iterator)
    else:
        inner_opt = hg.GradientDescent(train_loss, args.inner_lr, data_or_iter=train_iterator)
    inner_opt_cg = hg.GradientDescent(train_loss, 1., data_or_iter=train_iterator)

    total_time = 0
    calls_num = 0
    loss_acc_time_results = np.zeros((args.epochs+1, 4))
    test_loss, test_acc = eval(parameters, x_test, y_test)
    loss_acc_time_results[0, 0] = test_loss
    loss_acc_time_results[0, 1] = test_acc
    loss_acc_time_results[0, 2] = 0.0
    loss_acc_time_results[0, 3] = 0

    if args.alg == 'PRAHGD' or args.alg == 'PRAF2BA':
        r = 0.1
        hparams0 = hparams[0] + torch.rand_like(hparams[0]) * r    # used in PRAHGD
        k, s = 0, 0
    elif args.alg == 'RAHGD' or args.alg == 'RAF2BA' or args.alg == 'RAGD-GS':
        hparams0 = hparams[0]
        k, s = 0, 0
    pk, huaT = 0, 3  # used in PAID
    for o_step in range(args.epochs):
        start_time = time.time()
        if args.alg == 'PRAHGD':
            inner_losses = []
            inner_theta = 0.005  # parameter of AGD
            outer_theta = 0.05
            B = 0.1
            parameters0 = parameters[0]
            for t in range(args.T):
                # loss_train = train_loss(parameters, hparams, train_list[train_index_list[t%train_list_len]])
                parameters_y = parameters[0] + (1 - inner_theta) * (parameters[0] - parameters0)
                loss_train = train_loss([parameters_y], hparams, [x_train, y_train])
                inner_grad = torch.autograd.grad(loss_train, [parameters_y])
                # if float(torch.norm(inner_grad)) < tol:
                #     break
                parameters0, parameters[0] = parameters[0], parameters_y - args.inner_lr * inner_grad[0]
                inner_losses.append(loss_train)

                if t % train_log_interval == 0 or t == args.T - 1:
                    print('t={} loss: {}'.format(t, inner_losses[-1]))

            hparams_y = hparams[0] + (1 - outer_theta) * (hparams[0] - hparams0)
            hg.CG(parameters, [hparams_y], args.hessian_q, inner_opt_cg, val_loss, stochastic=False, tol=tol)
            hparams0 = hparams[0]
            hparams[0] = hparams_y - args.outer_lr * hparams_y.grad
            s += float(torch.norm(hparams[0] - hparams0)) ** 2
            k += 1
            if k * s > B ** 2:
                hparams[0] = hparams0 + torch.rand_like(hparams0) * r
                k, s = 0, 0

            # hparams[0] = hparams[0] - args.outer_lr * outer_update
            calls_num += args.batch_size * (args.hessian_q + args.T)
            final_params = parameters
            for p, new_p in zip(parameters, final_params[:len(parameters)]):
                if warm_start:
                    p.data = new_p
                else:
                    p.data = torch.zeros_like(p)
            val_loss(final_params, hparams)

        elif args.alg == 'RAHGD':
            inner_losses = []
            inner_theta = 0.009  # parameter of AGD
            outer_theta = 0.05
            B = 0.1
            parameters0 = parameters[0]
            for t in range(args.T):
                # loss_train = train_loss(parameters, hparams, train_list[train_index_list[t%train_list_len]])
                parameters_y = parameters[0] + (1 - inner_theta) * (parameters[0] - parameters0)
                loss_train = train_loss([parameters_y], hparams, [x_train, y_train])
                inner_grad = torch.autograd.grad(loss_train, [parameters_y])
                # if float(torch.norm(inner_grad)) < tol:
                #     break
                parameters0, parameters[0] = parameters[0], parameters_y - args.inner_lr * inner_grad[0]
                inner_losses.append(loss_train)

                if t % train_log_interval == 0 or t == args.T - 1:
                    print('t={} loss: {}'.format(t, inner_losses[-1]))

            hparams_y = hparams[0] + (1 - outer_theta) * (hparams[0] - hparams0)
            hg.CG(parameters, [hparams_y], args.hessian_q, inner_opt_cg, val_loss, stochastic=False, tol=tol)
            hparams0 = hparams[0]
            hparams[0] = hparams_y - args.outer_lr * hparams_y.grad
            s += float(torch.norm(hparams[0] - hparams0)) ** 2
            k += 1
            if k * s > B ** 2:
                hparams[0] = hparams0
                k, s = 0, 0

            # hparams[0] = hparams[0] - args.outer_lr * outer_update
            calls_num += args.batch_size * (args.hessian_q + args.T)
            final_params = parameters
            for p, new_p in zip(parameters, final_params[:len(parameters)]):
                if warm_start:
                    p.data = new_p
                else:
                    p.data = torch.zeros_like(p)
            val_loss(final_params, hparams)

        elif args.alg == 'RAF2BA':
            lambda_F2BA = 500
            inner_losses = []
            inner_theta = 0.009  # parameter of AGD
            outer_theta = 0.05
            B = 0.1
            parameters0 = parameters[0]
            parameters_F2BA = parameters[0]
            parameters0_F2BA = parameters[0]
            for t in range(args.T):
                parameters_y = parameters[0] + (1 - inner_theta) * (parameters[0] - parameters0)
                loss_train = train_loss([parameters_y], hparams, [x_train, y_train])
                inner_grad = torch.autograd.grad(loss_train, [parameters_y])          
                parameters0, parameters[0] = parameters[0], parameters_y - args.inner_lr * inner_grad[0]
                inner_losses.append(loss_train)

                if t % train_log_interval == 0 or t == args.T - 1:
                    print('t={} loss: {}'.format(t, inner_losses[-1]))

            for t in range(args.T):
                parameters_F2BA_y = parameters_F2BA + (1 - inner_theta) * (parameters_F2BA - parameters0_F2BA)
                loss_F2BA = val_loss([parameters_F2BA_y],hparams) + lambda_F2BA * train_loss([parameters_F2BA_y], hparams, [x_train, y_train])
                inner_F2BA_y_grad = torch.autograd.grad(loss_F2BA, [parameters_F2BA_y])
                parameters0_F2BA, parameters_F2BA = parameters_F2BA, parameters_F2BA_y - args.inner_lr * inner_F2BA_y_grad[0]

            hparams_y = hparams[0] + (1 - outer_theta) * (hparams[0] - hparams0)
            hparams0 = hparams[0]

            loss_train_F2BA = train_loss([parameters_F2BA], hparams, [x_train, y_train])
            inner_grad_lambda_F2BA = torch.autograd.grad(loss_train_F2BA, hparams)[0]
            loss_train = train_loss(parameters, hparams, [x_train, y_train])
            inner_grad_lambda = torch.autograd.grad(loss_train, hparams)[0]
            outer_grad = lambda_F2BA * (inner_grad_lambda_F2BA - inner_grad_lambda) # + outer_grad_lambda_y

            hparams[0] = hparams_y - args.outer_lr * outer_grad
            s += float(torch.norm(hparams[0] - hparams0)) ** 2
            k += 1
            if k * s > B ** 2:
                hparams[0] = hparams0
                k, s = 0, 0

            calls_num += args.batch_size * (2 + 2 * args.T)
            final_params = parameters
            for p, new_p in zip(parameters, final_params[:len(parameters)]):
                if warm_start:
                    p.data = new_p
                else:
                    p.data = torch.zeros_like(p)
            val_loss(final_params, hparams)

        elif args.alg == 'RAGD-GS':
            lambda_F2BA = 100
            inner_losses = []
            inner_theta = 0.007  # parameter of AGD
            outer_theta = 0.03
            # inner_theta = 0.005  # parameter of AGD
            # outer_theta = 0.03
            # B = 0.1
            parameters0 = parameters[0]
            parameters_F2BA = parameters[0]
            parameters0_F2BA = parameters[0]
            for t in range(args.T):
                parameters_y = parameters[0] + (1 - inner_theta) * (parameters[0] - parameters0)
                loss_train = train_loss([parameters_y], hparams, [x_train, y_train])
                inner_grad = torch.autograd.grad(loss_train, [parameters_y])          
                parameters0, parameters[0] = parameters[0], parameters_y - args.inner_lr * inner_grad[0]
                inner_losses.append(loss_train)

                if t % train_log_interval == 0 or t == args.T - 1:
                    print('t={} loss: {}'.format(t, inner_losses[-1]))

            for t in range(args.T):
                parameters_F2BA_y = parameters_F2BA + (1 - inner_theta/7) * (parameters_F2BA - parameters0_F2BA)
                loss_F2BA = val_loss([parameters_F2BA_y],hparams) + lambda_F2BA * train_loss([parameters_F2BA_y], hparams, [x_train, y_train])
                inner_F2BA_y_grad = torch.autograd.grad(loss_F2BA, [parameters_F2BA_y])
                parameters0_F2BA, parameters_F2BA = parameters_F2BA, parameters_F2BA_y - args.inner_lr * inner_F2BA_y_grad[0]
            hparams_y = hparams[0] + (1+k)/(2+k) * (hparams[0] - hparams0)
            #hparams_y = hparams[0] + (1 - outer_theta) * (hparams[0] - hparams0)
            hparams0 = hparams[0]

            loss_train_F2BA = train_loss([parameters_F2BA], hparams, [x_train, y_train])
            inner_grad_lambda_F2BA = torch.autograd.grad(loss_train_F2BA, hparams)[0]
            loss_train = train_loss(parameters, hparams, [x_train, y_train])
            inner_grad_lambda = torch.autograd.grad(loss_train, hparams)[0]
            outer_grad = lambda_F2BA * (inner_grad_lambda_F2BA - inner_grad_lambda) # + outer_grad_lambda_y

            hparams[0] = hparams_y - args.outer_lr * outer_grad
            s += float(torch.norm(hparams[0] - hparams0)) ** 2
            k += 1
            if k**5 * s > 0.01:
                hparams[0] = hparams0
                k, s = 0, 0

            calls_num += args.batch_size * (2 + 2 * args.T)
            final_params = parameters
            for p, new_p in zip(parameters, final_params[:len(parameters)]):
                if warm_start:
                    p.data = new_p
                else:
                    p.data = torch.zeros_like(p)
            val_loss(final_params, hparams)

        elif args.alg == 'F2BA':
            lambda_F2BA = 100
            inner_losses = []
            inner_theta = 1  # parameter of AGD
            outer_theta = 1
            # B = 0.1
            parameters0 = parameters[0]
            parameters_F2BA = parameters[0]
            parameters0_F2BA = parameters[0]
            for t in range(args.T):
                parameters_y = parameters[0] + (1 - inner_theta) * (parameters[0] - parameters0)
                loss_train = train_loss([parameters_y], hparams, [x_train, y_train])
                inner_grad = torch.autograd.grad(loss_train, [parameters_y])          
                parameters0, parameters[0] = parameters[0], parameters_y - args.inner_lr * inner_grad[0]
                inner_losses.append(loss_train)

                if t % train_log_interval == 0 or t == args.T - 1:
                    print('t={} loss: {}'.format(t, inner_losses[-1]))

            for t in range(args.T):
                parameters_F2BA_y = parameters_F2BA + (1 - inner_theta) * (parameters_F2BA - parameters0_F2BA)
                loss_F2BA = val_loss([parameters_F2BA_y],hparams) + lambda_F2BA * train_loss([parameters_F2BA_y], hparams, [x_train, y_train])
                inner_F2BA_y_grad = torch.autograd.grad(loss_F2BA, [parameters_F2BA_y])
                parameters0_F2BA, parameters_F2BA = parameters_F2BA, parameters_F2BA_y - args.inner_lr * inner_F2BA_y_grad[0]
            hparams_y = hparams[0]
            #hparams_y = hparams[0] + (1 - outer_theta) * (hparams[0] - hparams0)
            hparams0 = hparams[0]

            loss_train_F2BA = train_loss([parameters_F2BA], hparams, [x_train, y_train])
            inner_grad_lambda_F2BA = torch.autograd.grad(loss_train_F2BA, hparams)[0]
            loss_train = train_loss(parameters, hparams, [x_train, y_train])
            inner_grad_lambda = torch.autograd.grad(loss_train, hparams)[0]
            outer_grad = lambda_F2BA * (inner_grad_lambda_F2BA - inner_grad_lambda) # + outer_grad_lambda_y

            hparams[0] = hparams_y - args.outer_lr * outer_grad


            calls_num += args.batch_size * (2 + 2 * args.T)
            final_params = parameters
            for p, new_p in zip(parameters, final_params[:len(parameters)]):
                if warm_start:
                    p.data = new_p
                else:
                    p.data = torch.zeros_like(p)
            val_loss(final_params, hparams)


        elif args.alg == 'PRAF2BA':
            lambda_F2BA = 500
            inner_losses = []
            inner_theta = 0.009  # parameter of AGD
            outer_theta = 0.05
            B = 0.1
            parameters0 = parameters[0]
            parameters_F2BA = parameters[0]
            parameters0_F2BA = parameters[0]
            for t in range(args.T):
                parameters_y = parameters[0] + (1 - inner_theta) * (parameters[0] - parameters0)
                loss_train = train_loss([parameters_y], hparams, [x_train, y_train])
                inner_grad = torch.autograd.grad(loss_train, [parameters_y])
                parameters0, parameters[0] = parameters[0], parameters_y - args.inner_lr * inner_grad[0]
                inner_losses.append(loss_train)

                if t % train_log_interval == 0 or t == args.T - 1:
                    print('t={} loss: {}'.format(t, inner_losses[-1]))

            for t in range(args.T):
                parameters_F2BA_y = parameters_F2BA + (1 - inner_theta) * (parameters_F2BA - parameters0_F2BA)
                loss_F2BA = val_loss([parameters_F2BA_y],hparams) + lambda_F2BA * train_loss([parameters_F2BA_y], hparams, [x_train, y_train])
                inner_F2BA_y_grad = torch.autograd.grad(loss_F2BA, [parameters_F2BA_y])
                parameters0_F2BA, parameters_F2BA = parameters_F2BA, parameters_F2BA_y - args.inner_lr * inner_F2BA_y_grad[0]

            hparams_y = hparams[0] + (1 - outer_theta) * (hparams[0] - hparams0)
            hparams0 = hparams[0]

            loss_train_F2BA = train_loss([parameters_F2BA], hparams, [x_train, y_train])
            inner_grad_lambda_F2BA = torch.autograd.grad(loss_train_F2BA, hparams)[0]
            loss_train = train_loss(parameters, hparams, [x_train, y_train])
            inner_grad_lambda = torch.autograd.grad(loss_train, hparams)[0]
            outer_grad = lambda_F2BA * (inner_grad_lambda_F2BA - inner_grad_lambda) # + outer_grad_lambda_y

            hparams[0] = hparams_y - args.outer_lr * outer_grad
            s += float(torch.norm(hparams[0] - hparams0)) ** 2
            k += 1
            if k * s > B ** 2:
                hparams[0] = hparams0 + torch.rand_like(hparams0) * r
                k, s = 0, 0

            calls_num += args.batch_size * (2 + 2 * args.T)
            final_params = parameters
            for p, new_p in zip(parameters, final_params[:len(parameters)]):
                if warm_start:
                    p.data = new_p
                else:
                    p.data = torch.zeros_like(p)
            val_loss(final_params, hparams)

        elif args.alg == 'ITD':
            inner_losses = []
            for t in range(args.T):
                loss_train = train_loss(parameters, hparams, [x_train, y_train])
                inner_grad = torch.autograd.grad(loss_train, parameters, create_graph=True)
                parameters[0] = parameters[0] - args.inner_lr * inner_grad[0]
                inner_losses.append(loss_train)

                if t % train_log_interval == 0 or t == args.T - 1:
                    print('t={} loss: {}'.format(t, inner_losses[-1]))
            loss_val = val_loss(parameters, hparams[0])
            outer_grad = torch.autograd.grad(loss_val, hparams[0])[0]
            hparams[0] = hparams[0] - args.outer_lr * outer_grad
            calls_num += args.batch_size * (1 + args.T)
            final_params = parameters
            for p, new_p in zip(parameters, final_params[:len(parameters)]):
                if warm_start:
                    p.data = new_p
                else:
                    p.data = torch.zeros_like(p)
            val_loss(final_params, hparams)

        elif args.alg == 'AID' or args.alg == 'PAID' or args.alg == 'BA-CG':
            if args.alg == 'AID':
                r = 0
                iteration = args.T
            elif args.alg == 'PAID':
                r = 0.1
                iteration = args.T
            else:
                r = 0
                iteration = int(math.pow(o_step+1, 1/4)*2)+1
            inner_losses = []
            for t in range(iteration):
                loss_train = train_loss(parameters, hparams, [x_train, y_train])
                inner_grad = torch.autograd.grad(loss_train, parameters)
                parameters[0] = parameters[0] - args.inner_lr * inner_grad[0]
                inner_losses.append(loss_train)

                if t % train_log_interval == 0 or t == iteration - 1:
                    print('t={} loss: {}'.format(t, inner_losses[-1]))

            hg.CG(parameters, hparams, args.hessian_q, inner_opt_cg, val_loss, stochastic=False, tol=tol)
            if torch.norm(hparams[0].grad) <= 0.8 * tol * 1e8 and o_step - pk >= huaT:
                hparams[0] = hparams[0] - args.outer_lr * (torch.rand_like(hparams[0]) * r + hparams[0].grad)
                pk = o_step
            else:
                hparams[0] = hparams[0] - args.outer_lr * hparams[0].grad

            calls_num += args.batch_size * (args.hessian_q + iteration)
            final_params = parameters
            for p, new_p in zip(parameters, final_params[:len(parameters)]):
                if warm_start:
                    p.data = new_p
                else:
                    p.data = torch.zeros_like(p)
            val_loss(final_params, hparams)

        
        iter_time = time.time() - start_time
        total_time += iter_time
        if o_step % val_log_interval == 0 or o_step == args.T-1:
            test_loss, test_acc = eval(final_params[:len(parameters)], x_test, y_test)
            loss_acc_time_results[o_step+1, 0] = test_loss
            loss_acc_time_results[o_step+1, 1] = test_acc
            loss_acc_time_results[o_step+1, 2] = total_time
            loss_acc_time_results[o_step+1, 3] = calls_num
            print('o_step={} ({:.2e}s) Val loss: {:.4e}, Val Acc: {:.2f}%'.format(o_step, iter_time, val_losses[-1],
                                                                                100*val_accs[-1]))
            print('          Test loss: {:.4e}, Test Acc: {:.2f}%'.format(test_loss, 100*test_acc))
            print('          l2_hp norm: {:.4e}'.format(torch.norm(hparams[0])))
            if len(hparams) == 2:
                print('          l1_hp : ', torch.norm(hparams[1]))
        if total_time > 150 and calls_num > 2.5e6:
            break

    file_name = 'results.npy'
    file_addr = os.path.join(args.save_folder, file_name)
    with open(file_addr, 'wb') as f:
        np.save(f, loss_acc_time_results)
    print(file_addr)

    print(loss_acc_time_results)
    print('HPO ended in {:.2e} seconds\n'.format(total_time))


def from_sparse(x):
    x = x.tocoo()
    values = x.data
    indices = np.vstack((x.row, x.col))

    i = torch.LongTensor(indices)
    v = torch.FloatTensor(values)
    shape = x.shape

    return torch.sparse.FloatTensor(i, v, torch.Size(shape))


def train_loss(params, hparams, data):
    x_mb, y_mb = data
    out = out_f(x_mb,  params)
    return F.cross_entropy(out, y_mb) + reg_f(params, *hparams)


def val_loss(opt_params, hparams):
    x_mb, y_mb = next(val_iterator)
    out = out_f(x_mb,  opt_params[:len(parameters)])
    val_loss = F.cross_entropy(out, y_mb)
    pred = out.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
    acc = pred.eq(y_mb.view_as(pred)).sum().item() / len(y_mb)

    val_losses.append(tonp(val_loss))
    val_accs.append(acc)
    return val_loss


def reg_f(params, l2_reg_params, l1_reg_params=None):
    ones_dxc = torch.ones(params[0].size())
    r = 0.5 * ((params[0] ** 2) * torch.exp(l2_reg_params.unsqueeze(1) * ones_dxc)).mean()
    if l1_reg_params is not None:
        r += (params[0].abs() * torch.exp(l1_reg_params.unsqueeze(1) * ones_dxc)).mean()
    return r


def reg_fs(params, hparams, loss):
    reg = reg_f(params, *hparams)
    return loss+reg


def out_f(x, params):
    out = x @ params[0]
    out += params[1] if len(params) == 2 else 0
    return out


def main():
    args = parse_args()
    args.alg = 'RAGD-GS'
    print(args)
    train_model(args)



if __name__ == '__main__':
    main()
