import argparse
import copy
import math
import numpy as np
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from sklearn.utils.extmath import safe_sparse_dot
import matplotlib.pyplot as plt
from scipy import optimize

import tls as ls
from minibatch import MinibatchSampler
from matplotlib.font_manager import FontProperties



from torchvision import datasets

#Ablation study on Q of qNBO
################################################################################

#  Bilevel Optimization
#
#  min_{x,w} f(x, w)
#  s.t. x = argmin_x g(x, w)
#
#  here: f(x, w) is on valset
#        g(x, w) is on trainset
#
#  f_x = df/dx
#  f_w = df/dw
#  g_x = dg/dx
#  g_w = dg/dw
#
################################################################################

def variance_reduction(grad, memory, vr_info):
    idx, weight = vr_info
    diff = grad - memory[idx]
    direction = diff + memory[-1]
    memory[-1] += diff * weight
    memory[idx, :] = grad
    return direction

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default="mnist", choices=["mnist", "fashion"])
    parser.add_argument('--train_size', type=int, default=50000)
    parser.add_argument('--val_size', type=int, default=5000)
    parser.add_argument('--pretrain', action='store_true',
                                      default=False, help='whether to create data and pretrain on valset')
    parser.add_argument('--epochs', type=int, default=500)
    parser.add_argument('--batch_size', type=int, default=200)
    parser.add_argument('--iterations', type=int, default=1, help='T')
    parser.add_argument('--K', type=int, default=10, help='k')
    parser.add_argument('--data_path', default='./data', help='where to save data')
    parser.add_argument('--model_path', default='./save_data_cleaning', help='where to save model')
    parser.add_argument('--noise_rate', type=float, default=0.5)
    parser.add_argument('--x_lr', type=float, default=0.01)
    parser.add_argument('--xhat_lr', type=float, default=0.01)
    parser.add_argument('--w_lr', type=float, default=100)
    parser.add_argument('--w_momentum', type=float, default=0.9)
    parser.add_argument('--x_momentum', type=float, default=0.9)

    parser.add_argument('--eta', type=float, default=0.01)
    parser.add_argument('--u1', type=float, default=0.1)
    parser.add_argument('--seed', type=int, default=0)

    args = parser.parse_args()

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    return args


def get_data(args):

    data = {
        'mnist': datasets.MNIST,
        'fashion': datasets.FashionMNIST,
    }

    trainset = data[args.dataset](root=args.data_path,
                                  train=True,
                                  download=True)
    testset  = data[args.dataset](root=args.data_path,
                                  train=False,
                                  download=True)

    indices = torch.randperm(len(trainset))

    train_x = trainset.data[indices[:args.train_size]] / 255.
    val_x   = trainset.data[indices[args.train_size:args.train_size+args.val_size]] / 255.
    teval_x = trainset.data[indices[args.train_size+args.val_size:]] / 255.
    test_x  = testset.data / 255.

    targets = trainset.targets if args.dataset in ["mnist", "fashion"] else torch.LongTensor(trainset.targets) 
    train_y = targets[indices[:args.train_size]]
    val_y   = targets[indices[args.train_size:args.train_size+args.val_size]]
    teval_y = targets[indices[args.train_size+args.val_size:]]
    test_y  = torch.LongTensor(testset.targets)

    num_classes = test_y.unique().shape[0]
    assert val_y.unique().shape[0] == num_classes

    ### poison training data with noise rate = args.noise_rate
    num_noisy = int(args.train_size * args.noise_rate)
    rand_indices = torch.randperm(args.train_size)
    noisy_indices = rand_indices[:num_noisy]
    noisy_y = torch.randint(num_classes, size=(num_noisy,))
    old_train_y = train_y.data.clone()
    train_y.data[noisy_indices] = noisy_y.data

    # normalizing inputs to mean 0 and std 1.
    mean = train_x.unsqueeze(1).mean([0,2,3])
    std  = train_x.unsqueeze(1).std([0,2,3])

    trainset = ( torch.flatten((train_x  - mean)/(std+1e-4), start_dim=1), train_y )
    valset   = ( torch.flatten((val_x    - mean)/(std+1e-4), start_dim=1), val_y   )
    testset  = ( torch.flatten((test_x   - mean)/(std+1e-4), start_dim=1), test_y  )
    tevalset = ( torch.flatten((teval_x  - mean)/(std+1e-4), start_dim=1), teval_y )
    return trainset, valset, testset, tevalset, old_train_y

