import argparse
import copy
import math
import numpy as np
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from sklearn.utils.extmath import safe_sparse_dot
import matplotlib.pyplot as plt
from scipy import optimize

import tls as ls
from minibatch import MinibatchSampler
from matplotlib.font_manager import FontProperties



from torchvision import datasets


################################################################################
#
#  Bilevel Optimization
#
#  min_{x,w} f(x, w)
#  s.t. x = argmin_x g(x, w)
#
#  here: f(x, w) is on valset
#        g(x, w) is on trainset
#
#  f_x = df/dx
#  f_w = df/dw
#  g_x = dg/dx
#  g_w = dg/dw
#
################################################################################
METHODS = [
    'BSG_1',
    'blfoa',
    'BOME',
    'shine',
    'shinea',
    'blfoae',
]

def variance_reduction(grad, memory, vr_info):
    idx, weight = vr_info
    diff = grad - memory[idx]
    direction = diff + memory[-1]
    memory[-1] += diff * weight
    memory[idx, :] = grad
    return direction

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default="fashion", choices=["mnist", "fashion"])
    parser.add_argument('--train_size', type=int, default=50000)
    parser.add_argument('--val_size', type=int, default=5000)
    parser.add_argument('--pretrain', action='store_true',
                                      default=False, help='whether to create data and pretrain on valset')
    parser.add_argument('--epochs', type=int, default=50000)
    parser.add_argument('--batch_size', type=int, default=200)
    parser.add_argument('--iterations', type=int, default=1, help='T')
    parser.add_argument('--K', type=int, default=10, help='k')
    parser.add_argument('--data_path', default='./data', help='where to save data')
    parser.add_argument('--model_path', default='./save_data_cleaning', help='where to save model')
    parser.add_argument('--noise_rate', type=float, default=0.5)
    parser.add_argument('--x_lr', type=float, default=0.01)
    parser.add_argument('--xhat_lr', type=float, default=0.01)
    parser.add_argument('--w_lr', type=float, default=100)
    parser.add_argument('--w_momentum', type=float, default=0.9)
    parser.add_argument('--x_momentum', type=float, default=0.9)

    parser.add_argument('--eta', type=float, default=0.01)
    parser.add_argument('--u1', type=float, default=0.1)
    parser.add_argument('--seed', type=int, default=0)

    args = parser.parse_args()

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    return args


def get_data(args):

    data = {
        'mnist': datasets.MNIST,
        'fashion': datasets.FashionMNIST,
    }

    trainset = data[args.dataset](root=args.data_path,
                                  train=True,
                                  download=True)
    testset  = data[args.dataset](root=args.data_path,
                                  train=False,
                                  download=True)

    indices = torch.randperm(len(trainset))

    train_x = trainset.data[indices[:args.train_size]] / 255.
    val_x   = trainset.data[indices[args.train_size:args.train_size+args.val_size]] / 255.
    teval_x = trainset.data[indices[args.train_size+args.val_size:]] / 255.
    test_x  = testset.data / 255.

    targets = trainset.targets if args.dataset in ["mnist", "fashion"] else torch.LongTensor(trainset.targets) 
    train_y = targets[indices[:args.train_size]]
    val_y   = targets[indices[args.train_size:args.train_size+args.val_size]]
    teval_y = targets[indices[args.train_size+args.val_size:]]
    test_y  = torch.LongTensor(testset.targets)

    num_classes = test_y.unique().shape[0]
    assert val_y.unique().shape[0] == num_classes

    ### poison training data with noise rate = args.noise_rate
    num_noisy = int(args.train_size * args.noise_rate)
    rand_indices = torch.randperm(args.train_size)
    noisy_indices = rand_indices[:num_noisy]
    noisy_y = torch.randint(num_classes, size=(num_noisy,))
    old_train_y = train_y.data.clone()
    train_y.data[noisy_indices] = noisy_y.data

    # normalizing inputs to mean 0 and std 1.
    mean = train_x.unsqueeze(1).mean([0,2,3])
    std  = train_x.unsqueeze(1).std([0,2,3])

    trainset = ( torch.flatten((train_x  - mean)/(std+1e-4), start_dim=1), train_y )
    valset   = ( torch.flatten((val_x    - mean)/(std+1e-4), start_dim=1), val_y   )
    testset  = ( torch.flatten((test_x   - mean)/(std+1e-4), start_dim=1), test_y  )
    tevalset = ( torch.flatten((teval_x  - mean)/(std+1e-4), start_dim=1), teval_y )
    return trainset, valset, testset, tevalset, old_train_y

### initialize a linear model

def get_model(in_features, out_features, device):
    x = torch.zeros(out_features, in_features, requires_grad=True, device=device)

    weight = torch.empty((out_features, in_features))
    bias = torch.empty(out_features)
    nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
    fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weight)
    bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
    nn.init.uniform_(bias, -bound, bound)

    x[:,:in_features].data.copy_(weight.clone().to(device))
    x[:, -1].data.copy_(bias.clone().to(device))
    x.data.copy_(weight.clone().to(device))
    return x

def model_forward(x, inputs):
    in_features = 28*28
    A = x[:,:in_features] # (out_features, in_features)
    b = x[:,-1] # (out_features,)
    y = inputs.mm(A.t()) + b.view(1,-1)
    return y

### original f, g, and gradients

def f(x, w, dataset):
    data_x, data_y = dataset
    y = model_forward(x, data_x)
    loss = F.cross_entropy(y, data_y)
    return loss

def g(x, w, dataset):
    data_x, data_y = dataset
    y = model_forward(x, data_x)
    loss = F.cross_entropy(y, data_y, reduction='none')
    loss = (loss * torch.clip(w, 0, 1)).mean() + 0.001 * x.norm(2).pow(2)
    return loss

