import os
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
import torch.optim
import argparse
import hypergrad as hg
import time
from itertools import repeat
from torch.nn import functional as F
from torchvision import datasets
from stocBiO import *
import math
from torch.utils.tensorboard import SummaryWriter
import csv
from itertools import product
from matplotlib import pyplot as plt
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_classes', default=10, type=int)
    parser.add_argument('--batch_size', type=int, default=20000)
    parser.add_argument('--test_size', type=int, default=10000)
    parser.add_argument('--epochs', type=int, default=100, help='K')
    parser.add_argument('--iterations', type=int, default=200, help='T')
    parser.add_argument('--outer_lr', type=float, default=0.1, help='beta')
    parser.add_argument('--inner_lr', type=float, default=0.1, help='alpha')
    parser.add_argument('--eta', type=float, default=0.5, help='used in Hessian')
    parser.add_argument('--data_path', default='data/', help='The temporary data storage path')
    parser.add_argument('--training_size', type=int, default=20000)
    parser.add_argument('--validation_size', type=int, default=5000)
    parser.add_argument('--noise_rate', type=float, default=0.2)
    parser.add_argument('--hessian_q', type=int, default=3)
    parser.add_argument('--save_folder', type=str, default='', help='path to save result')
    parser.add_argument('--model_name', type=str, default='', help='Experiment name')
    parser.add_argument('--seed', type=int, default=1)
    parser.add_argument('--alg', type=str, default='F2BA', choices=['PRAHGD', 'BA-CG', 'RAF2BA', 'PRAF2BA',
                                                          'RAHGD', 'AID', 'ITD', 'PAID','RAGD-GS',
                                                          'F2BA'])
    args = parser.parse_args()
    
    if not args.save_folder:
        args.save_folder = 'Bilevel_exp/mnist_exp'
    args.model_name = '{}_{}_bs_{}_olr_{}_ilr_{}_eta_{}_noiser_{}_q_{}_ite_{}'.format(args.alg, 
                       args.training_size, args.batch_size, args.outer_lr, args.inner_lr, args.eta, 
                       args.noise_rate, args.hessian_q, args.iterations)
    args.save_folder = os.path.join(args.save_folder, args.model_name)
    if not os.path.isdir(args.save_folder):
        os.makedirs(args.save_folder)
    return args


def get_data_loaders(args):
    kwargs = {'num_workers': 0, 'pin_memory': True}
    dataset = datasets.MNIST(root=args.data_path, train=True, download=True,
                        transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                        ]))
    train_sampler = torch.utils.data.sampler.SequentialSampler(dataset)
    train_loader = torch.utils.data.DataLoader(dataset, sampler=train_sampler,
        batch_size=args.batch_size, **kwargs)
    test_loader = torch.utils.data.DataLoader(datasets.MNIST(root=args.data_path, train=False,
                        download=True, transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                        ])), batch_size=args.test_size)
    return train_loader, test_loader