### initialize a linear model

def get_model(in_features, out_features, device):
    x = torch.zeros(out_features, in_features, requires_grad=True, device=device)

    weight = torch.empty((out_features, in_features))
    bias = torch.empty(out_features)
    nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
    fan_in, _ = nn.init._calculate_fan_in_and_fan_out(weight)
    bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
    nn.init.uniform_(bias, -bound, bound)

    x[:,:in_features].data.copy_(weight.clone().to(device))
    x[:, -1].data.copy_(bias.clone().to(device))
    x.data.copy_(weight.clone().to(device))
    return x

def model_forward(x, inputs):
    in_features = 28*28
    A = x[:,:in_features] # (out_features, in_features)
    b = x[:,-1] # (out_features,)
    y = inputs.mm(A.t()) + b.view(1,-1)
    return y

### original f, g, and gradients

def f(x, w, dataset):
    data_x, data_y = dataset
    y = model_forward(x, data_x)
    loss = F.cross_entropy(y, data_y)
    return loss

def g(x, w, dataset):
    data_x, data_y = dataset
    y = model_forward(x, data_x)
    loss = F.cross_entropy(y, data_y, reduction='none')
    loss = (loss * torch.clip(w, 0, 1)).mean() + 0.001 * x.norm(2).pow(2)
    return loss

def g_x(x, w, dataset, retain_graph=False, create_graph=False):
    loss = g(x, w, dataset)
    grad = torch.autograd.grad(loss, x,
                               retain_graph=retain_graph,
                               create_graph=create_graph)[0]
    return grad

def g_xx(x,w,vs,dataset):
    gra=torch.autograd.grad(g(x,w,dataset), x, retain_graph=True,allow_unused=True,create_graph=True,only_inputs=True)[0]
    gra.requires_grad_(True)
    grad=torch.autograd.grad(gra, x, grad_outputs=vs, retain_graph=True,
                                 allow_unused=True)[0]
    return grad if grad is not None else torch.zeros_like(x)

def g_w(x, w, dataset, retain_graph=False, create_graph=False):
    loss = g(x, w, dataset)
    grad = torch.autograd.grad(loss, w,
                               retain_graph=retain_graph,
                               create_graph=create_graph)[0]
    return grad

def g_wx(x,w,vs,dataset):
    gra=torch.autograd.grad(g(x,w,dataset), x, retain_graph=True,allow_unused=True,create_graph=True,only_inputs=True)[0]
    gra.requires_grad_(True)
    grad=torch.autograd.grad(gra, w, grad_outputs=vs, retain_graph=True,
                                 allow_unused=True)[0]
    return grad if grad is not None else torch.zeros_like(w)

def g_x_xhat_w(x, xhat, w, dataset, retain_graph=False, create_graph=False):
    loss = g(x, w, dataset) - g(xhat.detach(), w, dataset)
    grad = torch.autograd.grad(loss, [x, w],
                               retain_graph=retain_graph,
                               create_graph=create_graph)
    return loss, grad[0], grad[1]

def g_x_xhat_w_bo(x, xhat, w, dataset, retain_graph=False, create_graph=False):
    loss = g(x, w, dataset) - g(xhat, w, dataset)
    grad = torch.autograd.grad(loss, [x, xhat, w],
                               retain_graph=retain_graph,
                               create_graph=create_graph)
    return grad[0], grad[1], grad[2]

def f_x(x, w, dataset, retain_graph=False, create_graph=False):
    loss = f(x, w, dataset)
    grad = torch.autograd.grad(loss, x,
                               retain_graph=retain_graph,
                               create_graph=create_graph)[0]
    return grad


### Define evaluation metric

def evaluate(x, testset, tevalset):
    with torch.no_grad():
        test_x, test_y = testset  
        y = model_forward(x, test_x)
        test_loss = F.cross_entropy(y, test_y).detach().item()
        test_acc = y.argmax(-1).eq(test_y).float().mean().detach().cpu().item()
        # have a separate test val set since valset is used for training
        teval_x, teval_y = tevalset
        y_ = model_forward(x, teval_x)
        teval_loss = F.cross_entropy(y_, teval_y).detach().item()
    return test_loss, test_acc, teval_loss

