from __future__ import print_function
from curses import meta
from operator import is_
from re import I
from selectors import EpollSelector
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import random
import os, shutil
import argparse
import numpy as np
from PreResNet import *
from sklearn.mixture import GaussianMixture
import dataloader_cifar as dataloader
from torch.utils.tensorboard import SummaryWriter
from utils import MetaNet_Bin, verbose_prob_estimate_both, prob_prototype
import pickle
parser = argparse.ArgumentParser(description='PyTorch CIFAR Training')
parser.add_argument('--batch_size', default=64, type=int, help='train batchsize') 
parser.add_argument('--lr', '--learning_rate', default=0.02, type=float, help='initial learning rate')
parser.add_argument('--meta_lr', '--meta_learning_rate', default=0.02, type=float, help='initial learning rate')
parser.add_argument('--noise_mode',  default='sym')
parser.add_argument('--alpha', default=4, type=float, help='parameter for Beta')
parser.add_argument('--lambda_u', default=25, type=float, help='weight for unsupervised loss')
parser.add_argument('--p_threshold', default=0.5, type=float, help='clean probability threshold')
parser.add_argument('--T', default=0.5, type=float, help='sharpening temperature')
parser.add_argument('--num_epochs', default=450, type=int)
parser.add_argument('--r', default=0.5, type=float, help='noise ratio')
parser.add_argument('--id', default='')
parser.add_argument('--seed', default=123)
parser.add_argument('--gpuid', default=0, type=int)
parser.add_argument('--num_class', default=10, type=int)
parser.add_argument('--data_path', default='./Data/cifar10/cifar-10-batches-py', type=str, help='path to dataset')
parser.add_argument('--dataset', default='cifar10', type=str)
parser.add_argument('--gmm_ablation', action='store_true', default=False)# for ablation study on gmm cleaner
parser.add_argument('--meta_thd', default=0.5, type=float)
args = parser.parse_args()

if args.noise_mode == 'sym':
    all2per = False # use class-aware gmm cleaner under symetric noise
elif args.noise_mode == 'asym':
    all2per = True # use class-agnostic gmm cleaner under asymetric noise
else:
    assert 'unknown noise type' 
    
torch.cuda.set_device(args.gpuid)
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)