def train_model(args, train_loader, test_loader):
    writer = SummaryWriter(log_dir=args.save_folder)


    # if torch.cuda.is_available():
    #     device = torch.device('cuda')
    # else:
    #     device = torch.device('cpu')
    device = torch.device('cpu')
    print(device)

    np.random.seed(1204)
    torch.manual_seed(1204)
    parameters = torch.randn((args.num_classes, 785), requires_grad=True)
    parameters = nn.init.kaiming_normal_(parameters, mode='fan_out').to(device)
    lambda_x = torch.zeros((args.training_size)).to(device)
    lambda_x.requires_grad=True
    loss_time_results = np.zeros((args.epochs+1, 5))
    batch_num = args.training_size//args.batch_size
    train_loss_avg = loss_train_avg(train_loader, parameters, device, batch_num)
    test_loss_avg = loss_test_avg(test_loader, parameters, device)
    loss_time_results[0, :] = [train_loss_avg.cpu(), test_loss_avg.cpu(), (0.0), (0), (0.0)]
    print('Epoch: {:d} Train Loss: {:.4f} Test Loss: {:.4f}'.format(0, train_loss_avg, test_loss_avg))
    
    images_list, labels_list = [], []
    for index, (images, labels) in enumerate(train_loader):
        images_list.append(images)
        labels_list.append(labels)

    images_outer, labels_outer = images_list[-1], labels_list[-1]
    images_outer = torch.reshape(images_outer, (images_outer.size()[0], -1)).to(device)
    images_temp, labels_temp = images_outer[0:args.validation_size, :], labels_outer[0:args.validation_size]
    images_outer = torch.cat([images_temp] * (args.training_size // args.validation_size))
    labels_outer = torch.cat([labels_temp] * (args.training_size // args.validation_size))
    ## loss inner
    def loss_inner(parameters, weight, data_all):
        data = data_all[0]
        labels = data_all[1]
        data = torch.reshape(data, (data.size()[0], -1)).to(device)
        labels_cp = nositify(labels, args.noise_rate, args.num_classes).to(device)
        output = torch.matmul(data, torch.t(parameters[0][:, 0:784]))+parameters[0][:, 784]
        loss = F.cross_entropy(output, labels_cp, reduction='none')
        loss_regu = torch.mean(torch.mul(loss, torch.sigmoid(weight[0]))) + 0.001*torch.pow(torch.norm(parameters[0]),2)
        return loss_regu

    def loss_outer(parameters, lambda_x):
        output = torch.matmul(images_outer, torch.t(parameters[0][:, 0:784]))+parameters[0][:, 784]
        loss = F.cross_entropy(output, labels_outer)
        return loss

    def loss_outer_F2BA(parameters, lambda_x):
        # images, labels = images_list[-1], labels_list[-1]
        # images = torch.reshape(images, (images.size()[0], -1)).to(device)
        # images_temp, labels_temp = images[0:args.validation_size, :], labels[0:args.validation_size]
        # images = torch.cat([images_temp]*(args.training_size // args.validation_size))
        # labels = torch.cat([labels_temp]*(args.training_size // args.validation_size))
        output = torch.matmul(images_temp, torch.t(parameters[0][:, 0:784]))+parameters[0][:, 784]
        loss = F.cross_entropy(output, labels_temp)
        return loss

    def out_f(data, parameters):
        output = torch.matmul(data, torch.t(parameters[0][:, 0:784]))+parameters[0][:, 784]
        return output

    def reg_f(params, hparams, loss):
        loss_regu = torch.mean(torch.mul(loss, torch.sigmoid(hparams))) + 0.001*torch.pow(torch.norm(params[0]), 2)
        return loss_regu

    tol = 1e-12
    warm_start = True
    params_history = []
    train_iterator = repeat([images_list[0], labels_list[0]])
    inner_opt = hg.GradientDescent(loss_inner, args.inner_lr, data_or_iter=train_iterator)
    inner_opt_cg = hg.GradientDescent(loss_inner, 1., data_or_iter=train_iterator)
    outer_opt = torch.optim.SGD(lr=args.outer_lr, params=[lambda_x])

    start_time = time.time()
    calls_num = 0
    lambda_index_outer = 0
    pk, huaT = 0, 10 # used in PAID
    if args.alg in  ['PRAHGD', 'RAHGD', 'RAF2BA', 'PRAF2BA','RAGD-GS','F2BA']:
        lambda_x0 = lambda_x     # used in PRAHGD
        k, s = 0, 0
    for epoch in range(args.epochs):
        grad_norm_inner = 0.0
        if args.alg == 'PRAHGD':
            inner_theta = 0.005 # parameter of AGD
            outer_theta = 0.005
            B = 0.1
            r = 1e-2
            parameters0 = parameters
            weight = lambda_x
            for index in range(args.iterations):
                parameters_y = parameters + (1 - inner_theta) * (parameters - parameters0)
                images, labels = images_list[0], labels_list[0]
                images = torch.reshape(images, (images.size()[0], -1)).to(device)
                labels_cp = nositify(labels, args.noise_rate, args.num_classes).to(device)
                output = out_f(images, [parameters_y])
                inner_update = gradient_gy(args, labels_cp, parameters_y, images, weight, output, reg_f)
                if float(torch.norm(inner_update)) < tol:
                    break
                parameters0 = parameters
                parameters = parameters_y - args.inner_lr * inner_update
                calls_num += args.batch_size

            lambda_y = lambda_x + (1-outer_theta) * (lambda_x - lambda_x0)
            hg.CG([parameters], [lambda_y], args.hessian_q, inner_opt_cg, loss_outer, stochastic=False, tol=tol)
            calls_num += args.batch_size * 3
            lambda_x0 = lambda_x
            lambda_x = lambda_y - args.outer_lr * lambda_y.grad
            s += float(torch.norm(lambda_x - lambda_x0))**2
            k += 1
            if k * s > B**2:
                lambda_x = lambda_x0 + torch.rand_like(lambda_x) * r
                k, s = 0, 0

        elif args.alg == 'RAHGD':
            inner_theta = 0.005 # parameter of AGD
            outer_theta = 0.005
            B = 0.1
            parameters0 = parameters
            weight = lambda_x

            for index in range(args.iterations):
                parameters_y = parameters + (1 - inner_theta) * (parameters - parameters0)
                images, labels = images_list[0], labels_list[0]
                images = torch.reshape(images, (images.size()[0], -1)).to(device)
                labels_cp = nositify(labels, args.noise_rate, args.num_classes).to(device)
                output = out_f(images, [parameters_y])
                inner_update = gradient_gy(args, labels_cp, parameters_y, images, weight, output, reg_f)
                if float(torch.norm(inner_update)) < tol:
                    break
                parameters0 = parameters
                parameters = parameters_y - args.inner_lr * inner_update
                calls_num += args.batch_size


            lambda_y = lambda_x + (1-outer_theta) * (lambda_x - lambda_x0)
            hg.CG([parameters], [lambda_y], args.hessian_q, inner_opt_cg, loss_outer, stochastic=False, tol=tol)
            calls_num += args.batch_size * 3
            lambda_x0 = lambda_x
            lambda_x = lambda_y - args.outer_lr * lambda_y.grad
            s += float(torch.norm(lambda_x - lambda_x0))**2
            k += 1
            if k * s > B**2:
                lambda_x = lambda_x0
                k, s = 0, 0

        elif args.alg == 'RAF2BA':
            lambda_F2BA = 700
            inner_theta = 0.005 # parameter of AGD
            outer_theta = 0.005
            B = 0.1
            parameters0 = parameters
            parameters_F2BA = parameters
            parameters0_F2BA = parameters_F2BA
            weight = lambda_x

            for index in range(args.iterations):
                parameters_y = parameters + (1 - inner_theta) * (parameters - parameters0)
                parameters_F2BA_y = parameters_F2BA + (1 - inner_theta/7) * (parameters_F2BA - parameters0_F2BA)
                # parameters = parameters_y/(2-inner_theta) + (1-inner_theta) * parameters0/(2-inner_theta)

                images, labels = images_list[0], labels_list[0]
                images = torch.reshape(images, (images.size()[0], -1)).to(device)
                labels_cp = nositify(labels, args.noise_rate, args.num_classes).to(device)

                output = out_f(images, [parameters_y])
                inner_update = gradient_gy(args, labels_cp, parameters_y, images, weight, output, reg_f)

                output_F2BA = out_f(images, [parameters_F2BA_y])

                inner_update_F2BA = gradient_gy_lambda(args, labels_cp, parameters_F2BA_y, images, weight, output_F2BA,
                                                       reg_f, loss_outer_F2BA, lambda_F2BA)


                if float(torch.norm(inner_update)) < tol and float(torch.norm(inner_update_F2BA)) < tol:
                    break
                parameters0 = parameters
                parameters = parameters_y - args.inner_lr * inner_update
                calls_num += args.batch_size

                parameters0_F2BA = parameters_F2BA
                parameters_F2BA = parameters_F2BA_y - args.inner_lr * inner_update_F2BA
                calls_num += args.batch_size


            lambda_y = lambda_x + (1-outer_theta) * (lambda_x - lambda_x0)


            images, labels = images_list[0], labels_list[0]
            images = torch.reshape(images, (images.size()[0], -1)).to(device)
            labels_cp = nositify(labels, args.noise_rate, args.num_classes).to(device)

            output = out_f(images, [parameters])
            output_F2BA = out_f(images, [parameters_F2BA])

            loss = F.cross_entropy(output, labels_cp)
            loss_regu = reg_f(parameters, lambda_y, loss)
            inner_grad_lambda_y = torch.autograd.grad(loss_regu, lambda_y)[0]

            loss_F2BA = F.cross_entropy(output_F2BA, labels_cp)
            loss_regu_F2BA = reg_f(parameters_F2BA, lambda_y, loss_F2BA)
            inner_grad_lambda_y_F2BA = torch.autograd.grad(loss_regu_F2BA, lambda_y)[0]

            outer_update = lambda_F2BA * (inner_grad_lambda_y_F2BA - inner_grad_lambda_y) # + outer_grad_lambda_y

            lambda_x0 = lambda_x
            lambda_x = lambda_y - args.outer_lr * outer_update
            s += float(torch.norm(lambda_x - lambda_x0))**2
            k += 1
            if k * s > B**2:
                lambda_x = lambda_x0
                k, s = 0, 0

        elif args.alg == 'PRAF2BA':
            lambda_F2BA = 700
            inner_theta = 0.005 # parameter of AGD
            outer_theta = 0.005
            B = 0.1
            r = 1e-2
            parameters0 = parameters
            parameters_F2BA = parameters
            parameters0_F2BA = parameters_F2BA
            weight = lambda_x

            for index in range(args.iterations):
                parameters_y = parameters + (1 - inner_theta) * (parameters - parameters0)
                parameters_F2BA_y = parameters_F2BA + (1 - inner_theta/7) * (parameters_F2BA - parameters0_F2BA)
                # parameters = parameters_y/(2-inner_theta) + (1-inner_theta) * parameters0/(2-inner_theta)

                images, labels = images_list[0], labels_list[0]
                images = torch.reshape(images, (images.size()[0], -1)).to(device)
                labels_cp = nositify(labels, args.noise_rate, args.num_classes).to(device)

                output = out_f(images, [parameters_y])
                inner_update = gradient_gy(args, labels_cp, parameters_y, images, weight, output, reg_f)

                output_F2BA = out_f(images, [parameters_F2BA_y])

                inner_update_F2BA = gradient_gy_lambda(args, labels_cp, parameters_F2BA_y, images, weight, output_F2BA,
                                                       reg_f, loss_outer_F2BA, lambda_F2BA)


                if float(torch.norm(inner_update)) < tol and float(torch.norm(inner_update_F2BA)) < tol:
                    break
                parameters0 = parameters
                parameters = parameters_y - args.inner_lr * inner_update
                calls_num += args.batch_size

                parameters0_F2BA = parameters_F2BA
                parameters_F2BA = parameters_F2BA_y - args.inner_lr * inner_update_F2BA
                calls_num += args.batch_size


            lambda_y = lambda_x + (1-outer_theta) * (lambda_x - lambda_x0)


            images, labels = images_list[0], labels_list[0]
            images = torch.reshape(images, (images.size()[0], -1)).to(device)
            labels_cp = nositify(labels, args.noise_rate, args.num_classes).to(device)

            output = out_f(images, [parameters])
            output_F2BA = out_f(images, [parameters_F2BA])

            loss = F.cross_entropy(output, labels_cp)
            loss_regu = reg_f(parameters, lambda_y, loss)
            inner_grad_lambda_y = torch.autograd.grad(loss_regu, lambda_y)[0]

            loss_F2BA = F.cross_entropy(output_F2BA, labels_cp)
            loss_regu_F2BA = reg_f(parameters_F2BA, lambda_y, loss_F2BA)
            inner_grad_lambda_y_F2BA = torch.autograd.grad(loss_regu_F2BA, lambda_y)[0]

            outer_update = lambda_F2BA * (inner_grad_lambda_y_F2BA - inner_grad_lambda_y) # + outer_grad_lambda_y

            lambda_x0 = lambda_x
            lambda_x = lambda_y - args.outer_lr * outer_update
            s += float(torch.norm(lambda_x - lambda_x0))**2
            k += 1
            if k * s > B**2:
                lambda_x = lambda_x0 + torch.rand_like(lambda_x) * r
                k, s = 0, 0

        elif args.alg == 'PAID':
            r = 1e-2
            weight = lambda_x
            for index in range(args.iterations):
                images, labels = images_list[0], labels_list[0]
                images = torch.reshape(images, (images.size()[0], -1)).to(device)
                labels_cp = nositify(labels, args.noise_rate, args.num_classes).to(device)
                output = out_f(images, [parameters])
                inner_update = gradient_gy(args, labels_cp, parameters, images, weight, output, reg_f)
                if float(torch.norm(inner_update)) < tol:
                    break
                parameters = parameters - args.inner_lr * inner_update
                calls_num += args.batch_size

            hg.CG([parameters], [lambda_x], args.hessian_q, inner_opt_cg, loss_outer, stochastic=False, tol=tol)
            if torch.norm(lambda_x.grad) <= 0.8 * tol and epoch - pk > huaT:
                lambda_x -= args.outer_lr * torch.rand_like(lambda_x) * r
                pk = epoch
            lambda_x = lambda_x - args.outer_lr * lambda_x.grad

        elif args.alg == 'AID':
            weight = lambda_x
            for index in range(args.iterations):
                images, labels = images_list[0], labels_list[0]
                images = torch.reshape(images, (images.size()[0], -1)).to(device)
                labels_cp = nositify(labels, args.noise_rate, args.num_classes).to(device)
                output = out_f(images, [parameters])
                inner_update = gradient_gy(args, labels_cp, parameters, images, weight, output, reg_f)
                if float(torch.norm(inner_update)) < tol:
                    break
                parameters = parameters - args.inner_lr * inner_update
                calls_num += args.batch_size

            hg.CG([parameters], [lambda_x], args.hessian_q, inner_opt_cg, loss_outer, stochastic=False, tol=tol)
            lambda_x = lambda_x - args.outer_lr * lambda_x.grad


        elif args.alg == 'BA-CG':
            weight = lambda_x
            for index in range(int(math.pow(epoch+1, 1/4)*2)+1):
                images, labels = images_list[0], labels_list[0]
                images = torch.reshape(images, (images.size()[0], -1)).to(device)
                labels_cp = nositify(labels, args.noise_rate, args.num_classes).to(device)
                output = out_f(images, [parameters])
                inner_update = gradient_gy(args, labels_cp, parameters, images, weight, output, reg_f)
                if float(torch.norm(inner_update)) < tol:
                    break
                parameters = parameters - args.inner_lr * inner_update
                calls_num += args.batch_size

            hg.CG([parameters], [lambda_x], args.hessian_q, inner_opt_cg, loss_outer, stochastic=False, tol=tol)
            lambda_x = lambda_x - args.outer_lr * lambda_x.grad

        elif args.alg == 'ITD':
            weight = lambda_x
            for index in range(args.iterations):
                images, labels = images_list[0], labels_list[0]
                images = torch.reshape(images, (images.size()[0], -1)).to(device)
                labels_cp = nositify(labels, args.noise_rate, args.num_classes).to(device)
                output = out_f(images, [parameters])
                inner_update = gradient_gy(args, labels_cp, parameters, images, weight, output, reg_f)
                if float(torch.norm(inner_update)) < tol:
                    break
                parameters = parameters - args.inner_lr * inner_update
                calls_num += args.batch_size

            loss_o = loss_outer([parameters], lambda_x)
            # print(loss_o.requires_grad)  # 应为 True
            # print(lambda_x.requires_grad)  # 应为 True
            grad_lambda_x = torch.autograd.grad(loss_o, lambda_x)[0]
            lambda_x = lambda_x - args.outer_lr * grad_lambda_x
            lambda_x.detach().requires_grad = True
            parameters.detach().requires_grad = True
            # print(loss_o.requires_grad)  # 应为 True
            # print(lambda_x.requires_grad)  # 应为 True

        elif args.alg == 'RAGD-GS':
            lambda_F2BA = 700
            inner_theta = 0.005 # parameter of AGD
            # outer_theta = 0.005
            parameters0 = parameters
            parameters_F2BA = parameters
            parameters0_F2BA = parameters_F2BA
            weight = lambda_x
            for index in range(args.iterations):
                parameters_y = parameters + (1-inner_theta*4) * (parameters - parameters0)
                parameters_F2BA_y = parameters_F2BA + (1 - inner_theta) * (parameters_F2BA - parameters0_F2BA)
                # parameters = parameters_y/(2-inner_theta) + (1-inner_theta) * parameters0/(2-inner_theta)
                # parameters_F2BA_y = parameters_F2BA + (1 - args.iinner_theta) * (parameters_F2BA - parameters0_F2BA)
                
                images, labels = images_list[0], labels_list[0]
                images = torch.reshape(images, (images.size()[0], -1)).to(device)
                labels_cp = nositify(labels, args.noise_rate, args.num_classes).to(device)

                output = out_f(images, [parameters_y])
                inner_update = gradient_gy(args, labels_cp, parameters_y, images, weight, output, reg_f)

                output_F2BA = out_f(images, [parameters_F2BA_y])

                inner_update_F2BA = gradient_gy_lambda(args, labels_cp, parameters_F2BA_y, images, weight, output_F2BA,
                                                       reg_f, loss_outer_F2BA, args.lambda_F2BA)


                if float(torch.norm(inner_update)) < tol and float(torch.norm(inner_update_F2BA)) < tol:
                    break
                parameters0 = parameters
                parameters = parameters_y - args.inner_lr * inner_update
                calls_num += args.batch_size

                parameters0_F2BA = parameters_F2BA
                parameters_F2BA = parameters_F2BA_y - args.inner_lr * inner_update_F2BA
                calls_num += args.batch_size


            lambda_y = lambda_x + (k+1)/(k+2) * (lambda_x - lambda_x0)


            images, labels = images_list[0], labels_list[0]
            images = torch.reshape(images, (images.size()[0], -1)).to(device)
            labels_cp = nositify(labels, args.noise_rate, args.num_classes).to(device)

            output = out_f(images, [parameters])
            output_F2BA = out_f(images, [parameters_F2BA])

            loss = F.cross_entropy(output, labels_cp)
            loss_regu = reg_f(parameters, lambda_y, loss)
            inner_grad_lambda_y = torch.autograd.grad(loss_regu, lambda_y)[0]

            loss_F2BA = F.cross_entropy(output_F2BA, labels_cp)
            loss_regu_F2BA = reg_f(parameters_F2BA, lambda_y, loss_F2BA)
            inner_grad_lambda_y_F2BA = torch.autograd.grad(loss_regu_F2BA, lambda_y)[0]

            outer_update = lambda_F2BA * (inner_grad_lambda_y_F2BA - inner_grad_lambda_y) # + outer_grad_lambda_y

            lambda_x0 = lambda_x
            lambda_x = lambda_y - args.outer_lr * outer_update
            s += float(torch.norm(lambda_x - lambda_x0))**2
            k += 1
            # if k * s > B**2:
            if (k+1)**5*s>0.4:
                lambda_x = lambda_x0
                k, s = 0, 0

        elif args.alg == 'F2BA':
            lambda_F2BA = 300
            inner_theta = 1 # parameter of AGD
            outer_theta = 1
            parameters0 = parameters
            parameters_F2BA = parameters
            parameters0_F2BA = parameters_F2BA
            weight = lambda_x

            for index in range(args.iterations):
                # parameters_y = parameters + (1-inner_theta) * (parameters - parameters0)
                # parameters_F2BA_y = parameters_F2BA + (1 - inner_theta/7) * (parameters_F2BA - parameters0_F2BA)
                parameters_y = parameters
                parameters_F2BA_y = parameters_F2BA
                # parameters = parameters_y/(2-inner_theta) + (1-inner_theta) * parameters0/(2-inner_theta)

                images, labels = images_list[0], labels_list[0]
                images = torch.reshape(images, (images.size()[0], -1)).to(device)
                labels_cp = nositify(labels, args.noise_rate, args.num_classes).to(device)

                output = out_f(images, [parameters_y])
                inner_update = gradient_gy(args, labels_cp, parameters_y, images, weight, output, reg_f)

                output_F2BA = out_f(images, [parameters_F2BA_y])

                inner_update_F2BA = gradient_gy_lambda(args, labels_cp, parameters_F2BA_y, images, weight, output_F2BA,
                                                       reg_f, loss_outer_F2BA, lambda_F2BA)


                if float(torch.norm(inner_update)) < tol and float(torch.norm(inner_update_F2BA)) < tol:
                    break
                parameters0 = parameters
                parameters = parameters_y - args.inner_lr * inner_update
                calls_num += args.batch_size

                parameters0_F2BA = parameters_F2BA
                parameters_F2BA = parameters_F2BA_y - args.inner_lr * inner_update_F2BA
                calls_num += args.batch_size


            lambda_y = lambda_x + (1-outer_theta) * (lambda_x - lambda_x0)


            images, labels = images_list[0], labels_list[0]
            images = torch.reshape(images, (images.size()[0], -1)).to(device)
            labels_cp = nositify(labels, args.noise_rate, args.num_classes).to(device)

            output = out_f(images, [parameters])
            output_F2BA = out_f(images, [parameters_F2BA])

            loss = F.cross_entropy(output, labels_cp)
            loss_regu = reg_f(parameters, lambda_y, loss)
            inner_grad_lambda_y = torch.autograd.grad(loss_regu, lambda_y)[0]

            loss_F2BA = F.cross_entropy(output_F2BA, labels_cp)
            loss_regu_F2BA = reg_f(parameters_F2BA, lambda_y, loss_F2BA)
            inner_grad_lambda_y_F2BA = torch.autograd.grad(loss_regu_F2BA, lambda_y)[0]

            outer_update = lambda_F2BA * (inner_grad_lambda_y_F2BA - inner_grad_lambda_y) # + outer_grad_lambda_y

            lambda_x0 = lambda_x
            lambda_x = lambda_y - args.outer_lr * outer_update
            s += float(torch.norm(lambda_x - lambda_x0))**2
            k += 1
            # if k * s > B**2:
            # if (k+1)**4.5*s**0.5>0.4:
            #     lambda_x = lambda_x0
            #     print(s,k)
            #     k, s = 0, 0


        train_loss_avg = loss_train_avg(train_loader, parameters, device, batch_num)
        test_loss_avg = loss_test_avg(test_loader, parameters, device)
        end_time = time.time()
        print('Epoch: {:d} Train Loss: {:.4f} Test Loss: {:.4f} Time: {:.4f}'.format(epoch+1, train_loss_avg,
                                                                            test_loss_avg, (end_time-start_time)))
        writer.add_scalar('Loss/train', train_loss_avg, epoch)
        writer.add_scalar('Loss/test', test_loss_avg, epoch)
        writer.add_scalar('Time', (end_time-start_time), epoch)
        writer.add_scalar('Calls_num', calls_num, epoch)
        writer.add_scalar('Grad_norm_inner', grad_norm_inner, epoch)

        loss_time_results[epoch+1, 0] = train_loss_avg
        loss_time_results[epoch+1, 1] = test_loss_avg
        loss_time_results[epoch+1, 2] = (end_time-start_time)
        loss_time_results[epoch + 1, 3] = calls_num
        loss_time_results[epoch+1, 4] = grad_norm_inner

        if end_time - start_time > 160:
            break

    # print(loss_time_results)
    file_name = str(args.seed)+'.npy'
    file_addr = os.path.join(args.save_folder, file_name)
    with open(file_addr, 'wb') as f:
            np.save(f, loss_time_results)

    writer.close()
def loss_train_avg(data_loader, parameters, device, batch_num):
    loss_avg, num = 0.0, 0
    for index, (images, labels) in enumerate(data_loader):
        if index >= batch_num:
            break
        else:
            images = torch.reshape(images, (images.size()[0], -1)).to(device)
            labels = labels.to(device)
            loss = loss_f_funciton(labels, parameters, images)
            loss_avg += loss
            num += 1
    loss_avg = loss_avg/num
    return loss_avg.detach()


def loss_test_avg(data_loader, parameters, device):
    loss_avg, num = 0.0, 0
    for _, (images, labels) in enumerate(data_loader):
        images = torch.reshape(images, (images.size()[0],-1)).to(device)
        # images = torch.cat((images, torch.ones(images.size()[0],1)),1)
        labels = labels.to(device)
        loss = loss_f_funciton(labels, parameters, images)
        loss_avg += loss 
        num += 1
    loss_avg = loss_avg/num
    return loss_avg.detach()


def loss_f_funciton(labels, parameters, data):
    output = torch.matmul(data, torch.t(parameters[:, 0:784]))+parameters[:, 784]
    loss = F.cross_entropy(output, labels)
    return loss


def nositify(labels, noise_rate, n_class):
    num = noise_rate*(labels.size()[0])
    num = int(num)
    randint = torch.randint(1, 10, (num,))
    index = torch.randperm(labels.size()[0])[:num]
    labels[index] = (labels[index]+randint) % n_class
    return labels


def build_val_data(args, val_index, images_list, labels_list, device):
    if len(val_index) < 3:
        val_index = [0, 0, 0]
    else:
        val_index = -(val_index)
    val_images_list, val_labels_list = [], []
    
    images, labels = images_list[val_index[0]], labels_list[val_index[0]]
    images = torch.reshape(images, (images.size()[0], -1)).to(device)
    labels = labels.to(device)
    val_images_list.append(images)
    val_labels_list.append(labels)

    images, labels = images_list[val_index[1]], labels_list[val_index[1]]
    images = torch.reshape(images, (images.size()[0],-1)).to(device)
    labels_cp = nositify(labels, args.noise_rate, args.num_classes).to(device)
    val_images_list.append(images)
    val_labels_list.append(labels_cp)

    images, labels = images_list[val_index[2]], labels_list[val_index[2]]
    images = torch.reshape(images, (images.size()[0], -1)).to(device)
    labels_cp = nositify(labels, args.noise_rate, args.num_classes).to(device)
    val_images_list.append(images)
    val_labels_list.append(labels_cp)

    return [val_images_list, val_labels_list]


def main():
    args = parse_args()
    args.alg = 'RAGD-GS'
    print(args)
    train_loader, test_loader = get_data_loaders(args)
    train_model(args, train_loader, test_loader)



if __name__ == '__main__':
    main()