def evaluate_importance_f1(w, clean_indices):
    with torch.no_grad():
        w_ = w.gt(0.5).float()
        TP = (w_ * clean_indices.float()).sum()
        recall = TP / (clean_indices.float().sum()+1e-4)
        precision = TP / (w_.sum()+1e-4)
        f1 = 2.0 * recall * precision / (recall + precision + 1e-4)
    return precision.cpu().item(), recall.cpu().item(), f1.cpu().item()

### inner solver bfgs
def bfgs(x,w,tol,step,maxiter_hg,stepsize1,stepsize2,m,h0=0.01,ex_up=False,ws=3): 
            y_list, s_list, mu_list = [], [], []
            y1_list, s1_list, mu1_list = [], [], []
            for k in range(1, step + 1):
                if k<ws:
                   x = x- 0.1 * g_x(x,w,trainset)
                   new_grad=g_x(x,w,trainset)
                   ngrad=new_grad.detach().cpu().numpy()
                else:
                   p = two_loops(grady, m, s_list, y_list, mu_list,h0)#default H0=I
                   s= 0.1*p
                   st=torch.from_numpy(s).cuda().float()
                   x = x +st
                   new_grad=g_x(x,w,trainset)
                   ngrad=new_grad.detach().cpu().numpy()#\nabla_y f(x_k,y_{k+1})
                   y=ngrad-grady
                    # Update the memory
                   if (safe_sparse_dot(np.ravel(y),np.ravel(s)))>1e-10:
                       y_list.append(y.copy())
                       s_list.append(s.copy())
                       mu=1/safe_sparse_dot(np.ravel(y),np.ravel(s))
                       mu_list.append(mu)
                   if len(y_list) > m:
                       y_list.pop(0)
                       s_list.pop(0)
                       mu_list.pop(0)
                grady=ngrad

            fx = f_x(x, w, valset)# dy F x/y shape(10,784)
            gradFy=fx.detach().cpu().numpy()#\nabla_y F(x_k,y_{k+1})
            if ex_up==False:
                hg = -two_loops(gradFy, m, s_list, y_list, mu_list,h0)
                et=torch.from_numpy(hg).cuda().float()
            else:
                for i in range (1, maxiter_hg + 1):
                    eq = - two_loops(gradFy, m, s1_list, y1_list, mu1_list,h0)#default H0=I
                    en=np.linalg.norm(np.ravel(eq))
                    eq = eq /en
                    et=torch.from_numpy(eq).cuda().float()
                    x1=x+et
                    grad=g_x(x1,w,trainset)
                    f1grad=grad.detach().cpu().numpy()
                    y_tilde1 = f1grad- grady
                    if safe_sparse_dot(np.ravel(y_tilde1), np.ravel(eq))>1e-10:
                      mu1 = 1 / safe_sparse_dot(np.ravel(y_tilde1), np.ravel(eq))
                      y1_list.append(y_tilde1.copy())
                      s1_list.append(eq.copy())
                      mu1_list.append(mu1)
                    if len(y1_list) > m:
                      y1_list.pop(0)
                      s1_list.pop(0)
                      mu1_list.pop(0)
            print(f'{k} iterates')
            return x, et

def two_loops(grad_y, m, s_list, y_list, mu_list,h0):
            q = grad_y.copy()
            alpha_list = []
            for s, y, mu in zip(reversed(s_list), reversed(y_list), reversed(mu_list)):
                alpha = mu * safe_sparse_dot(np.ravel(s), np.ravel(q))
                alpha_list.append(alpha)
                q -= alpha * y
            r=h0*q
            for s, y, mu, alpha in zip(s_list, y_list, mu_list, reversed(alpha_list)):
                beta = mu * safe_sparse_dot(np.ravel(y), np.ravel(r))
                r += (alpha - beta) * s
            return -r