# Training
def train(epoch,net,net2, meta_net, optimizer, meta_optimizer, labeled_trainloader,unlabeled_trainloader, model_name=None, global_iter=0, global_meta_iter=0):
    net.train()
    meta_net.train()
    net2.eval() #fix one network and train the other

    unlabeled_train_iter = iter(unlabeled_trainloader)    
    num_iter = (len(labeled_trainloader.dataset)//args.batch_size)+1
    
    for batch_idx, (inputs_x, inputs_x2, inputs_x3, inputs_x4,  labels_x, w_x, eval_loss_x) in enumerate(labeled_trainloader):      
        global_iter += 1
        global_meta_iter += 1

        try:
            inputs_u, inputs_u2, inputs_u3, inputs_u4, labels_un, w_u, eval_loss_u = unlabeled_train_iter.next()
        except:
            unlabeled_train_iter = iter(unlabeled_trainloader)
            inputs_u, inputs_u2, inputs_u3, inputs_u4,  labels_un, w_u,  eval_loss_u = unlabeled_train_iter.next()                 
        
        
        batch_size = inputs_x.size(0)
        
        # Transform label to one-hot
        labels_x_l = labels_x.view(-1,1) 
        labels_x = torch.zeros(batch_size, args.num_class).scatter_(1, labels_x.view(-1,1), 1)        
        w_x = w_x.view(-1,1).type(torch.FloatTensor) 
        labels_un_l = labels_un.view(-1,1) 
        labels_un = torch.zeros(labels_un.size(0), args.num_class).scatter_(1, labels_un.view(-1,1), 1)  
        w_u = w_u.view(-1,1).type(torch.FloatTensor)  

        inputs_x, inputs_x2,inputs_x3, inputs_x4, labels_x, labels_x_l, w_x,eval_loss_x= inputs_x.cuda(), inputs_x2.cuda(),inputs_x3.cuda(), inputs_x4.cuda(), labels_x.cuda(), labels_x_l.cuda(),w_x.cuda(),eval_loss_x.cuda()
        
          
        inputs_u, inputs_u2,inputs_u3, inputs_u4, labels_un, labels_un_l, w_u, eval_loss_u = inputs_u.cuda(), inputs_u2.cuda(),inputs_u3.cuda(), inputs_u4.cuda(), labels_un.cuda(), labels_un_l.cuda(),w_u.cuda(), eval_loss_u.cuda()


        if True:
            with torch.no_grad():
                # label co-guessing of unlabeled samples
                fea_u11, outputs_u11 = net.fea_forward(inputs_u3)
                fea_u12, outputs_u12 = net.fea_forward(inputs_u4)
                            
                outputs_u21 = net2(inputs_u3)
                outputs_u22 = net2(inputs_u4)

                pu11, pu12, pu21, pu22 = torch.softmax(outputs_u11, dim=1) , torch.softmax(outputs_u12, dim=1),  torch.softmax(outputs_u21, dim=1), torch.softmax(outputs_u22, dim=1)
                pu = (pu11+pu12+pu21+pu22)/4
  
                
                ptu = pu**(1/args.T) # temparature sharpening
                
                targets_u = ptu / ptu.sum(dim=1, keepdim=True) # normalize
                targets_u = targets_u.detach()       
                
        if True:
            with torch.no_grad():
                label_l = inputs_x.size(0)
                    # label refinement of labeled samples
                
                fea_x1, outputs_x = net.fea_forward(inputs_x3)
                fea_x2, outputs_x2 = net.fea_forward(inputs_x4)             
                   
                px = (torch.softmax(outputs_x, dim=1) + torch.softmax(outputs_x2, dim=1)) / 2
            
            px_g = w_x.view(-1,1)*labels_x + (1-w_x.view(-1,1))*px.detach()
            ptx = px_g**(1/args.T) # temparature sharpening  
                        
            targets_x = ptx / ptx.sum(dim=1, keepdim=True)     
        
        
        if True:   
            # mixmatch
            l = np.random.beta(args.alpha, args.alpha)        
            l = max(l, 1-l)
                    
            all_inputs = torch.cat([inputs_x, inputs_x2, inputs_u, inputs_u2], dim=0)
            all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u], dim=0)

            idx = torch.randperm(all_inputs.size(0))

            input_a, input_b = all_inputs, all_inputs[idx]
            target_a, target_b = all_targets, all_targets[idx]
            
                 
            try:
                mixed_input = l * input_a + (1 - l) * input_b        
                mixed_target = l * target_a + (1 - l) * target_b
            except Exception as e:
                print(e)

            mixed_input = Variable(mixed_input)

            logits = net(mixed_input)
            logits_x = logits[:label_l*2]
            logits_u = logits[label_l*2:]        

            if True:   
                Lx, Lu, lamb = criterion(logits_x, mixed_target[:label_l*2].detach(), logits_u, mixed_target[label_l*2:].detach(), epoch+batch_idx/num_iter, warm_up)

                if args.gmm_ablation:
                    
                    _, meta_output_u1_s = meta_net(fea_u11, labels_un_l)
                    _, meta_output_u2_s = meta_net(fea_u12, labels_un_l) 
                    pseudo_label_u = torch.max(eval_loss_u, dim=1)[1] 

                    meta_reg_u1 = CEloss(meta_output_u1_s, pseudo_label_u.type_as(labels_un_l))
                    meta_reg_u2 = CEloss(meta_output_u2_s, pseudo_label_u.type_as(labels_un_l)) 

  
                    _, meta_output_x1_s = meta_net(fea_x1, labels_x_l) 
                    _, meta_output_x2_s = meta_net(fea_x2, labels_x_l)
                    pseudo_label_x = torch.max(eval_loss_x, dim=1)[1]
                    meta_reg_x1 = CEloss(meta_output_x1_s, pseudo_label_x.type_as(labels_x_l))
                    meta_reg_x2 = CEloss(meta_output_x2_s, pseudo_label_x.type_as(labels_x_l)) 

                    meta_reg = (meta_reg_u1 + meta_reg_u2 + meta_reg_x1 + meta_reg_x2)/4 

                    tf_writer.add_scalar('MetaNet/{}/train_loss'.format(model_name), meta_reg.detach(), global_meta_iter) 

                    meta_optimizer.zero_grad()
                    meta_reg.backward()
                    meta_optimizer.step() 
                         
                else:
                    
                    labels_un_l = labels_un_l.view(-1)
                    if inputs_u.size(0) > 0:
                        meta_pred_u1, _ = meta_net(fea_u11, labels_un_l)
                        meta_pred_u2, _ = meta_net(fea_u12, labels_un_l) 
                        
                        if use_meta_label < epoch and use_meta_label > 0:
                            l_u = eval_loss_u.type_as(w_u)
                        else:
                            l_u = w_u

                        select_idx_u = l_u.view(-1) < min(args.meta_thd, args.p_threshold)
                        

                        l_u = torch.zeros_like(w_u).cuda()
                        
                        meta_reg_u1 = meta_bce(meta_pred_u1.view(-1), l_u.view(-1))
                        meta_reg_u2 = meta_bce(meta_pred_u2.view(-1), l_u.view(-1))
                        

                    labels_x_l = labels_x_l.view(-1)
                    if inputs_x.size(0) > 0:

                        meta_pred_x1, meta_output_x1_s = meta_net(fea_x1, labels_x_l) 
                        meta_pred_x2, meta_output_x2_s = meta_net(fea_x2, labels_x_l) 
                        meta_output_x1 = torch.sigmoid(meta_output_x1_s)
                        meta_output_x2 = torch.sigmoid(meta_output_x2_s)
                        
                        if use_meta_label < epoch and use_meta_label > 0:
                            l_x = eval_loss_x.type_as(w_x)
                        else:
                            l_x = w_x

                        select_idx_x = l_x.view(-1) > max(1-args.meta_thd, 1-args.p_threshold)

                        
                        l_x = torch.ones_like(w_x).cuda()

                        if True:
                            full_t = torch.zeros(labels_x_l.size(0), args.num_class).type_as(l_x)
                            full_t = full_t.scatter(1, labels_x_l.view(-1,1), l_x.view(-1,1))

                        
                            mask = (torch.ones(l_x.size(0), args.num_class) ).type_as(meta_output_x1) * neg_weight  


                            mask = mask * l_x.view(-1,1) 
                            mask = mask.scatter(1, labels_x_l.view(-1,1), 1)
                            mask = mask * select_idx_x.view(-1,1)     

                            meta_reg_x1 = meta_bce(meta_output_x1.view(-1), full_t.view(-1))
                            meta_reg_x2 = meta_bce(meta_output_x2.view(-1), full_t.view(-1))
                        
                        

                    if inputs_u.size(0) == 0 and inputs_x.size(0) == 0: 
                        pass
                    else:
                        if inputs_u.size(0) > 0 and inputs_x.size(0) > 0: 
                            
                            meta_reg = ((meta_reg_u1 * select_idx_u).sum() + (meta_reg_u2 * select_idx_u).sum() + (meta_reg_x1 * mask.view(-1)).sum() + (meta_reg_x2 * mask.view(-1)).sum() )/(select_idx_u.sum()*2 + mask.sum()*2)    
                        elif inputs_u.size(0) > 0:
                            meta_reg = ((meta_reg_u1 * select_idx_u).sum() + (meta_reg_u2 * select_idx_u).sum()) /select_idx_u.sum()*2 
                        else:                            
                            meta_reg = ((meta_reg_x1 * mask.view(-1)).sum() + (meta_reg_x2 * mask.view(-1)).sum())/ mask.sum()*2
                            

                        tf_writer.add_scalar('MetaNet/{}/train_loss'.format(model_name), meta_reg.detach(), global_meta_iter) 

                        meta_optimizer.zero_grad()
                        meta_reg.backward()
                        meta_optimizer.step()  
                
           
            # regularization
            prior = torch.ones(args.num_class)/args.num_class
            prior = prior.cuda()        
            pred_mean = torch.softmax(logits, dim=1).mean(0)
            penalty = torch.sum(prior*torch.log(prior/pred_mean))
            
            loss = Lx + lamb * Lu  + penalty

            tf_writer.add_scalar(model_name+'/Loss', loss, global_iter)
            tf_writer.add_scalar(model_name+'/Lx', Lx, global_iter)
            tf_writer.add_scalar(model_name+'/Lu', lamb *Lu, global_iter)
            tf_writer.add_scalar(model_name+'/penalty', penalty, global_iter)

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


            sys.stdout.write('\r')
            sys.stdout.write('%s:%.1f-%s | Epoch [%3d/%3d] Iter[%3d/%3d]\t Labeled loss: %.2f  Unlabeled loss: %.2f'
                        %(args.dataset, args.r, args.noise_mode, epoch, args.num_epochs, batch_idx+1, num_iter, Lx, Lu))
            sys.stdout.flush()

    return global_iter, global_meta_iter