def g_x(x, w, dataset, retain_graph=False, create_graph=False):
    loss = g(x, w, dataset)
    grad = torch.autograd.grad(loss, x,
                               retain_graph=retain_graph,
                               create_graph=create_graph)[0]
    return grad

def g_xx(x,w,vs,dataset):
    gra=torch.autograd.grad(g(x,w,dataset), x, retain_graph=True,allow_unused=True,create_graph=True,only_inputs=True)[0]
    gra.requires_grad_(True)
    grad=torch.autograd.grad(gra, x, grad_outputs=vs, retain_graph=True,
                                 allow_unused=True)[0]
    return grad if grad is not None else torch.zeros_like(x)

def g_w(x, w, dataset, retain_graph=False, create_graph=False):
    loss = g(x, w, dataset)
    grad = torch.autograd.grad(loss, w,
                               retain_graph=retain_graph,
                               create_graph=create_graph)[0]
    return grad

def g_wx(x,w,vs,dataset):
    gra=torch.autograd.grad(g(x,w,dataset), x, retain_graph=True,allow_unused=True,create_graph=True,only_inputs=True)[0]
    gra.requires_grad_(True)
    grad=torch.autograd.grad(gra, w, grad_outputs=vs, retain_graph=True,
                                 allow_unused=True)[0]
    return grad if grad is not None else torch.zeros_like(w)

def g_x_xhat_w(x, xhat, w, dataset, retain_graph=False, create_graph=False):
    loss = g(x, w, dataset) - g(xhat.detach(), w, dataset)
    grad = torch.autograd.grad(loss, [x, w],
                               retain_graph=retain_graph,
                               create_graph=create_graph)
    return loss, grad[0], grad[1]

def g_x_xhat_w_bo(x, xhat, w, dataset, retain_graph=False, create_graph=False):
    loss = g(x, w, dataset) - g(xhat, w, dataset)
    grad = torch.autograd.grad(loss, [x, xhat, w],
                               retain_graph=retain_graph,
                               create_graph=create_graph)
    return grad[0], grad[1], grad[2]

def f_x(x, w, dataset, retain_graph=False, create_graph=False):
    loss = f(x, w, dataset)
    grad = torch.autograd.grad(loss, x,
                               retain_graph=retain_graph,
                               create_graph=create_graph)[0]
    return grad


### Define evaluation metric

def evaluate(x, testset, tevalset):
    with torch.no_grad():
        test_x, test_y = testset  
        y = model_forward(x, test_x)
        test_loss = F.cross_entropy(y, test_y).detach().item()
        test_acc = y.argmax(-1).eq(test_y).float().mean().detach().cpu().item()
        # have a separate test val set since valset is used for training
        teval_x, teval_y = tevalset
        y_ = model_forward(x, teval_x)
        teval_loss = F.cross_entropy(y_, teval_y).detach().item()
    return test_loss, test_acc, teval_loss

def evaluate_importance_f1(w, clean_indices):
    with torch.no_grad():
        w_ = w.gt(0.5).float()
        TP = (w_ * clean_indices.float()).sum()
        recall = TP / (clean_indices.float().sum()+1e-4)
        precision = TP / (w_.sum()+1e-4)
        f1 = 2.0 * recall * precision / (recall + precision + 1e-4)
    return precision.cpu().item(), recall.cpu().item(), f1.cpu().item()

### inner solver bfgs
def bfgs(x,w,tol,step,maxiter_hg,stepsize1,stepsize2,m,h0=0.01,ex_up=False,ws=3): 
            y_list, s_list, mu_list = [], [], []
            y1_list, s1_list, mu1_list = [], [], []
            for k in range(1, step + 1):
                if k<ws:
                   x = x- 0.1 * g_x(x,w,trainset)
                   new_grad=g_x(x,w,trainset)
                   ngrad=new_grad.detach().cpu().numpy()
                else:
                   p = two_loops(grady, m, s_list, y_list, mu_list,h0)#default H0=I
                   s= 0.1*p
                   st=torch.from_numpy(s).to(x.device).float()
                   x = x +st
                   new_grad=g_x(x,w,trainset)
                   ngrad=new_grad.detach().cpu().numpy()#\nabla_y f(x_k,y_{k+1})
                   y=ngrad-grady
                    # Update the memory
                   if (safe_sparse_dot(np.ravel(y),np.ravel(s)))>1e-10:
                       y_list.append(y.copy())
                       s_list.append(s.copy())
                       mu=1/safe_sparse_dot(np.ravel(y),np.ravel(s))
                       mu_list.append(mu)
                   if len(y_list) > m:
                       y_list.pop(0)
                       s_list.pop(0)
                       mu_list.pop(0)
                grady=ngrad
                
            fx = f_x(x, w, valset)# dy F x/y shape(10,784)
            gradFy=fx.detach().cpu().numpy()#\nabla_y F(x_k,y_{k+1})
            if ex_up==False:
                hg = -two_loops(gradFy, m, s_list, y_list, mu_list,h0)
                et=torch.from_numpy(hg).to(x.device).float()
            else:
                for i in range (1, maxiter_hg + 1):
                    eq = - two_loops(gradFy, m, s1_list, y1_list, mu1_list,h0)#default H0=I
                    en=np.linalg.norm(np.ravel(eq))
                    eq = eq /en
                    et=torch.from_numpy(eq).to(x.device).float()
                    x1=x+et
                    grad=g_x(x1,w,trainset)
                    f1grad=grad.detach().cpu().numpy()
                    y_tilde1 = f1grad- grady
                    if safe_sparse_dot(np.ravel(y_tilde1), np.ravel(eq))>1e-10:
                      mu1 = 1 / safe_sparse_dot(np.ravel(y_tilde1), np.ravel(eq))
                      y1_list.append(y_tilde1.copy())
                      s1_list.append(eq.copy())
                      mu1_list.append(mu1)
                    if len(y1_list) > m:
                      y1_list.pop(0)
                      s1_list.pop(0)
                      mu1_list.pop(0)
            print(f'{k} iterates')
            return x, et