###############################################################################
#
# Bilevel Optimization Training Methods
#
###############################################################################
def simple_train(args, x, data_x, data_y, testset, tevalset, tag='pretrain', regularize=False): # directly train on the dataset
    opt = torch.optim.SGD([x], lr=args.x_lr, momentum=args.x_momentum)
    n = data_x.shape[0]

    n_epochs = 500
    best_teval_loss = np.inf
    final_test_loss = 0.
    final_test_acc = 0.
    best_x = None

    for epoch in range(n_epochs):
        opt.zero_grad()
        y = model_forward(x, data_x)
        loss = F.cross_entropy(y, data_y)
        if regularize:
            loss += 0.001 * x.norm(2).pow(2)
        loss.backward()
        opt.step()

        test_loss, test_acc, teval_loss = evaluate(x, testset, tevalset)
        if teval_loss <= best_teval_loss:
            best_teval_loss = teval_loss
            final_test_loss = test_loss
            final_test_acc  = test_acc
            best_x = x.data.clone()
        print(f"[{tag}] epoch {epoch:5d} test loss {test_loss:10.4f} test acc {test_acc:10.4f}")
    return final_test_loss, final_test_acc, best_x


def blfoa(args, x, w, trainset, valset, testset, tevalset,clean_indices,test_loss,test_acc):
    total_time = 0.0
    stats = []
    test_lossesfoa=[]
    test_accsfoa=[]
    running_timefoa=[]

    test_lossesfoa.append(test_loss)
    test_accsfoa.append(test_acc)
    running_timefoa.append(total_time)

    outer_opt = torch.optim.SGD([w], lr=args.w_lr, momentum=args.w_momentum)
    for epoch in range(args.epochs):
        t0 = time.time()
        
        x,et=bfgs(x,w,tol=1/(epoch+1),step=10,maxiter_hg=3,stepsize1=0.1,stepsize2=0.1,m=30,h0=1,ex_up=False,ws=4)
            
        #fw = f_w(x, w, valset) f_w=0
      
        outer_opt.zero_grad()
        rhg=g_wx(x,w,et,trainset)
        w.grad = ( -rhg).data #fw=0
        outer_opt.step()

        t1 = time.time()
        total_time += t1 - t0
        w.data.clamp_(0.0, 1.0)
        test_loss, test_acc, teval_loss = evaluate(x, testset, tevalset)
        f1 = evaluate_importance_f1(w, clean_indices)
        stats.append((total_time, test_loss, test_acc, teval_loss))
        test_lossesfoa.append(test_loss)
        test_accsfoa.append(test_acc)
        running_timefoa.append(total_time)
        print(f"[info] epoch {epoch:5d} | te loss {test_loss:6.4f} | te acc {test_acc:4.2f} | teval loss {teval_loss:6.4f} | time {total_time:6.2f} | w-min {w.min().item():4.2f} w-max {w.max().item():4.2f} | f1 {f1[2]:4.2f}")
        if total_time>30:
            break
    return stats,running_timefoa,test_lossesfoa,test_accsfoa



def blfoat10(args, x, w, trainset, valset, testset, tevalset,clean_indices,test_loss,test_acc):
    total_time = 0.0
    stats = []
    test_lossesfoat10=[]
    test_accsfoat10=[]
    running_timefoat10=[]

    test_lossesfoat10.append(test_loss)
    test_accsfoat10.append(test_acc)
    running_timefoat10.append(total_time)

    outer_opt = torch.optim.SGD([w], lr=args.w_lr, momentum=args.w_momentum)
    for epoch in range(args.epochs):
        t0 = time.time()
        
        x,et=bfgs(x,w,tol=1/(epoch+1),step=10,maxiter_hg=3,stepsize1=0.1,stepsize2=0.1,m=30,h0=10,ex_up=True,ws=4)
            
        #fw = f_w(x, w, valset) f_w=0
      
        outer_opt.zero_grad()
        rhg=g_wx(x,w,et,trainset)
        w.grad = ( -rhg).data #fw=0
        outer_opt.step()

        t1 = time.time()
        total_time += t1 - t0
        w.data.clamp_(0.0, 1.0)
        test_loss, test_acc, teval_loss = evaluate(x, testset, tevalset)
        f1 = evaluate_importance_f1(w, clean_indices)
        stats.append((total_time, test_loss, test_acc, teval_loss))
        test_lossesfoat10.append(test_loss)
        test_accsfoat10.append(test_acc)
        running_timefoat10.append(total_time)
        print(f"[info] epoch {epoch:5d} | te loss {test_loss:6.4f} | te acc {test_acc:4.2f} | teval loss {teval_loss:6.4f} | time {total_time:6.2f} | w-min {w.min().item():4.2f} w-max {w.max().item():4.2f} | f1 {f1[2]:4.2f}")
        if total_time>30:
            break
    return stats,running_timefoat10,test_lossesfoat10,test_accsfoat10