def warmup(epoch,net,optimizer,dataloader):
    net.train()
    num_iter = (len(dataloader.dataset)//dataloader.batch_size)+1
    for batch_idx, (inputs, labels, path) in enumerate(dataloader):      
        inputs, labels = inputs.cuda(), labels.cuda() 
        optimizer.zero_grad()
        outputs = net(inputs)               
        loss = CEloss(outputs, labels)      
        if args.noise_mode=='asym':  # penalize confident prediction for asymmetric noise
            penalty = conf_penalty(outputs)
            L = loss + penalty      
        elif args.noise_mode=='sym':   
            L = loss
        L.backward()  
        optimizer.step() 

        sys.stdout.write('\r')
        sys.stdout.write('%s:%.1f-%s | Epoch [%3d/%3d] Iter[%3d/%3d]\t CE-loss: %.4f'
                %(args.dataset, args.r, args.noise_mode, epoch, args.num_epochs, batch_idx+1, num_iter, loss.item()))
        sys.stdout.flush()

def test(epoch,net1, net2, meta_net1=None, meta_net2=None, is_warm=True):
    net1.eval()
    net2.eval()
    correct = 0
    total = 0
    if not is_warm:
        meta_net1.eval()
        meta_net2.eval()
        meta_correct = 0

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs1 = net1(inputs)
            outputs2 = net2(inputs)           
            outputs = outputs1+outputs2
            _, predicted = torch.max(outputs, 1)            
                       
            total += targets.size(0)
            correct += predicted.eq(targets).cpu().sum().item()   

            if not is_warm:
                fea1, _ = net1.fea_forward(inputs)
                fea2, _ = net2.fea_forward(inputs)  
                _, meta_outputs_s1 = meta_net1(fea1)
                meta_outputs1 = torch.softmax(meta_outputs_s1,dim=-1)
                _, meta_outputs_s2 = meta_net2(fea2)
                meta_outputs2 = torch.softmax(meta_outputs_s2,dim=-1) 
                meta_outputs = meta_outputs1 + meta_outputs2  
                _, meta_predicted = torch.max(meta_outputs, 1)            
                meta_correct += meta_predicted.eq(targets).cpu().sum().item()  


    acc = 100.*correct/total
    print("\n| Test Epoch #%d\t Accuracy: %.2f%%\n" %(epoch,acc))  
    test_log.write('Epoch:%d   Accuracy:%.2f\n'%(epoch,acc))
    test_log.flush() 
    if not is_warm:
        meta_acc = 100.*meta_correct/total
        print("\n| Test Epoch #%d\t Meta Accuracy: %.2f%%\n" %(epoch,meta_acc))   
        return acc, meta_acc
    else:
        return acc, None

def eval_train(model,all_loss, log, epoch, meta_net, meta_optimizer, global_meta_iter, net_name, warm_up):    
    model.eval()
     
    train_len = 50000

    if use_meta_label < 0:
        meta_net.eval()
        meta_prob = torch.zeros(train_len)
    if args.gmm_ablation:
        meta_logits = torch.zeros(train_len, args.num_class) 
        meta_logits_list = []

    losses = torch.zeros(train_len)    
    labels = torch.zeros(train_len)
    fea_list = []
    label_list = []
    index_list = []
    #if args.adapt_thd:
    logits = torch.zeros(train_len,args.num_class)
    with torch.no_grad():
        for batch_idx, (inputs, targets, index) in enumerate(eval_loader):
            inputs, targets = inputs.cuda(), targets.cuda() 
            fea, outputs = model.fea_forward(inputs)
            loss = CE(outputs, targets)  
            
            if args.gmm_ablation:
                _, meta_logits_batch = meta_net(fea, targets) 
                meta_logits_batch = (torch.softmax(meta_logits_batch,dim=-1) + torch.softmax(outputs,dim=-1))/2 
                for b,j in enumerate(index):
                    meta_logits[j] = meta_logits_batch[b].cpu()
                meta_logits_list.append(meta_logits_batch.cpu())
                fea_list.append(fea.cpu().numpy())
                label_list.append(targets.cpu().numpy())

            if use_meta_label < 0:  
                meta_pred, _ = meta_net(fea, targets) 
                for b,j in enumerate(index):
                    meta_prob[j] = meta_pred[b].item()
            elif not args.gmm_ablation:
                fea_list.append(fea.cpu().numpy())
                label_list.append(targets.cpu().numpy())


            index_list.append(index)

            for b in range(inputs.size(0)):
                losses[index[b]]=loss[b]
                labels[index[b]] = targets[b] 
                
                #if args.adapt_thd:
                logits[index[b]] = torch.softmax(outputs[b],dim=-1).cpu()        
    
    losses = (losses-losses.min())/(losses.max()-losses.min())   
    all_loss.append(losses)

    if args.r==0.9: # average loss over last 5 epochs to improve convergence stability
        history = torch.stack(all_loss)
        input_loss = history[-5:].mean(0)
        input_loss = input_loss.reshape(-1,1)
    else:
        input_loss = losses.reshape(-1,1)
    
    # fit a two-component GMM to the loss
    if args.gmm_ablation:
        #get prob from predicted label and prototypes
        prob = prob_prototype(meta_logits, labels)
    else:
        prob, _, _  = verbose_prob_estimate_both(input_loss, labels, args.num_class, log, all2per) 

    if True:
        meta_net.train()
        if True:
            meta_prob = torch.zeros(train_len)   
            for i in range(len(fea_list)):
                fea = torch.Tensor(fea_list[i]).cuda()
                l = torch.LongTensor(label_list[i])
                l = l.cuda() # learn embedding
                if args.gmm_ablation:
                    batch_logits = meta_logits_list[i].cuda() 
                    pseudo_label = torch.max(batch_logits, dim=1)[1]
                    meta_pred, meta_outputs_s = meta_net(fea, l.view(-1))
                    meta_loss = CEloss(meta_outputs_s, pseudo_label.type_as(l)) 
                else:

                    batch_t = torch.Tensor(prob[index_list[i]]).cuda()

                    select_idx = torch.logical_or(batch_t<args.meta_thd, batch_t>(1-args.meta_thd))
                    
                    # get hard label to train metanet
                    h_t = np.where(prob[index_list[i]] > args.p_threshold, np.ones_like(prob[index_list[i]]), np.zeros_like(prob[index_list[i]]))
                    t = torch.Tensor(h_t).cuda()  


                    fea, l, t = Variable(fea), Variable(l), Variable(t)

                    
                    meta_pred, meta_outputs_s = meta_net(fea, l.view(-1)) 
                    meta_outputs = torch.sigmoid(meta_outputs_s)
                    
                    for b,j in enumerate( index_list[i]):
                        meta_prob[j] = meta_pred[b].item()

                    if not select_idx.sum().item() > 0:
                        #没有符合条件的
                        continue
                    
                    full_t = torch.zeros(t.size(0), args.num_class).type_as(t)
                    full_t = full_t.scatter(1, l.view(-1,1), t.view(-1,1))
                    mask = (torch.ones(t.size(0), args.num_class)).type_as(meta_outputs) * neg_weight 
                        
                    mask = mask * t.view(-1,1) #neg label in noise samples are not used to train metanet 
                    mask = mask.scatter(1, l.view(-1,1), 1)

                    
                    mask = mask * t.view(-1,1)

                    meta_loss = meta_bce(meta_outputs.view(-1), full_t.view(-1))
                    meta_loss = (meta_loss * mask.view(-1)).sum()/mask.sum() 

                
                tf_writer.add_scalar('MetaNet/{}/train_loss'.format(net_name), meta_loss.item(), global_meta_iter)
                global_meta_iter += 1

                meta_optimizer.zero_grad()
                meta_loss.backward()
                meta_optimizer.step()
    if args.gmm_ablation:
        return prob, all_loss, None, None, prob, global_meta_iter, meta_logits 
    else:     
        return prob,all_loss, None, None, meta_prob.numpy(), global_meta_iter,None 

def linear_rampup(current, warm_up, rampup_length=16):
    current = np.clip((current-warm_up) / rampup_length, 0.0, 1.0)
    return args.lambda_u*float(current)

class SemiLoss(object):
    def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch, warm_up):
        if outputs_u.size(0) > 0:
            probs_u = torch.softmax(outputs_u, dim=1)
            Lu = torch.mean((probs_u - targets_u)**2)
        else:
            Lu = 0
        Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
        

        return Lx, Lu, linear_rampup(epoch,warm_up)