def two_loops(grad_y, m, s_list, y_list, mu_list,h0):
            q = grad_y.copy()
            alpha_list = []
            for s, y, mu in zip(reversed(s_list), reversed(y_list), reversed(mu_list)):
                alpha = mu * safe_sparse_dot(np.ravel(s), np.ravel(q))
                alpha_list.append(alpha)
                q -= alpha * y
            r=h0*q
            for s, y, mu, alpha in zip(s_list, y_list, mu_list, reversed(alpha_list)):
                beta = mu * safe_sparse_dot(np.ravel(y), np.ravel(r))
                r += (alpha - beta) * s
            return -r

def rbfgs(x,w,tol,step,m,exup=False,h0=0.01,maxls=10): 
            c1=0.0001
            c2=0.0009
            y_list, s_list, mu_list = [], [], []
            new_grad=g_x(x,w,trainset)
            grady=new_grad.detach().cpu().numpy()
            lf=lambda x: g(x,w,trainset)
            lf_grad=lambda x:g_x(x,w,trainset)
            t=0.1
            tc=1e-9
            obf=g(x,w,trainset)
            for k in range(1, step + 1):
                   if k>1 and k%5==0 and exup==True:
                       fx = f_x(x, w, valset)# dy F
                       gradFy=fx.detach().cpu().numpy()#\nabla_y F(x_k,y_{k+1})
                       eq = - two_loops(gradFy, m, s_list, y_list, mu_list,h0)
                       eq =eq/np.linalg.norm(eq)*np.linalg.norm(s)
                       et=torch.from_numpy(eq).to(x.device).float()
                       x1=x+et
                       grad=g_x(x1,w,trainset)
                       f1grad=grad.detach().cpu().numpy()
                       y_tilde1 = f1grad- grady
                       if safe_sparse_dot(np.ravel(y_tilde1), np.ravel(eq))>1e-20:
                           mu = 1 / safe_sparse_dot(np.ravel(y_tilde1), np.ravel(eq))
                           y_list.append(y_tilde1.copy())
                           s_list.append(eq.copy())
                           mu_list.append(mu)  
                   d = two_loops(grady, m, s_list, y_list, mu_list,h0)
                   p=torch.from_numpy(d).to(x.device).float()
                   gtd=(new_grad.view(-1)).dot(p.view(-1))
                   obf, new_grad, step,lsi = ls.strong_wolfe(lf, lf_grad, x,t,
                                                              p, obf,new_grad,gtd,
                                                              c1,c2,tc,
                                                              maxls)
                   if step is None:
                        step = 0.1
                        s = step * d
                        st=step * p
                        x = x +st
                        new_grad=g_x(x,w,trainset)
                   else: 
                        
                        if type(step)!=float and type(step)!=int:
                            step =step.cpu().detach().numpy()
                            
                        s = step * d
                        st=torch.from_numpy(s).to(x.device)
                        x = x +st
                   ngrad=new_grad.detach().cpu().numpy()#\nabla_y f(x_k,y_{k+1})
                   y=ngrad-grady
                    # Update the memory
                   if (safe_sparse_dot(np.ravel(y),np.ravel(s)))>1e-10:
                       y_list.append(y.copy())
                       s_list.append(s.copy())
                       mu=1/safe_sparse_dot(np.ravel(y),np.ravel(s))
                       mu_list.append(mu)
                   if len(y_list) > m:
                       y_list.pop(0)
                       s_list.pop(0)
                       mu_list.pop(0)
                   grady=ngrad
                   l_inf_norm_grad = np.linalg.norm(np.ravel(grady))
                   if l_inf_norm_grad < tol:
                       break
            fx = f_x(x, w, valset)# dy F x/y shape(10,784)
            gradFy=fx.detach().cpu().numpy()#\nabla_y F(x_k,y_{k+1})
            hg = -two_loops(gradFy, m, s_list, y_list, mu_list,h0)
            et=torch.from_numpy(hg).to(x.device).float()
            print(f'{k} iterates')
            return x, et

def sr(x,w,tol,step,maxiter_hg,stepsize1,stepsize2,m,h0=0.01,ex_up=False,ws=3): 
            y_list, s_list, mu_list = [], [], []
            y1_list, s1_list, mu1_list = [], [], [] 
            for k in range(1, step + 1):
                if k<ws:
                   x = x- 0.1 * g_x(x,w,trainset)
                   new_grad=g_x(x,w,trainset)
                   ngrad=new_grad.detach().cpu().numpy()
                else:
                   p = two_loopsr(grady, m, s_list, y_list,h0)
                   s= -0.1*p
                   st=torch.from_numpy(s).to(x.device).float()
                   x = x +st
                   new_grad=g_x(x,w,trainset)
                   ngrad=new_grad.detach().cpu().numpy()#\nabla_y f(x_k,y_{k+1})
                   y=ngrad-grady
                    # Update the memory
                   if (safe_sparse_dot(np.ravel(y),np.ravel(s)))>1e-10:
                       y_list.append(y.copy())
                       s_list.append(s.copy())
                   if len(y_list) > m:
                       y_list.pop(0)
                       s_list.pop(0)
                grady=ngrad
                
                l_inf_norm_grad = np.linalg.norm(np.ravel(grady))
                if l_inf_norm_grad < tol:
                    break
                
            fx = f_x(x, w, valset)# dy F x/y shape(10,784)
            gradFy=fx.detach().cpu().numpy()#\nabla_y F(x_k,y_{k+1})
            if ex_up==False:
                hg = two_loopsr(gradFy, m, s_list, y_list,h0)
                et=torch.from_numpy(hg).to(x.device).float()
            else:
                for i in range (1, maxiter_hg + 1):
                    eq =  two_loopsr(gradFy, m, s1_list, y1_list,h0)#default H0=I
                    en=np.linalg.norm(np.ravel(eq))
                    eq = eq /en
                    et=torch.from_numpy(eq).to(x.device).float()
                    x1=x+et
                    grad=g_x(x1,w,trainset)
                    f1grad=grad.detach().cpu().numpy()
                    y_tilde1 = f1grad- grady
                    if safe_sparse_dot(np.ravel(y_tilde1), np.ravel(eq))>1e-10:
                      y1_list.append(y_tilde1.copy())
                      s1_list.append(eq.copy())
                    if len(y1_list) > m:
                      y1_list.pop(0)
                      s1_list.pop(0)
            print(f'{k} iterates')
            return x, et