def blfoat15(args, x, w, trainset, valset, testset, tevalset,clean_indices,test_loss,test_acc):
    total_time = 0.0
    stats = []
    test_lossesfoat15=[]
    test_accsfoat15=[]
    running_timefoat15=[]

    test_lossesfoat15.append(test_loss)
    test_accsfoat15.append(test_acc)
    running_timefoat15.append(total_time)

    outer_opt = torch.optim.SGD([w], lr=args.w_lr, momentum=args.w_momentum)
    for epoch in range(args.epochs):
        t0 = time.time()
        
        x,et=bfgs(x,w,tol=1/(epoch+1),step=10,maxiter_hg=3,stepsize1=0.1,stepsize2=0.1,m=30,h0=15,ex_up=True,ws=4)
            
        #fw = f_w(x, w, valset) f_w=0
      
        outer_opt.zero_grad()
        rhg=g_wx(x,w,et,trainset)
        w.grad = ( -rhg).data #fw=0
        outer_opt.step()

        t1 = time.time()
        total_time += t1 - t0
        w.data.clamp_(0.0, 1.0)
        test_loss, test_acc, teval_loss = evaluate(x, testset, tevalset)
        f1 = evaluate_importance_f1(w, clean_indices)
        stats.append((total_time, test_loss, test_acc, teval_loss))
        test_lossesfoat15.append(test_loss)
        test_accsfoat15.append(test_acc)
        running_timefoat15.append(total_time)
        print(f"[info] epoch {epoch:5d} | te loss {test_loss:6.4f} | te acc {test_acc:4.2f} | teval loss {teval_loss:6.4f} | time {total_time:6.2f} | w-min {w.min().item():4.2f} w-max {w.max().item():4.2f} | f1 {f1[2]:4.2f}")
        if total_time>30:
            break
    return stats,running_timefoat15,test_lossesfoat15,test_accsfoat15


def blfoat20(args, x, w, trainset, valset, testset, tevalset,clean_indices,test_loss,test_acc):
    total_time = 0.0
    stats = []
    test_lossesfoat20=[]
    test_accsfoat20=[]
    running_timefoat20=[]

    test_lossesfoat20.append(test_loss)
    test_accsfoat20.append(test_acc)
    running_timefoat20.append(total_time)

    outer_opt = torch.optim.SGD([w], lr=args.w_lr, momentum=args.w_momentum)
    for epoch in range(args.epochs):
        t0 = time.time()
        
        x,et=bfgs(x,w,tol=1/(epoch+1),step=10,maxiter_hg=3,stepsize1=0.1,stepsize2=0.1,m=30,h0=20,ex_up=True,ws=4)
            
        #fw = f_w(x, w, valset) f_w=0
      
        outer_opt.zero_grad()
        rhg=g_wx(x,w,et,trainset)
        w.grad = ( -rhg).data #fw=0
        outer_opt.step()

        t1 = time.time()
        total_time += t1 - t0
        w.data.clamp_(0.0, 1.0)
        test_loss, test_acc, teval_loss = evaluate(x, testset, tevalset)
        f1 = evaluate_importance_f1(w, clean_indices)
        stats.append((total_time, test_loss, test_acc, teval_loss))
        test_lossesfoat20.append(test_loss)
        test_accsfoat20.append(test_acc)
        running_timefoat20.append(total_time)
        print(f"[info] epoch {epoch:5d} | te loss {test_loss:6.4f} | te acc {test_acc:4.2f} | teval loss {teval_loss:6.4f} | time {total_time:6.2f} | w-min {w.min().item():4.2f} w-max {w.max().item():4.2f} | f1 {f1[2]:4.2f}")
        if total_time>30:
            break
    return stats,running_timefoat20,test_lossesfoat20,test_accsfoat20