class NegEntropy(object):
    def __call__(self,outputs):
        probs = torch.softmax(outputs, dim=1)
        return torch.mean(torch.sum(probs.log()*probs, dim=1))

def create_model():
    model = ResNet18(num_classes=args.num_class)
    model = model.cuda()
    return model

f = open('log_path.txt', 'w+')
  

if os.path.exists('./checkpoint/%s/%.1f_%s_%s'%(args.dataset,args.r,args.noise_mode, args.id)):
    shutil.rmtree('./checkpoint/%s/%.1f_%s_%s'%(args.dataset,args.r,args.noise_mode, args.id))
else:
    os.mkdir('./checkpoint/%s/%.1f_%s_%s'%(args.dataset,args.r,args.noise_mode, args.id))
tf_writer =  SummaryWriter('./checkpoint/%s/%.1f_%s_%s'%(args.dataset,args.r,args.noise_mode, args.id))
f.write('-r ./checkpoint/%s/%.1f_%s_%s\n'%(args.dataset,args.r,args.noise_mode, args.id))

stats_log=open('./checkpoint/%s_%.1f_%s_%s'%(args.dataset,args.r,args.noise_mode, args.id)+'_stats.txt','w') 
f.write('-f ./checkpoint/%s_%.1f_%s_%s'%(args.dataset,args.r,args.noise_mode, args.id)+'_stats.txt\n')
test_log=open('./checkpoint/%s_%.1f_%s_%s'%(args.dataset,args.r,args.noise_mode,args.id)+'_acc.txt','w')   
f.write('-f ./checkpoint/%s_%.1f_%s_%s'%(args.dataset,args.r,args.noise_mode, args.id)+'_acc.txt\n')