def two_loopsr(grad_x, m,s_list, y_list,h0):
            q = grad_x.copy()
            p_list = []
            r=h0*q
            for s, y in zip(s_list, y_list):
                p=s-h0*y     #p_i=s_i-H0y_i
                i=len(p_list)
                for k in range(i):
                    p = p-(safe_sparse_dot(np.ravel(p_list[k]), np.ravel(y)))/(safe_sparse_dot(np.ravel(p_list[k]), np.ravel(y_list[k])))*p_list[k]
                p_list.append(p)
            for p, y in zip(p_list, y_list):
                r = r+(safe_sparse_dot(np.ravel(p),np.ravel(q)))/(safe_sparse_dot(np.ravel(p), np.ravel(y)))*p
            return r





###############################################################################
#
# Bilevel Optimization Training Methods
#
###############################################################################
def simple_train(args, x, data_x, data_y, testset, tevalset, tag='pretrain', regularize=False): # directly train on the dataset
    opt = torch.optim.SGD([x], lr=args.x_lr, momentum=args.x_momentum)
    n = data_x.shape[0]

    n_epochs = 5000
    best_teval_loss = np.inf
    final_test_loss = 0.
    final_test_acc = 0.
    best_x = None

    for epoch in range(n_epochs):
        opt.zero_grad()
        y = model_forward(x, data_x)
        loss = F.cross_entropy(y, data_y)
        if regularize:
            loss += 0.001 * x.norm(2).pow(2)
        loss.backward()
        opt.step()

        test_loss, test_acc, teval_loss = evaluate(x, testset, tevalset)
        if teval_loss <= best_teval_loss:
            best_teval_loss = teval_loss
            final_test_loss = test_loss
            final_test_acc  = test_acc
            best_x = x.data.clone()
        print(f"[{tag}] epoch {epoch:5d} test loss {test_loss:10.4f} test acc {test_acc:10.4f}")
    return final_test_loss, final_test_acc, best_x

def FOA(args, x, w, trainset, valset, testset, tevalset,clean_indices,test_loss,test_acc):
    total_time = 0.0
    stats = []
    test_lossesfoa=[]
    test_accsfoa=[]
    running_timefoa=[]
    f1sfoa=[]

    test_lossesfoa.append(test_loss)
    test_accsfoa.append(test_acc)
    running_timefoa.append(total_time)

    outer_opt = torch.optim.SGD([w], lr=args.w_lr, momentum=args.w_momentum)
    for epoch in range(5000):
        t0 = time.time()
        
        x,et=bfgs(x,w,tol=1/(epoch+1),step=50,maxiter_hg=1,stepsize1=0.1,stepsize2=0.1,m=30,h0=1,ex_up=False,ws=4)
            
        #fw = f_w(x, w, valset) f_w=0
      
        outer_opt.zero_grad()
        rhg=g_wx(x,w,et,trainset)
        w.grad = ( -rhg).data #fw=0
        outer_opt.step()

        t1 = time.time()
        total_time += t1 - t0
        w.data.clamp_(0.0, 1.0)
        test_loss, test_acc, teval_loss = evaluate(x, testset, tevalset)
        f1 = evaluate_importance_f1(w, clean_indices)
        f1sfoa.append(f1)
        stats.append((total_time, test_loss, test_acc, teval_loss))
        test_lossesfoa.append(test_loss)
        test_accsfoa.append(test_acc)
        running_timefoa.append(total_time)
        print(f"[info] epoch {epoch:5d} | te loss {test_loss:6.4f} | te acc {test_acc:4.2f} | teval loss {teval_loss:6.4f} | time {total_time:6.2f} | w-min {w.min().item():4.2f} w-max {w.max().item():4.2f} | f1 {f1[2]:4.2f}")
        
    return stats,running_timefoa,test_lossesfoa,test_accsfoa,f1sfoa

def SR1(args, x, w, trainset, valset, testset, tevalset,clean_indices,test_loss,test_acc):
    total_time = 0.0
    stats = []
    test_lossessr=[]
    test_accssr=[]
    running_timesr=[]
    f1ssr=[]

    test_lossessr.append(test_loss)
    running_timesr.append(total_time)
    test_accssr.append(test_acc)
    outer_opt = torch.optim.SGD([w], lr=args.w_lr, momentum=args.w_momentum)

    for epoch in range(7000):
        t0 = time.time()
        
        x,et=sr(x,w,tol=0.1,step=20,maxiter_hg=3,stepsize1=0.1,stepsize2=0.1,m=30,h0=0.01,ex_up=True,ws=4)
            
      
        outer_opt.zero_grad()
        rhg=g_wx(x,w,et,trainset)
        w.grad = ( -rhg).data #fw=0
        outer_opt.step()

        t1 = time.time()
        total_time += t1 - t0
        w.data.clamp_(0.0, 1.0)
        test_loss, test_acc, teval_loss = evaluate(x, testset, tevalset)
        f1 = evaluate_importance_f1(w, clean_indices)
        f1ssr.append(f1)
        stats.append((total_time, test_loss, test_acc, teval_loss))
        test_lossessr.append(test_loss)
        test_accssr.append(test_acc)
        running_timesr.append(total_time)
        print(f"[info] epoch {epoch:5d} | te loss {test_loss:6.4f} | te acc {test_acc:4.2f} | teval loss {teval_loss:6.4f} | time {total_time:6.2f} | w-min {w.min().item():4.2f} w-max {w.max().item():4.2f} | f1 {f1[2]:4.2f}")
        
    return stats,running_timesr,test_lossessr,test_accssr,f1ssr