def blfoat25(args, x, w, trainset, valset, testset, tevalset,clean_indices,test_loss,test_acc):
    total_time = 0.0
    stats = []
    test_lossesfoat25=[]
    test_accsfoat25=[]
    running_timefoat25=[]

    test_lossesfoat25.append(test_loss)
    test_accsfoat25.append(test_acc)
    running_timefoat25.append(total_time)
    outer_opt = torch.optim.SGD([w], lr=args.w_lr, momentum=args.w_momentum)
    for epoch in range(args.epochs):
        t0 = time.time()
        # x,et=bfgs(x,w,tol=1/(epoch+1),step=10,maxiter_hg=epoch+1,stepsize1=0.1,stepsize2=0.1,m=30,h0=1,ex_up=True,ws=4)

        x,et=bfgs(x,w,tol=1/(epoch+1),step=10,maxiter_hg=3,stepsize1=0.1,stepsize2=0.1,m=30,h0=25,ex_up=True,ws=4)
            
        #fw = f_w(x, w, valset) f_w=0
      
        outer_opt.zero_grad()
        rhg=g_wx(x,w,et,trainset)
        w.grad = ( -rhg).data #fw=0
        outer_opt.step()

        t1 = time.time()
        total_time += t1 - t0
        w.data.clamp_(0.0, 1.0)
        test_loss, test_acc, teval_loss = evaluate(x, testset, tevalset)
        f1 = evaluate_importance_f1(w, clean_indices)
        stats.append((total_time, test_loss, test_acc, teval_loss))
        test_lossesfoat25.append(test_loss)
        test_accsfoat25.append(test_acc)
        running_timefoat25.append(total_time)
        print(f"[info] epoch {epoch:5d} | te loss {test_loss:6.4f} | te acc {test_acc:4.2f} | teval loss {teval_loss:6.4f} | time {total_time:6.2f} | w-min {w.min().item():4.2f} w-max {w.max().item():4.2f} | f1 {f1[2]:4.2f}")
        if total_time>30:
            break
    return stats,running_timefoat25,test_lossesfoat25,test_accsfoat25