f.close()

neg_weight = 1/args.num_class 

#set DNNs warm-up epochs and CPC warm-up epochs
if args.dataset=='cifar10':
    warm_up = 10
    use_meta_label = warm_up + int(0.05*args.num_epochs) 
elif args.dataset=='cifar100':
    warm_up = 30
    if args.r >= 0.8:
        # high noise ratio
        use_meta_label = warm_up + int(0.1*args.num_epochs)
        neg_weight = neg_weight*5 # history loss balance in dividemix 
    else:
        use_meta_label = warm_up + int(0.05*args.num_epochs)

if args.gmm_ablation:
    use_meta_label = -1


noise_file='%s/%.1f_%s.json'%(args.data_path,args.r,args.noise_mode)


loader = dataloader.cifar_dataloader(args.dataset,r=args.r,noise_mode=args.noise_mode,batch_size=args.batch_size,num_workers=5,\
    root_dir=args.data_path,log=stats_log, noise_file=noise_file)

print('| Building net')
net1 = create_model()

net2 = create_model()

meta_net1 = MetaNet_Bin( 512, args.num_class)
meta_net1 = meta_net1.cuda()
meta_net2 = MetaNet_Bin( 512, args.num_class)
meta_net2 = meta_net2.cuda()
cudnn.benchmark = True