def BOME(args, x, w, trainset, valset, testset, tevalset, clean_indices,test_loss,test_acc):
    xhat = copy.deepcopy(x)

    total_time = 0.0
    n = trainset[0].shape[0]
    stats = []
    test_lossesbome=[]
    test_accsbome=[]
    running_timebome=[]
    f1sbm=[]

    test_lossesbome.append(test_loss)
    test_accsbome.append(test_acc) 
    running_timebome.append(total_time)

    outer_opt = torch.optim.SGD([
        {'params': [x], 'lr': args.x_lr},
        {'params': [w], 'lr': args.w_lr}], momentum=args.w_momentum)
    inner_opt = torch.optim.SGD([xhat], lr=args.xhat_lr, momentum=args.x_momentum)

    n_params_w = w.numel()
    zz = torch.zeros(n_params_w).to(x.device)

    for epoch in range(10000):

        xhat.data = x.data.clone()
        t0 = time.time()
        for it in range(args.iterations):
            inner_opt.zero_grad()
            xhat.grad = g_x(xhat, w, trainset)
            inner_opt.step()

        # prepare gradients 
        fx = f_x(x, w, valset)
        loss, gx, gw_minus_gw_k = g_x_xhat_w(x, xhat, w, trainset)

        df = torch.cat([fx.view(-1), zz])
        dg = torch.cat([gx.view(-1), gw_minus_gw_k.view(-1)])
        norm_dq = dg.norm().pow(2)
        dot = df.dot(dg)
        #lmbd = F.relu(args.u1 - dot/(norm_dq + 1e-8))
        lmbd = F.relu((args.u1*loss-dot)/(norm_dq+1e-8))

        outer_opt.zero_grad()
        x.grad = fx + lmbd * gx
        w.grad = lmbd * gw_minus_gw_k
        outer_opt.step()
        t1 = time.time()
        total_time += t1 - t0
        w.data.clamp_(0.0, 1.0)

        test_loss, test_acc, teval_loss = evaluate(x, testset, tevalset)
        f1 = evaluate_importance_f1(w, clean_indices)
        f1sbm.append(f1)
        stats.append((total_time, test_loss, test_acc, teval_loss))
        test_lossesbome.append(test_loss)
        test_accsbome.append(test_acc)
        running_timebome.append(total_time)
        print(f"[info] epoch {epoch:5d} | te loss {test_loss:6.4f} | te acc {test_acc:4.2f} | teval loss {teval_loss:6.4f} | time {total_time:6.2f} | w-min {w.min().item():4.2f} w-max {w.max().item():4.2f} | f1 {f1[2]:4.2f}")
       
    return stats,running_timebome,test_lossesbome,test_accsbome,f1sbm

def F2SA(args, x, w, trainset, valset, testset, tevalset, clean_indices,test_loss,test_acc):
    xhat = copy.deepcopy(x)

    batch_size = 2000

    n_val = valset[0].shape[0]
    n_train = trainset[0].shape[0]

    total_time = 0.0
    stats = []
    test_lossesf2sa=[]
    test_accsf2sa=[]
    running_timef2sa=[]
    f1sfs=[]

    test_lossesf2sa.append(test_loss)
    test_accsf2sa.append(test_acc) 
    running_timef2sa.append(total_time)

    inner_opt = torch.optim.SGD([
        {'params': [x], 'lr': args.x_lr},
        {'params': [xhat], 'lr': args.xhat_lr}], momentum=args.x_momentum)
    outer_opt = torch.optim.SGD([w], lr=args.w_lr, momentum=args.w_momentum)

    lmbd=0.1
    for epoch in range(8000):

        idx3 = torch.randperm(n_train).to(x.device)
        idx4 = torch.randperm(n_val).to(x.device)

        idx1 = torch.randperm(n_train)[:batch_size].to(x.device)

        xhat.data = x.data.clone()
        t0 = time.time()
        for it in range(args.iterations):
            idx = idx3[it*batch_size:(it+1)*batch_size]
            idxv = idx4[it*batch_size:(it+1)*batch_size]
            fx = f_x(x, w[idxv], (valset[0][idxv], valset[1][idxv]))
            gx= g_x(x, w[idx], (trainset[0][idx], trainset[1][idx]))
            inner_opt.zero_grad()
            x.grad = fx + lmbd * gx
            xhat.grad = g_x(xhat, w[idx], (trainset[0][idx], trainset[1][idx]))
            inner_opt.step()

        # prepare gradients 
        gwminus_gw_k=torch.zeros_like(w)
        _,_, gwminus_gw_k[idx1] = g_x_xhat_w(x, xhat, w[idx1],(trainset[0][idx1], trainset[1][idx1]))

        outer_opt.zero_grad()
        w.grad = lmbd * gwminus_gw_k#fw=0
        outer_opt.step()
        t1 = time.time()
        total_time += t1 - t0
        w.data.clamp_(0.0, 1.0)
        lmbd=lmbd+0.001

        test_loss, test_acc, teval_loss = evaluate(x, testset, tevalset)
        f1 = evaluate_importance_f1(w, clean_indices)
        f1sfs.append(f1)
        stats.append((total_time, test_loss, test_acc, teval_loss))
        test_lossesf2sa.append(test_loss)
        test_accsf2sa.append(test_acc)
        running_timef2sa.append(total_time)
        print(f"[info] epoch {epoch:5d} | te loss {test_loss:6.4f} | te acc {test_acc:4.2f} | teval loss {teval_loss:6.4f} | time {total_time:6.2f} | w-min {w.min().item():4.2f} w-max {w.max().item():4.2f} | f1 {f1[2]:4.2f}")
        
    return stats,running_timef2sa,test_lossesf2sa,test_accsf2sa,f1sfs