if __name__ == "__main__":
    args = parse_args()

    if args.pretrain: # preprocess data and pretrain a model on validation set

        if not os.path.exists(args.data_path):
            os.makedirs(args.data_path)

        if not os.path.exists(args.model_path):
            os.makedirs(args.model_path)

        ### generate data
        trainset, valset, testset, tevalset, old_train_y = get_data(args)
        torch.save((trainset, valset, testset, tevalset, old_train_y),
                   os.path.join(args.data_path, f"{args.dataset}_data_cleaning.pt"))
        print(f"[info] successfully generated data to {args.data_path}/{args.dataset}_data_cleaning.pt")

        ### pretrain a model and save it
        n_feats = np.prod(*trainset[0].shape[1:])
        num_classes = trainset[1].unique().shape[-1]
        args.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

        trainset = (trainset[0].to(args.device), trainset[1].to(args.device))
        valset   = (valset[0].to(args.device),   valset[1].to(args.device))
        testset  = (testset[0].to(args.device),  testset[1].to(args.device))
        tevalset = (tevalset[0].to(args.device), tevalset[1].to(args.device))
        old_train_y = old_train_y.to(args.device)

        x = get_model(n_feats, num_classes, args.device)
        sd = x.data.clone()

        # lower bound (train on noisy train + valset)
        tmp_x = torch.cat([trainset[0], valset[0]], 0)
        tmp_y = torch.cat([trainset[1], valset[1]], 0)
        test_loss1, test_acc1, best_x1 = simple_train(args, x, tmp_x, tmp_y, testset, tevalset, regularize=True)
        torch.save(best_x1.data.cpu().clone(),
                   os.path.join(args.model_path, f"{args.dataset}_pretrained.pt"))

        # a baseline: train on valset
        x.data.copy_(sd)
        test_loss2, test_acc2, best_x2 = simple_train(args, x, valset[0], valset[1], testset, tevalset)
        torch.save(best_x2.data.cpu().clone(),
                   os.path.join(args.model_path, f"{args.dataset}_pretrained_val.pt"))

        # upper bound (train on correct train + valset)
        x.data.copy_(sd)
        tmp_x = torch.cat([trainset[0], valset[0]], 0)
        tmp_y = torch.cat([old_train_y, valset[1]], 0)
        test_loss3, test_acc3, best_x3 = simple_train(args, x, tmp_x, tmp_y, testset, tevalset)
        torch.save(best_x3.data.cpu().clone(),
                   os.path.join(args.model_path, f"{args.dataset}_pretrained_trainval.pt"))

        print(f"[pretrained] noisy train + val   : test loss {test_loss1} test acc {test_acc1}")
        print(f"[pretrained] val                 : test loss {test_loss2} test acc {test_acc2}")
        print(f"[pretrained] correct train + val : test loss {test_loss3} test acc {test_acc3}")

        torch.save({
            "pretrain_test_loss": test_loss1,
            "pretrain_test_acc": test_acc1,
            "pretrain_val_test_loss": test_loss2,
            "pretrain_val_test_acc": test_acc2,
            "pretrain_trainval_test_loss": test_loss3,
            "pretrain_trainval_test_acc": test_acc3,
            }, os.path.join(args.model_path, f"{args.dataset}_pretrained.stats"))


    else: # load pretrained model on valset and then start model training
        trainset, valset, testset, tevalset, old_train_y = torch.load(
                os.path.join(args.data_path, f"{args.dataset}_data_cleaning.pt"))
        args.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

        n_feats = np.prod(*trainset[0].shape[1:])
        num_classes = trainset[1].unique().shape[-1]

        trainset = (trainset[0].to(args.device), trainset[1].to(args.device))
        valset   = (valset[0].to(args.device),   valset[1].to(args.device))
        testset  = (testset[0].to(args.device),  testset[1].to(args.device))
        tevalset = (tevalset[0].to(args.device), tevalset[1].to(args.device))
        old_train_y = old_train_y.to(args.device)

        x = get_model(n_feats, num_classes, args.device)
        x.data.copy_(torch.load(os.path.join(args.model_path, f"{args.dataset}_pretrained.pt")).to(args.device))

        # load the pretrained model on validation set
        pretrained_stats = torch.load(
            os.path.join(args.model_path, f"{args.dataset}_pretrained.stats"))

        test_loss1 = pretrained_stats['pretrain_test_loss']
        test_loss2 = pretrained_stats['pretrain_val_test_loss']
        test_loss3 = pretrained_stats['pretrain_trainval_test_loss']
        test_acc1  = pretrained_stats['pretrain_test_acc']
        test_acc2  = pretrained_stats['pretrain_val_test_acc']
        test_acc3  = pretrained_stats['pretrain_trainval_test_acc']
        print(f"[pretrained] noisy train + val   : test loss {test_loss1} test acc {test_acc1}")
        print(f"[pretrained] val                 : test loss {test_loss2} test acc {test_acc2}")
        print(f"[pretrained] correct train + val : test loss {test_loss3} test acc {test_acc3}")

        test_loss, test_acc, teval_loss = evaluate(x, testset, tevalset)
        print("original test loss ", test_loss, "original test acc ", test_acc)

        clean_indices = old_train_y.to(args.device).eq(trainset[1])
        w = torch.zeros(trainset[0].shape[0], requires_grad=True, device=args.device)
        w.data.add_(0.5)

        
        stats,running_timefoa,test_lossesfoa,test_accsfoa = eval('blfoa')(args=args,
                               x=x,
                               w=w,
                               trainset=trainset,
                               valset=valset,
                               testset=testset,
                               tevalset=tevalset,
                               clean_indices=clean_indices,test_loss=test_loss,test_acc=test_acc)
        
        
        x = get_model(n_feats, num_classes, args.device)
        x.data.copy_(torch.load(os.path.join(args.model_path, f"{args.dataset}_pretrained.pt")).to(args.device))

        test_loss, test_acc, teval_loss = evaluate(x, testset, tevalset)
        print("original test loss ", test_loss, "original test acc ", test_acc)

        clean_indices = old_train_y.to(args.device).eq(trainset[1])
        w = torch.zeros(trainset[0].shape[0], requires_grad=True, device=args.device)
        w.data.add_(0.5)

        
        stats,running_timefoat10,test_lossesfoat10,test_accsfoat10 = eval('blfoat10')(args=args,
                               x=x,
                               w=w,
                               trainset=trainset,
                               valset=valset,
                               testset=testset,
                               tevalset=tevalset,
                               clean_indices=clean_indices,test_loss=test_loss,test_acc=test_acc)
        
        
        x = get_model(n_feats, num_classes, args.device)
        x.data.copy_(torch.load(os.path.join(args.model_path, f"{args.dataset}_pretrained.pt")).to(args.device))

        test_loss, test_acc, teval_loss = evaluate(x, testset, tevalset)
        print("original test loss ", test_loss, "original test acc ", test_acc)

        clean_indices = old_train_y.to(args.device).eq(trainset[1])
        w = torch.zeros(trainset[0].shape[0], requires_grad=True, device=args.device)
        w.data.add_(0.5)

        stats,running_timefoat15,test_lossesfoat15,test_accsfoat15 = eval('blfoat15')(args=args,
                               x=x,
                               w=w,
                               trainset=trainset,
                               valset=valset,
                               testset=testset,
                               tevalset=tevalset,
                               clean_indices=clean_indices,test_loss=test_loss,test_acc=test_acc)
        

        x = get_model(n_feats, num_classes, args.device)
        x.data.copy_(torch.load(os.path.join(args.model_path, f"{args.dataset}_pretrained.pt")).to(args.device))

        test_loss, test_acc, teval_loss = evaluate(x, testset, tevalset)
        print("original test loss ", test_loss, "original test acc ", test_acc)

        clean_indices = old_train_y.to(args.device).eq(trainset[1])
        w = torch.zeros(trainset[0].shape[0], requires_grad=True, device=args.device)
        w.data.add_(0.5)

        stats,running_timefoat20,test_lossesfoat20,test_accsfoat20 = eval('blfoat20')(args=args,
                               x=x,
                               w=w,
                               trainset=trainset,
                               valset=valset,
                               testset=testset,
                               tevalset=tevalset,
                               clean_indices=clean_indices,test_loss=test_loss,test_acc=test_acc)
        
        x = get_model(n_feats, num_classes, args.device)
        x.data.copy_(torch.load(os.path.join(args.model_path, f"{args.dataset}_pretrained.pt")).to(args.device))

        test_loss, test_acc, teval_loss = evaluate(x, testset, tevalset)
        print("original test loss ", test_loss, "original test acc ", test_acc)

        clean_indices = old_train_y.to(args.device).eq(trainset[1])
        w = torch.zeros(trainset[0].shape[0], requires_grad=True, device=args.device)
        w.data.add_(0.5)

        stats,running_timefoat25,test_lossesfoat25,test_accsfoat25 = eval('blfoat25')(args=args,
                               x=x,
                               w=w,
                               trainset=trainset,
                               valset=valset,
                               testset=testset,
                               tevalset=tevalset,
                               clean_indices=clean_indices,test_loss=test_loss,test_acc=test_acc)
        
       

        
        lw = 1.5

        plt.figure(figsize=(6,5))
        plt.plot(running_timefoat25, test_lossesfoat25, '-', label='h0=25', linewidth=lw)
        plt.plot(running_timefoat20, test_lossesfoat20, '-', label='h0=20', linewidth=lw)
        plt.plot(running_timefoat15, test_lossesfoat15, '-', label='h0=15', linewidth=lw)
        plt.plot(running_timefoat10, test_lossesfoat10, '-', label='h0=10', linewidth=lw)
        plt.plot(running_timefoa, test_lossesfoa, '-', label='h0=1', linewidth=lw)
        plt.xlabel('Running time(s)', fontsize=20)
        plt.ylabel('Test loss', fontsize=20)
        plt.xticks(fontsize=15)
        plt.yticks(fontsize=15)
        

        legend_font = FontProperties(weight='bold')
        plt.legend(ncol=3, fontsize=8, prop=legend_font)

        plt.savefig('dclfoadqm.pdf', dpi=300, bbox_inches='tight')

        plt.figure(figsize=(6,5))
        plt.plot(running_timefoat25, test_accsfoat25, '-', label='h0=25', linewidth=lw)
        plt.plot(running_timefoat20, test_accsfoat20, '-', label='h0=20', linewidth=lw)
        plt.plot(running_timefoat15, test_accsfoat15, '-', label='h0=15', linewidth=lw)
        plt.plot(running_timefoat10, test_accsfoat10, '-', label='h0=10', linewidth=lw)
        plt.plot(running_timefoa, test_accsfoa, '-', label='h0=1', linewidth=lw)
        plt.xlabel('Running time(s)', fontsize=20)
        plt.ylabel('Test accuracy', fontsize=20)
        plt.xticks(fontsize=15)
        plt.yticks(fontsize=15)
        
        

        legend_font = FontProperties(weight='bold')
        plt.legend(ncol=3, fontsize=8, prop=legend_font)

        plt.savefig('dcfoadqm.pdf', dpi=300, bbox_inches='tight')