criterion = SemiLoss()
optimizer1 = optim.SGD(net1.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
optimizer2 = optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
meta_optimizer1 = optim.SGD(meta_net1.parameters(), lr=args.meta_lr, momentum=0.9, weight_decay=5e-4)
meta_optimizer2 = optim.SGD(meta_net2.parameters(), lr=args.meta_lr, momentum=0.9, weight_decay=5e-4)

CE = nn.CrossEntropyLoss(reduction='none')
CEloss = nn.CrossEntropyLoss()
if args.noise_mode=='asym':
    conf_penalty = NegEntropy()
meta_bce = nn.BCELoss(reduction='none')

all_loss = [[],[]] # save the history of losses from two networks
best_acc = 0
best_epoch = 0
global_iter_1 = 0
global_iter_2 = 0
global_meta_iter_1 = 0
global_meta_iter_2 = 0
for epoch in range(args.num_epochs+1): 
    lr=args.lr
    meta_lr = args.meta_lr
    if epoch >= args.num_epochs/2:
        lr /= 10
        meta_lr /= 10      
    for param_group in optimizer1.param_groups:
        param_group['lr'] = lr       
    for param_group in optimizer2.param_groups:
        param_group['lr'] = lr
    for param_group in meta_optimizer1.param_groups:
        param_group['lr'] = meta_lr 
    for param_group in meta_optimizer2.param_groups:
        param_group['lr'] = meta_lr          
    test_loader = loader.run('test')

    eval_loader = loader.run('eval_train_log')   
    
    if epoch<warm_up:       
        warmup_trainloader = loader.run('warmup')
        print('Warmup Net1')
        warmup(epoch,net1,optimizer1,warmup_trainloader)    
        print('\nWarmup Net2')
        warmup(epoch,net2,optimizer2,warmup_trainloader) 
    else:
      
        prob1,all_loss[0],per_class_thd1, noise_trans_idx1, meta_prob1, global_meta_iter_1, meta_logits_1 =eval_train(net1,all_loss[0], stats_log, epoch, meta_net1, meta_optimizer1, global_meta_iter_1, 'net1', warm_up) 

        prob2,all_loss[1],per_class_thd2, noise_trans_idx2, meta_prob2, global_meta_iter_2, meta_logits_2 =eval_train(net2,all_loss[1], stats_log, epoch, meta_net2, meta_optimizer2, global_meta_iter_2, 'net2', warm_up)          

        if per_class_thd1 is None:
            per_class_thd = None
        else:
            per_class_thd = [(thd1+thd2)/2 for thd1, thd2 in zip(per_class_thd1, per_class_thd2)]

        pred1 = (prob1 > args.p_threshold)      
        pred2 = (prob2 > args.p_threshold)
        meta_pred1 = (meta_prob1 > args.p_threshold)      
        meta_pred2 = (meta_prob2 > args.p_threshold)  


        stats_log.write('Net1: labeled data -> {}, labeled ratio -> {}\n'.format(pred1.sum(), pred1.sum()/pred1.shape[0]))    
        stats_log.write('Net2: labeled data -> {}, labeled ratio -> {}\n'.format(pred2.sum(), pred2.sum()/pred2.shape[0])) 
        tf_writer.add_scalar('Clean/Net1',  pred1.sum()/pred1.shape[0], epoch)
        tf_writer.add_scalar('Clean/Net2',  pred2.sum()/pred2.shape[0], epoch)
        tf_writer.add_scalar('Clean/Meta_Net1',  meta_pred1.sum()/meta_pred1.shape[0], epoch)
        tf_writer.add_scalar('Clean/Meta_Net2',  meta_pred2.sum()/meta_pred2.shape[0], epoch)

        print('Train Net1')
        labeled_trainloader, unlabeled_trainloader = loader.run('train',pred2,prob2,noise_trans_idx=noise_trans_idx2, eval_train_loss = all_loss[1][-1] if not args.gmm_ablation else meta_logits_2, tf_writer=tf_writer, epoch=epoch, model_name='net2', meta_pred = meta_pred2,  meta_prob = meta_prob2, use_meta_label=use_meta_label) # co-divide
 
        global_iter_1, global_meta_iter_1 = train(epoch,net1,net2, meta_net1, optimizer1, meta_optimizer1, labeled_trainloader, unlabeled_trainloader, 'net1', global_iter_1, global_meta_iter_1) # train net1  
        
        print('\nTrain Net2')
        labeled_trainloader, unlabeled_trainloader = loader.run('train',pred1,prob1,noise_trans_idx=noise_trans_idx1, eval_train_loss = all_loss[0][-1] if not args.gmm_ablation else meta_logits_1, tf_writer=tf_writer, epoch=epoch, model_name='net1', meta_pred=meta_pred1, meta_prob=meta_prob1, use_meta_label=use_meta_label) # co-divide


        global_iter_2, global_meta_iter_2 = train(epoch,net2,net1, meta_net2,optimizer2, meta_optimizer2,labeled_trainloader, unlabeled_trainloader, 'net2', global_iter_2, global_meta_iter_2) # train net2         

    if epoch<warm_up: 
        is_warm = True
    else:
        is_warm = False
    acc, meta_acc = test(epoch,net1,net2, meta_net1, meta_net2, is_warm)  
    
    if acc > best_acc:
        best_acc = acc
        best_epoch = epoch

    tf_writer.add_scalar('Test',  acc, epoch) 
    if not is_warm:
        tf_writer.add_scalar('Meta_Test',  meta_acc, epoch)

    print("\n| Current Best Epoch #%d\t Accuracy: %.2f%%\n" %(best_epoch,best_acc))  
    test_log.write('Current Best Epoch:%d   Accuracy:%.2f\n'%(best_epoch,best_acc))
    test_log.flush()  