def SABA(args, x, w, trainset, valset, testset, tevalset, clean_indices,test_loss,test_acc):
    

    batch_size = 2000

    n_val = valset[0].shape[0]
    n_train = trainset[0].shape[0]
    n_features= trainset[0].shape[1]
    n_w=w.shape[0]
    outer_sampler = MinibatchSampler(n_val, batch_size)
    inner_sampler = MinibatchSampler(n_train, batch_size)
    n_inner = (n_train + batch_size - 1) // batch_size
    n_outer = (n_val + batch_size - 1) // batch_size

    v = torch.zeros_like(x, dtype=torch.float,requires_grad=False).to(x.device)
    memory_inner_grad = torch.zeros((n_inner + 1, n_features*10), dtype=torch.float).to(x.device)
    memory_hvp = torch.zeros((n_inner + 1, n_features*10), dtype=torch.float).to(x.device)
    memory_cross_v = torch.zeros((n_inner + 1, n_w), dtype=torch.float).to(x.device)

    memory_grad_in_outer = torch.zeros((n_outer + 1, n_features*10), dtype=torch.float).to(x.device)
    
    if args.dataset=='mnist':
       outer_opt = torch.optim.SGD([w], lr=10, momentum=args.w_momentum)
       inner_opt = torch.optim.SGD([x], lr=args.x_lr, momentum=args.x_momentum)
       inner_step_size=0.1
    else:
        outer_opt = torch.optim.SGD([w], lr=100, momentum=args.w_momentum)
        inner_opt = torch.optim.SGD([x], lr=0.001, momentum=args.x_momentum)
        inner_step_size=0.001

    total_time = 0.0
    stats = []
    test_lossessaba=[]
    test_accssaba=[]
    running_timesaba=[]
    f1ssa=[]

    test_lossessaba.append(test_loss)
    test_accssaba.append(test_acc) 
    running_timesaba.append(total_time)

   
    for epoch in range(10000):
        t0 = time.time()
        slice_inner, vr_inner = inner_sampler.get_batch()

        grad_inner_var_o =g_x(x, w[slice_inner], (trainset[0][slice_inner], trainset[1][slice_inner]))
        grad_inner_var=grad_inner_var_o.flatten()
        hvp_o=g_xx(x, w[slice_inner],v, (trainset[0][slice_inner], trainset[1][slice_inner]))
        hvp=hvp_o.flatten()
        cross_v=torch.zeros_like(w)
        cross_v[slice_inner]=g_wx(x, w[slice_inner],v, (trainset[0][slice_inner], trainset[1][slice_inner]))#.flatten()output是与w同维吧？不是转置？
        
        slice_outer, vr_outer = outer_sampler.get_batch()
        grad_in_outer_o =f_x(x, w[slice_outer], (valset[0][slice_outer], valset[1][slice_outer]))
        grad_in_outer=grad_in_outer_o.flatten()
        
        grad_inner_var = variance_reduction(grad_inner_var, memory_inner_grad, vr_inner)
        hvp = variance_reduction(hvp, memory_hvp, vr_inner)
        cross_v = variance_reduction(cross_v, memory_cross_v, vr_inner)
        grad_in_outer = variance_reduction(grad_in_outer, memory_grad_in_outer, vr_outer)
        
        
        ginshape=grad_inner_var_o.shape
        hshape=hvp_o.shape
        goutshape=grad_in_outer_o.shape 
       
        xgrad =grad_inner_var.view(ginshape)
        hvpg=hvp.view(hshape)
        gradinouter=grad_in_outer.view(goutshape)


        inner_opt.zero_grad()
        x.grad=xgrad
        inner_opt.step()
        v =v- inner_step_size * (hvpg+ gradinouter)

        outer_opt.zero_grad()
        w.grad= cross_v
        outer_opt.step()

               
        t1 = time.time()
        total_time += t1 - t0
        w.data.clamp_(0.0, 1.0)
       

        test_loss, test_acc, teval_loss = evaluate(x, testset, tevalset)
        f1 = evaluate_importance_f1(w, clean_indices)
        f1ssa.append(f1)
        stats.append((total_time, test_loss, test_acc, teval_loss))
        test_lossessaba.append(test_loss)
        test_accssaba.append(test_acc)
        running_timesaba.append(total_time)
        print(f"[info] epoch {epoch:5d} | te loss {test_loss:6.4f} | te acc {test_acc:4.2f} | teval loss {teval_loss:6.4f} | time {total_time:6.2f} | w-min {w.min().item():4.2f} w-max {w.max().item():4.2f} | f1 {f1[2]:4.2f}")
        
    return stats,running_timesaba,test_lossessaba,test_accssaba,f1ssa

def BSG_1(args, x, w, trainset, valset, testset, tevalset, clean_indices,test_loss,test_acc):
    total_time = 0.0
    n = trainset[0].shape[0]
    stats = []
    test_lossesbsg1=[]
    test_accsbsg1=[]
    running_timebsg1=[]
    f1sbs=[]

    test_lossesbsg1.append(test_loss)
    test_accsbsg1.append(test_acc)
    running_timebsg1.append(total_time)
    outer_opt = torch.optim.SGD([w], lr=args.w_lr, momentum=args.w_momentum)
    inner_opt = torch.optim.SGD([x], lr=args.x_lr, momentum=args.x_momentum)

    for epoch in range(10000):

        t0 = time.time()
        for it in range(args.iterations):
            inner_opt.zero_grad()
            x.grad = g_x(x, w, trainset).data
            inner_opt.step()

        # prepare gradients 
        fx = f_x(x, w, valset)
        gx = g_x(x, w, trainset)
        gw = g_w(x, w, trainset)

        outer_opt.zero_grad()
        w.grad = (-fx.view(-1).dot(gx.view(-1)) / (gx.norm(2).pow(2)+1e-4) * gw).data.clone()
        outer_opt.step()

        t1 = time.time()
        total_time += t1 - t0
        w.data.clamp_(0.0, 1.0)

        test_loss, test_acc, teval_loss = evaluate(x, testset, tevalset)
        stats.append((total_time, test_loss, test_acc, teval_loss))
        f1 = evaluate_importance_f1(w, clean_indices)
        f1sbs.append(f1)
        test_lossesbsg1.append(test_loss)
        test_accsbsg1.append(test_acc)
        running_timebsg1.append(total_time)
        print(f"[info] epoch {epoch:5d} | te loss {test_loss:6.4f} | te acc {test_acc:4.2f} | teval loss {teval_loss:6.4f} | time {total_time:6.2f} | w-min {w.min().item():4.2f} w-max {w.max().item():4.2f}")
    return stats,running_timebsg1,test_lossesbsg1,test_accsbsg1,f1sbs

def SHINE1(args, x, w, trainset, valset, testset, tevalset,clean_indices,test_loss,test_acc):
    total_time = 0.0
    stats = []
    test_lossesshinea=[]
    test_accsshinea=[]
    running_timeshinea=[]
    f1ssi=[]

    test_lossesshinea.append(test_loss)
    test_accsshinea.append(test_acc)
    running_timeshinea.append(total_time)

    outer_opt = torch.optim.SGD([w], lr=args.w_lr, momentum=args.w_momentum)

    for epoch in range(40):
        t0 = time.time()

        x,et=rbfgs(x,w,tol=1/(100*(epoch+1)),step=1000,m=30,exup=True,h0=1)
            
        outer_opt.zero_grad()
        rhg=g_wx(x,w,et,trainset)
        w.grad = ( -rhg).data #fw=0
        outer_opt.step()

        t1 = time.time()
        total_time += t1 - t0
        w.data.clamp_(0.0, 1.0)
        test_loss, test_acc, teval_loss = evaluate(x, testset, tevalset)
        f1 = evaluate_importance_f1(w, clean_indices)
        stats.append((total_time, test_loss, test_acc, teval_loss))
        f1ssi.append(f1)
        test_lossesshinea.append(test_loss)
        test_accsshinea.append(test_acc)
        running_timeshinea.append(total_time)
        print(f"[info] epoch {epoch:5d} | te loss {test_loss:6.4f} | te acc {test_acc:4.2f} | teval loss {teval_loss:6.4f} | time {total_time:6.2f} | w-min {w.min().item():4.2f} w-max {w.max().item():4.2f} | f1 {f1[2]:4.2f}")
        
    return stats,running_timeshinea,test_lossesshinea,test_accsshinea,f1ssi



if __name__ == "__main__":
    args = parse_args()

    if args.pretrain: # preprocess data and pretrain a model on validation set

        if not os.path.exists(args.data_path):
            os.makedirs(args.data_path)

        if not os.path.exists(args.model_path):
            os.makedirs(args.model_path)

        ### generate data
        trainset, valset, testset, tevalset, old_train_y = get_data(args)
        torch.save((trainset, valset, testset, tevalset, old_train_y),
                   os.path.join(args.data_path, f"{args.dataset}_data_cleaning.pt"))
        print(f"[info] successfully generated data to {args.data_path}/{args.dataset}_data_cleaning.pt")

        ### pretrain a model and save it
        n_feats = np.prod(*trainset[0].shape[1:])
        num_classes = trainset[1].unique().shape[-1]
        args.device = torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cpu")

        trainset = (trainset[0].to(args.device), trainset[1].to(args.device))
        valset   = (valset[0].to(args.device),   valset[1].to(args.device))
        testset  = (testset[0].to(args.device),  testset[1].to(args.device))
        tevalset = (tevalset[0].to(args.device), tevalset[1].to(args.device))
        old_train_y = old_train_y.to(args.device)

        x = get_model(n_feats, num_classes, args.device)
        sd = x.data.clone()

        # lower bound (train on noisy train + valset)
        tmp_x = torch.cat([trainset[0], valset[0]], 0)
        tmp_y = torch.cat([trainset[1], valset[1]], 0)
        test_loss1, test_acc1, best_x1 = simple_train(args, x, tmp_x, tmp_y, testset, tevalset, regularize=True)
        torch.save(best_x1.data.cpu().clone(),
                   os.path.join(args.model_path, f"{args.dataset}_pretrained.pt"))

        # a baseline: train on valset
        x.data.copy_(sd)
        test_loss2, test_acc2, best_x2 = simple_train(args, x, valset[0], valset[1], testset, tevalset)
        torch.save(best_x2.data.cpu().clone(),
                   os.path.join(args.model_path, f"{args.dataset}_pretrained_val.pt"))

        # upper bound (train on correct train + valset)
        x.data.copy_(sd)
        tmp_x = torch.cat([trainset[0], valset[0]], 0)
        tmp_y = torch.cat([old_train_y, valset[1]], 0)
        test_loss3, test_acc3, best_x3 = simple_train(args, x, tmp_x, tmp_y, testset, tevalset)
        torch.save(best_x3.data.cpu().clone(),
                   os.path.join(args.model_path, f"{args.dataset}_pretrained_trainval.pt"))

        print(f"[pretrained] noisy train + val   : test loss {test_loss1} test acc {test_acc1}")
        print(f"[pretrained] val                 : test loss {test_loss2} test acc {test_acc2}")
        print(f"[pretrained] correct train + val : test loss {test_loss3} test acc {test_acc3}")

        torch.save({
            "pretrain_test_loss": test_loss1,
            "pretrain_test_acc": test_acc1,
            "pretrain_val_test_loss": test_loss2,
            "pretrain_val_test_acc": test_acc2,
            "pretrain_trainval_test_loss": test_loss3,
            "pretrain_trainval_test_acc": test_acc3,
            }, os.path.join(args.model_path, f"{args.dataset}_pretrained.stats"))


    else: # load pretrained model on valset and then start model training
        trainset, valset, testset, tevalset, old_train_y = torch.load(
                os.path.join(args.data_path, f"{args.dataset}_data_cleaning.pt"))
        args.device = torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cpu")

        n_feats = np.prod(*trainset[0].shape[1:])
        num_classes = trainset[1].unique().shape[-1]

        trainset = (trainset[0].to(args.device), trainset[1].to(args.device))
        valset   = (valset[0].to(args.device),   valset[1].to(args.device))
        testset  = (testset[0].to(args.device),  testset[1].to(args.device))
        tevalset = (tevalset[0].to(args.device), tevalset[1].to(args.device))
        old_train_y = old_train_y.to(args.device)

        x = get_model(n_feats, num_classes, args.device)
        x.data.copy_(torch.load(os.path.join(args.model_path, f"{args.dataset}_pretrained.pt")).to(args.device))

        # load the pretrained model on validation set
        pretrained_stats = torch.load(
            os.path.join(args.model_path, f"{args.dataset}_pretrained.stats"))

        test_loss1 = pretrained_stats['pretrain_test_loss']
        test_loss2 = pretrained_stats['pretrain_val_test_loss']
        test_loss3 = pretrained_stats['pretrain_trainval_test_loss']
        test_acc1  = pretrained_stats['pretrain_test_acc']
        test_acc2  = pretrained_stats['pretrain_val_test_acc']
        test_acc3  = pretrained_stats['pretrain_trainval_test_acc']
        print(f"[pretrained] noisy train + val   : test loss {test_loss1} test acc {test_acc1}")
        print(f"[pretrained] val                 : test loss {test_loss2} test acc {test_acc2}")
        print(f"[pretrained] correct train + val : test loss {test_loss3} test acc {test_acc3}")

        test_loss, test_acc, teval_loss = evaluate(x, testset, tevalset)
        print("original test loss ", test_loss, "original test acc ", test_acc)

        clean_indices = old_train_y.to(args.device).eq(trainset[1])
        w = torch.zeros(trainset[0].shape[0], requires_grad=True, device=args.device)
        w.data.add_(0.5)

        n_runs = 10

        all_test_losses = {}
        all_test_accs = {}
        all_running_times = {}
        all_f1_scores = {}

        all_avg_test_losses = {}
        all_avg_test_accs = {}
        all_avg_running_times = {}
        all_avg_f1_scores = {}


        methods = ['SHINE1','FOA','SR1','SABA', 'BOME', 'F2SA']

        for method in methods:
            all_test_losses[method] = []
            all_test_accs[method] = []
            all_running_times[method] = []
            all_f1_scores[method] = []
            
            for _ in range(n_runs):
                x = get_model(n_feats, num_classes, args.device)
                x.data.copy_(torch.load(os.path.join(args.model_path, f"{args.dataset}_pretrained.pt")).to(args.device))

                test_loss, test_acc, teval_loss = evaluate(x, testset, tevalset)
                print("original test loss ", test_loss, "original test acc ", test_acc)

                clean_indices = old_train_y.to(args.device).eq(trainset[1])
                w = torch.zeros(trainset[0].shape[0], requires_grad=True, device=args.device)
                w.data.add_(0.5)
                stats, running_times, test_losses, test_accs,f1_score = eval(method)(args=args,
                                                                            x=x,
                                                                            w=w,
                                                                            trainset=trainset,
                                                                            valset=valset,
                                                                            testset=testset,
                                                                            tevalset=tevalset,
                                                                            clean_indices=clean_indices,
                                                                            test_loss=test_loss,
                                                                            test_acc=test_acc)
                
                all_test_losses[method].append(test_losses)
                all_test_accs[method].append(test_accs)
                all_running_times[method].append(running_times)
                all_f1_scores[method].append(f1_score)

                    
            torch.save({
                'all_test_losses': all_test_losses[method],
                'all_test_accs': all_test_accs[method],
                'all_running_times': all_running_times[method],
                'all_f1_scores': all_f1_scores[method],
            }, os.path.join('resultsall', f"{method}.pt"))
                        
            all_avg_test_losses[method] = np.mean(all_test_losses[method], axis=0)
            all_avg_test_accs[method] = np.mean(all_test_accs[method], axis=0)
            all_avg_running_times[method] = np.mean(all_running_times[method], axis=0)
            all_avg_f1_scores[method] = np.mean(all_f1_scores[method], axis=0)

            torch.save({
                'avg_test_losses': all_avg_test_losses[method],
                'avg_test_accs': all_avg_test_accs[method],
                'avg_running_times': all_avg_running_times[method],
                'avg_f1_scores': all_avg_f1_scores[method],
            }, os.path.join('results1', f"{method}_averages.pt"))