from __future__ import print_function
from cProfile import label
import sys
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.models as models
#from torchvision.models.resnet import ResNet 
import random
import os, shutil
import argparse
import numpy as np
import dataloader_clothing1M as dataloader
from torch.utils.tensorboard import SummaryWriter   
from utils import hardness_estimate, verbose_prob_estimate, MetaNet_Bin, prob_prototype
import pickle
parser = argparse.ArgumentParser(description='PyTorch Clothing1M Training')
parser.add_argument('--batch_size', default=32, type=int, help='train batchsize') 
parser.add_argument('--lr', '--learning_rate', default=0.002, type=float, help='initial learning rate')
parser.add_argument('--meta_lr', '--meta_learning_rate', default=0.002, type=float, help='initial meta_learning rate')
parser.add_argument('--alpha', default=0.5, type=float, help='parameter for Beta')
parser.add_argument('--lambda_u', default=0, type=float, help='weight for unsupervised loss')
parser.add_argument('--p_threshold', default=0.5, type=float, help='clean probability threshold')
parser.add_argument('--T', default=0.5, type=float, help='sharpening temperature')
parser.add_argument('--num_epochs', default=80, type=int)
parser.add_argument('--id', default='clothing1m')
parser.add_argument('--data_path', default='./Data/clothing1M', type=str, help='path to dataset')
parser.add_argument('--seed', default=123)
parser.add_argument('--gpuid', default=0, type=int)
parser.add_argument('--num_class', default=14, type=int)
parser.add_argument('--num_batches', default=1000, type=int)
parser.add_argument('--adapt_thd', action='store_true', default=False)
parser.add_argument('--meta_thd', default=0.3, type=float)
parser.add_argument('--use_meta_label', default=5, type=int)
parser.add_argument('--neg_weight', default=1, type=float)
parser.add_argument('--gmm_ablation', action='store_true', default=False)
args = parser.parse_args()

if args.gmm_ablation:
    args.use_meta_label=-1

torch.cuda.set_device(args.gpuid)
np.random.seed(args.seed)
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

# Training
def train(epoch,net,net2, meta_net,optimizer,meta_optimizer,labeled_trainloader,unlabeled_trainloader, per_class_thd=None, log=None, model_name=None, global_iter=0, global_meta_iter=0):
    if args.adapt_thd:
        perclass_thd = torch.Tensor(per_class_thd).cuda()

    net.train()
    meta_net.train()
    net2.eval() #fix one network and train the other


    unlabeled_train_iter = iter(unlabeled_trainloader)  
    num_iter = (len(labeled_trainloader.dataset)//args.batch_size)+1
    for batch_idx, (inputs_x, inputs_x2, inputs_x3, inputs_x4, labels_x, w_x, eval_loss_x) in enumerate(labeled_trainloader):   
        global_iter += 1
        
        global_meta_iter += 1     
        try:
            inputs_u, inputs_u2, inputs_u3, inputs_u4, labels_un, w_u, eval_loss_u = unlabeled_train_iter.next()
        except:
            unlabeled_train_iter = iter(unlabeled_trainloader)
            inputs_u, inputs_u2, inputs_u3, inputs_u4, labels_un, w_u, eval_loss_u = unlabeled_train_iter.next()                 
        
        
        batch_size = inputs_x.size(0)


        # Transform label to one-hot
        labels_x_l = labels_x.view(-1,1)
        labels_x = torch.zeros(batch_size, args.num_class).scatter_(1, labels_x.view(-1,1), 1)        
        w_x = w_x.view(-1,1).type(torch.FloatTensor) 

        labels_un_l = labels_un.view(-1,1)
        labels_un = torch.zeros(labels_un.size(0), args.num_class).scatter_(1, labels_un.view(-1,1), 1)       
        w_u = w_u.view(-1,1).type(torch.FloatTensor)  

        inputs_x, inputs_x2,inputs_x3, inputs_x4, labels_x, w_x, eval_loss_x, labels_x_l = inputs_x.cuda(), inputs_x2.cuda(),inputs_x3.cuda(), inputs_x4.cuda(),  labels_x.cuda(), w_x.cuda(), eval_loss_x.cuda(), labels_x_l.cuda()

        inputs_u, inputs_u2,inputs_u3, inputs_u4, labels_un, w_u, eval_loss_u, labels_un_l = inputs_u.cuda(), inputs_u2.cuda(),inputs_u3.cuda(), inputs_u4.cuda(),  labels_un.cuda(), w_u.cuda(), eval_loss_u.cuda(), labels_un_l.cuda()
     

        #with torch.no_grad():
        # label co-guessing of unlabeled samples
        with torch.no_grad():
            fea_u11, _ = net.fea_forward(inputs_u3)
            fea_u12, _ = net.fea_forward(inputs_u4)
            fea_x1, _ = net.fea_forward(inputs_x3)
            fea_x2, _ = net.fea_forward(inputs_x4)    
        
        if True:
            with torch.no_grad():
                outputs_u11 = net(inputs_u3)
                outputs_u12 = net(inputs_u4)
                outputs_u21 = net2(inputs_u3)
                outputs_u22 = net2(inputs_u4)            
                
                pu11, pu12, pu21, pu22 = torch.softmax(outputs_u11, dim=1) , torch.softmax(outputs_u12, dim=1),  torch.softmax(outputs_u21, dim=1), torch.softmax(outputs_u22, dim=1)
                pu = (pu11+pu12+pu21+pu22)/4  
                
                ptu = pu**(1/args.T) # temparature sharpening
                
                targets_u = ptu / ptu.sum(dim=1, keepdim=True) # normalize
                targets_u = targets_u.detach()  
                

        if True: 
            with torch.no_grad():
                label_l = inputs_x.size(0) 

                outputs_x = net(inputs_x3)
                outputs_x2 = net(inputs_x4)        
                px = (torch.softmax(outputs_x, dim=1) + torch.softmax(outputs_x2, dim=1)) / 2

            px_g = w_x.view(-1,1)*labels_x + (1-w_x.view(-1,1))*px.detach()              
            ptx = px_g**(1/args.T) # temparature sharpening 
                        
            targets_x = ptx / ptx.sum(dim=1, keepdim=True) # normalize           
            #targets_x = targets_x.detach()       
        
        
        
        if label_l > 0: #有干净样本的情况
            # mixmatch
            l = np.random.beta(args.alpha, args.alpha)        
            l = max(l, 1-l)        
            
            all_inputs = torch.cat([inputs_x, inputs_x2, inputs_u, inputs_u2], dim=0)
            all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u], dim=0)

            idx = torch.randperm(all_inputs.size(0))

            input_a, input_b = all_inputs, all_inputs[idx]
            target_a, target_b = all_targets, all_targets[idx]

            
            mixed_input = l * input_a[:label_l*2] + (1 - l) * input_b[:label_l*2]        
            mixed_target = l * target_a[:label_l*2] + (1 - l) * target_b[:label_l*2]
            
            mixed_input = Variable(mixed_input)


            logits = net(mixed_input) 
            Lx = -torch.mean(torch.sum(F.log_softmax(logits, dim=1) * mixed_target.detach(), dim=1))
        
            # regularization
            prior = torch.ones(args.num_class)/args.num_class
            prior = prior.cuda()        
            pred_mean = torch.softmax(logits, dim=1).mean(0)
            penalty = torch.sum(prior*torch.log(prior/pred_mean))
            
            loss = Lx + penalty
                
            tf_writer.add_scalar(model_name+'/Loss', loss.detach(), global_iter)
            tf_writer.add_scalar(model_name+'/Lx', Lx.detach(), global_iter)
            tf_writer.add_scalar(model_name+'/penalty', penalty.detach(), global_iter)
        

            if args.gmm_ablation:
                _, meta_output_u1_s = meta_net(fea_u11, labels_un_l)
                _, meta_output_u2_s = meta_net(fea_u12, labels_un_l) 
                pseudo_label_u = torch.max(eval_loss_u, dim=1)[1] 

                meta_reg_u1 = CEloss(meta_output_u1_s, pseudo_label_u.type_as(labels_un_l))
                meta_reg_u2 = CEloss(meta_output_u2_s, pseudo_label_u.type_as(labels_un_l)) 

  
                _, meta_output_x1_s = meta_net(fea_x1, labels_x_l) 
                _, meta_output_x2_s = meta_net(fea_x2, labels_x_l)
                pseudo_label_x = torch.max(eval_loss_x, dim=1)[1]
                meta_reg_x1 = CEloss(meta_output_x1_s, pseudo_label_x.type_as(labels_x_l))
                meta_reg_x2 = CEloss(meta_output_x2_s, pseudo_label_x.type_as(labels_x_l)) 

                meta_reg = (meta_reg_u1 + meta_reg_u2 + meta_reg_x1 + meta_reg_x2)/4 

                tf_writer.add_scalar('MetaNet/{}/train_loss'.format(model_name), meta_reg.detach(), global_meta_iter) 

                meta_optimizer.zero_grad()
                meta_reg.backward()
                meta_optimizer.step() 
            
            else:              
                labels_un_l = labels_un_l.view(-1)
                if inputs_u.size(0) > 0:
                    
                    meta_pred_u1, meta_output_u1_s = meta_net(fea_u11, labels_un_l) 
                    meta_output_u1 = torch.sigmoid(meta_output_u1_s)

                    meta_pred_u2, meta_output_u2_s = meta_net(fea_u12, labels_un_l) 
                    meta_output_u2 = torch.sigmoid(meta_output_u2_s)


                    if args.use_meta_label < epoch and args.use_meta_label > 0:
                        l_u = eval_loss_u.type_as(w_u)
                    else:
                        l_u = w_u
                    select_idx_u = l_u.view(-1) < min(args.meta_thd, args.p_threshold)
                    if True:
                        l_u = torch.zeros_like(w_u).cuda()
                        if args.adapt_thd:
                            selected_u = (pu > perclass_thd.view(-1)).nonzero()  
                                
                            meta_pred_u_pos1 = meta_output_u1[selected_u[:,0], selected_u[:,1]]
                            meta_pred_u_pos2 = meta_output_u2[selected_u[:,0], selected_u[:,1]]
                            l_u_pos = torch.ones_like(meta_pred_u_pos1).type_as(l_u)
                            meta_pred_u1 = torch.cat([meta_pred_u1.view(-1), meta_pred_u_pos1.view(-1)]) 
                            meta_pred_u2 = torch.cat([meta_pred_u2.view(-1), meta_pred_u_pos2.view(-1)]) 
                            l_u = torch.cat([l_u.view(-1), l_u_pos.view(-1)])
                            select_idx_u = torch.cat([select_idx_u, l_u_pos])  
                    
                    meta_reg_u1 = meta_bce(meta_pred_u1.view(-1), l_u.view(-1))
                    meta_reg_u2 = meta_bce(meta_pred_u2.view(-1), l_u.view(-1))
                    
                        

                labels_x_l = labels_x_l.view(-1)
                if inputs_x.size(0) > 0:
                    meta_pred_x1, meta_output_x1_s = meta_net(fea_x1, labels_x_l) 
                    meta_output_x1 = torch.sigmoid(meta_output_x1_s) 
                    
                   
                    meta_pred_x2, meta_output_x2_s = meta_net(fea_x2, labels_x_l) 
                    meta_output_x2 = torch.sigmoid(meta_output_x2_s)
                    
                    if args.use_meta_label < epoch and args.use_meta_label > 0:
                        l_x = eval_loss_x.type_as(w_x)
                    else:
                        l_x = w_x
                    
                    select_idx_x = l_x.view(-1) > max(1-args.meta_thd, 1-args.p_threshold)


                        
                    l_x = torch.ones_like(w_x).cuda()
                        

                    if True:
                        full_t = torch.zeros(labels_x_l.size(0), args.num_class).type_as(l_x)
                        full_t = full_t.scatter(1, labels_x_l.view(-1,1), l_x.view(-1,1))

                       
                        mask = (torch.ones(l_x.size(0), args.num_class) * torch.rand(l_x.size(0), args.num_class) < args.neg_ratio).type_as(meta_output_x1) * args.neg_weight  
                        
                        mask = mask * l_x.view(-1,1)
                        mask = mask.scatter(1, labels_x_l.view(-1,1), 1)
                        mask = mask * select_idx_x.view(-1,1)     

                        meta_reg_x1 = meta_bce(meta_output_x1.view(-1), full_t.view(-1)) 

                        
                        meta_reg_x2 = meta_bce(meta_output_x2.view(-1), full_t.view(-1))



                meta_reg = ((meta_reg_u1 * select_idx_u).sum() + (meta_reg_u2 * select_idx_u).sum() + (meta_reg_x1 * mask.view(-1)).sum() + (meta_reg_x2 * mask.view(-1)).sum() )/(select_idx_u.sum()*2 + mask.sum()*2) 

                tf_writer.add_scalar('MetaNet/{}/train_loss'.format(model_name), meta_reg.detach(), global_meta_iter) 

                meta_optimizer.zero_grad()
                    
                meta_reg.backward()
                meta_optimizer.step()  


            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step() 

            sys.stdout.write('\r')
            sys.stdout.write('Clothing1M | Epoch [%3d/%3d] Iter[%3d/%3d]\t  Labeled loss: %.4f'
                    %(epoch, args.num_epochs, batch_idx+1, num_iter, Lx.item()))
            
        else:
            print("!!!!! Warning! All clean samples are rejected!!!")
            log.write("!!!!! Warning! All clean samples are rejected!!!")

  
        sys.stdout.flush()

    return global_iter, global_meta_iter
        
    
def warmup(net,optimizer,dataloader, net2=None):
    #net2, inputs2, constrastloss, embed_queue
    net.train()
    if False:#args.contrast_train:
        net2.eval()
        queue_len = args.query_queue
        emb_queue = torch.zeros(queue_len, 128).cuda()
    for batch_idx, (inputs, labels, path, inputs2) in enumerate(dataloader):      
        inputs,inputs2, labels = inputs.cuda(),inputs2.cuda(), labels.cuda() 
        optimizer.zero_grad()
        if False:#args.contrast_train: 
            outputs, emb_q = net.feature_forward(inputs)
            with torch.no_grad():
                _, emb2 = net2.feature_forward(inputs2)
                emb_queue=torch.cat((emb2, emb_queue))
                emb_queue = emb_queue[:queue_len,:]

            const_loss = nce_caller(emb_q, emb_queue) 
        else:
            outputs = net(inputs)              
        loss = CEloss(outputs, labels)  
        
        
        penalty = conf_penalty(outputs)
        L = loss + penalty       
        if  False:#args.contrast_train:
            L = L + const_loss 
        L.backward()  
        optimizer.step() 

        sys.stdout.write('\r')
        if False:# args.contrast_train:
            sys.stdout.write('|Warm-up: Iter[%3d/%3d]\t CE-loss: %.4f  Conf-Penalty: %.4f NCELoss: %.4f'
                    %(batch_idx+1, args.num_batches, loss.item(), penalty.item(), const_loss.item()))
        else: 
            sys.stdout.write('|Warm-up: Iter[%3d/%3d]\t CE-loss: %.4f  Conf-Penalty: %.4f'
                    %(batch_idx+1, args.num_batches, loss.item(), penalty.item()))
        sys.stdout.flush()
    
def val(net,val_loader,k):
    net.eval()
    correct = 0
    total = 0
    perclass_num = [0]*14
    perclass_hit = [0]*14
    perclass_est = [0.1]*14
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(val_loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs = net(inputs)
            _, predicted = torch.max(outputs, 1)         

            for t,p in zip(targets, predicted):
                #if args.strict_ensemble:
                perclass_est[p.item()] += 1 
                perclass_num[t.item()] += 1
                if t.item() == p.item():
                    perclass_hit[t.item()] += 1

            total += targets.size(0)
            correct += predicted.eq(targets).cpu().sum().item()              
    acc = 100.*correct/total
    pre = [hit/num for hit, num in zip(perclass_hit, perclass_est)] 
    recall = [hit/num for hit, num in zip(perclass_hit, perclass_num)] 

    print("\n| Validation\t Net%d  Acc: %.2f%%" %(k,acc))
    print([[p,r] for p,r in zip(pre, recall)])
    
    if acc > best_acc[k-1]:
        best_acc[k-1] = acc
        print('| Saving Best Net%d ...'%k)
        save_point = './checkpoint/%s_net%d.pth.tar'%(args.id,k)
        torch.save(net.state_dict(), save_point)
    return acc, pre, recall

def test(net1,net2,test_loader):
    
    def perclass_count(targ, pred,per_num, per_hit, idx):
        for t,p in zip(targ, pred):
            per_num[idx][t.item()] += 1
            if t.item() == p.item():
                per_hit[idx][t.item()] += 1
        return per_num, per_hit
    
    net1.eval()
    net2.eval()
    correct = 0
    total = 0
    correct1 = 0
    correct2 = 0

    perclass_num = [[0]*14, [0]*14, [0]*14]
    perclass_hit = [[0]*14, [0]*14, [0]*14]
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs1 = net1(inputs)       
            outputs2 = net2(inputs)           
            outputs = outputs1+outputs2
            _, predicted = torch.max(outputs, 1)            
            _, predicted1 = torch.max(outputs1, 1)            
            _, predicted2 = torch.max(outputs2, 1)          
            total += targets.size(0)
            correct += predicted.eq(targets).cpu().sum().item()
            correct1 += predicted1.eq(targets).cpu().sum().item()                 
            correct2 += predicted2.eq(targets).cpu().sum().item()
            
            perclass_num, perclass_hit = perclass_count(targets, predicted,perclass_num, perclass_hit, 0)
            perclass_num, perclass_hit = perclass_count(targets, predicted1,perclass_num, perclass_hit, 1) 
            perclass_num, perclass_hit = perclass_count(targets, predicted2,perclass_num, perclass_hit, 2) 

    perclass_acc = [[100.*hit/num for hit, num in zip(perclass_hit[i], perclass_num[i])] for i in range(len(perclass_num))]
    for i,Acc in enumerate(perclass_acc):
        print("Pred {}".format(i))
        print(Acc)
    
    acc = 100.*correct/total
    acc1 = 100.*correct1/total
    acc2 = 100.*correct2/total
    print("\n| Test Acc: %.2f, %.2f, %.2f\n" %(acc, acc1, acc2))  
    return acc,acc1,acc2,perclass_acc    
    
def eval_train(epoch,model, log, meta_net, meta_optimizer, global_meta_iter, net_name):
    model.eval()
    
    
    num_samples = args.num_batches*args.batch_size
    meta_net.eval()
    meta_prob = torch.zeros(num_samples)

    if args.gmm_ablation:
        meta_logits = torch.zeros(num_samples, args.num_class) 
        meta_logits_list = []    
    

    losses = torch.zeros(num_samples)
    labels = torch.zeros(num_samples) 
    if args.adapt_thd:
        logits = torch.zeros(num_samples,args.num_class)  
    paths = []
    n=0
    fea_list = []
    label_list = []
    index_list = []
    with torch.no_grad():
        for batch_idx, (inputs, targets, path) in enumerate(eval_loader):
            inputs, targets = inputs.cuda(), targets.cuda()
            if True: 
                fea, outputs = model.fea_forward(inputs)
                
                if args.gmm_ablation:
                    _, meta_logits_batch = meta_net(fea, targets) 
                    meta_logits_batch = (torch.softmax(meta_logits_batch,dim=-1) + torch.softmax(outputs,dim=-1))/2 
                    meta_logits_list.append(meta_logits_batch.cpu())
                    fea_list.append(fea.cpu().numpy())
                    label_list.append(targets.cpu().numpy())                
                else:
                    fea_list.append(fea.cpu().numpy())
                    label_list.append(targets.cpu().numpy())
                meta_pred, _ = meta_net(fea, targets) 
                
            loss = CE(outputs, targets) 

            tmp_index_list = []
            for b in range(inputs.size(0)):
                meta_prob[n] = meta_pred[b]
                if args.gmm_ablation:
                    meta_logits[n] = meta_logits_batch[b].cpu()
                losses[n]=loss[b]
                labels[n] = targets[b]
                if args.adapt_thd:
                    logits[n] = torch.softmax(outputs[b],dim=-1).cpu() 
                paths.append(path[b])
                tmp_index_list.append(n)
                n+=1
            index_list.append(tmp_index_list)
            sys.stdout.write('\r')
            sys.stdout.write('| Evaluating loss Iter %3d\t' %(batch_idx)) 
            sys.stdout.flush()
    losses = (losses-losses.min())/(losses.max()-losses.min())    
    losses = losses.reshape(-1,1)

    # fit a two-component GMM to the loss
    if args.gmm_ablation:
        #get prob from predicted label and prototypes
        prob = prob_prototype(meta_logits, labels)
    else:
        prob, _, _ = verbose_prob_estimate(losses, labels,  args.num_class, log, True)  

    if args.adapt_thd:
        if epoch > 5:
            per_class_thd, trans_noise_idx = hardness_estimate(logits, prob> args.p_threshold, args.num_class, args.percentile, log, epoch, args.num_epochs,args.thd_decay, args.square_decay)
        else:
            per_class_thd = [1.0] * args.num_class
        per_class_thd = torch.Tensor(per_class_thd).type_as(losses).cuda()
    else:
        per_class_thd = None

    if True:
        meta_net.train()
        
        if True:
            meta_prob = torch.zeros(num_samples)   
            for i in range(len(fea_list)):
                fea = torch.Tensor(fea_list[i]).cuda()
                l = torch.LongTensor(label_list[i])
                
                l = l.cuda() # learn embedding

                if args.gmm_ablation:
                    batch_logits = meta_logits_list[i].cuda() 
                    pseudo_label = torch.max(batch_logits, dim=1)[1]
                    meta_pred, meta_outputs_s = meta_net(fea, l.view(-1))
                    meta_loss = CEloss(meta_outputs_s, pseudo_label.type_as(l)) 
                else:    
                    batch_t = torch.Tensor(prob[index_list[i]]).cuda()

                    if args.adapt_thd:
                        batch_logits = logits[index_list[i]].cuda()
                    
                    select_idx = torch.logical_or(batch_t<args.meta_thd, batch_t>(1-args.meta_thd))

                    if True:
                        h_t = np.where(prob[index_list[i]] > args.p_threshold, np.ones_like(prob[index_list[i]]), np.zeros_like(prob[index_list[i]]))
                        t = torch.Tensor(h_t).cuda()  

                    fea, l, t = Variable(fea), Variable(l), Variable(t)

                    
                    meta_pred, meta_outputs_s = meta_net(fea, l.view(-1)) 
                    meta_outputs = torch.sigmoid(meta_outputs_s)
                    
                    for b,j in enumerate( index_list[i]):
                        meta_prob[j] = meta_pred[b].item()


                    
                    if True:
                        full_t = torch.zeros(t.size(0), args.num_class).type_as(t)
                        full_t = full_t.scatter(1, l.view(-1,1), t.view(-1,1))
                        mask = (torch.ones(t.size(0), args.num_class) * torch.rand(t.size(0), args.num_class) < args.neg_ratio).type_as(meta_outputs) * args.neg_weight 
                        
                        mask = mask * t.view(-1,1) 
                        mask = mask.scatter(1, l.view(-1,1), 1)

                        if args.adapt_thd:
                            full_t[batch_logits > per_class_thd.view(-1)] = 1
                            mask[batch_logits > per_class_thd.view(-1)] = 1 
                            
                                
                        mask = mask * select_idx.view(-1,1)

                        meta_loss = meta_bce(meta_outputs.view(-1), full_t.view(-1))
                        meta_loss = (meta_loss * mask.view(-1)).sum()/mask.sum()
    
                    
                
                tf_writer.add_scalar('MetaNet/{}/train_loss'.format(net_name), meta_loss.item(), global_meta_iter)
                global_meta_iter += 1

                meta_optimizer.zero_grad()
                meta_loss.backward()
                meta_optimizer.step()

    
    if args.gmm_ablation:
        return prob, paths, per_class_thd, None,  None, None, meta_logits.numpy(), meta_prob.numpy(), global_meta_iter 
        
    elif args.adapt_thd:
        return prob, paths, per_class_thd.cpu().numpy(), None,  None, None, losses.cpu().numpy(), meta_prob.numpy(), global_meta_iter
    else:
        return prob, paths, per_class_thd, None,  None, None, losses.cpu().numpy(), meta_prob.numpy(), global_meta_iter



class NCELoss(object):
    def __call__(self,q,k):
        scores = torch.einsum('nc,kc->nk', [q, k])
        # apply temperature
        scores /= 0.07
        labels = torch.arange(q.size(0)).type(torch.LongTensor).cuda()
        loss = F.cross_entropy(scores, labels)
        return loss
    
class NegEntropy(object):
    def __call__(self,outputs):
        probs = torch.softmax(outputs, dim=1)
        return torch.mean(torch.sum(probs.log()*probs, dim=1))


class MyResNet50(nn.Module):
    def __init__(self, pretrained=True, class_num=14):
        super(MyResNet50,self).__init__()
        self.model = models.resnet50(pretrained=pretrained)
        self.model.fc = nn.Linear(2048,class_num)
        #self.mlp = nn.Linear(2048, 128)
    
    def forward(self,x):
        return self.model(x)

    def fea_forward(self,x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)

        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)

        x = self.model.avgpool(x)
        fea = torch.flatten(x, 1)

        x = self.model.fc(fea)
        
        return fea, x

        

def create_model():
    model = MyResNet50(pretrained=True, class_num = args.num_class)
    model = model.cuda()
    return model  

f = open('log_path.txt', 'w+')



if os.path.exists('./checkpoint/clothing1m/%s'%args.id):
    shutil.rmtree('./checkpoint/clothing1m/%s'%args.id)
else:
    os.mkdir('./checkpoint/clothing1m/%s'%args.id)
tf_writer =  SummaryWriter('./checkpoint/clothing1m/%s'%args.id)
f.write('-r ./checkpoint/clothing1m/%s\n'%args.id)


log=open('./checkpoint/%s.txt'%args.id,'w')

log.flush()
f.write('-f ./checkpoint/%s.txt\n'%args.id)


loader = dataloader.clothing_dataloader(root=args.data_path,batch_size=args.batch_size,num_workers=5,num_batches=args.num_batches)

print('| Building net')

net1 = create_model()
net2 = create_model()
meta_net1 = MetaNet_Bin( 2048, args.num_class)
meta_net1 = meta_net1.cuda()
meta_net2 = MetaNet_Bin( 2048, args.num_class)
meta_net2 = meta_net2.cuda()
cudnn.benchmark = True

optimizer1 = optim.SGD(net1.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-3)
optimizer2 = optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-3)
meta_optimizer1 = optim.SGD(meta_net1.parameters(), lr=args.meta_lr, momentum=0.9, weight_decay=1e-3)
meta_optimizer2 = optim.SGD(meta_net2.parameters(), lr=args.meta_lr, momentum=0.9, weight_decay=1e-3)

CE = nn.CrossEntropyLoss(reduction='none')
CEloss = nn.CrossEntropyLoss()
conf_penalty = NegEntropy()
nce_caller = NCELoss()
meta_bce = nn.BCELoss(reduction='none')
best_acc = [0,0]
global_iter_1 = 0
global_iter_2 = 0
global_meta_iter_1 = 0
global_meta_iter_2 = 0

if True:
    for epoch in range(args.num_epochs+1):   
        lr=args.lr
        meta_lr = args.meta_lr
        if epoch >= 40:
            lr /= 10
            meta_lr /= 10       
        for param_group in optimizer1.param_groups:
            param_group['lr'] = lr     
        for param_group in optimizer2.param_groups:
            param_group['lr'] = lr    
        for param_group in meta_optimizer1.param_groups:
            param_group['lr'] = meta_lr 
        for param_group in meta_optimizer2.param_groups:
            param_group['lr'] = meta_lr   

        if epoch<1:     # DNNs warm up  
            train_loader = loader.run('warmup')
            print('Warmup Net1')
            warmup(net1,optimizer1,train_loader,net2)     
            train_loader = loader.run('warmup')
            print('\nWarmup Net2')
            warmup(net2,optimizer2,train_loader,net1)                  
        else:       
            pred1 = (prob1 > args.p_threshold)  # divide dataset  
            pred2 = (prob2 > args.p_threshold)
            
            meta_pred1 = (meta_prob1 > args.p_threshold)      
            meta_pred2 = (meta_prob2 > args.p_threshold)

            log.write('Net1: labeled data -> {}, labeled ratio -> {}'.format(pred1.sum(), pred1.sum()/pred1.shape[0]))   
            
            tf_writer.add_scalar('Clean_Net1',  pred1.sum()/pred1.shape[0], epoch) 

            log.write('Net2: labeled data -> {}, labeled ratio -> {}'.format(pred2.sum(), pred2.sum()/pred2.shape[0]))   
            
            tf_writer.add_scalar('Clean_Net2',  pred2.sum()/pred2.shape[0], epoch)

            log.write('Net1: meta labeled data -> {}, meta labeled ratio -> {}'.format(meta_pred1.sum(), meta_pred1.sum()/meta_pred1.shape[0]))   
            
            tf_writer.add_scalar('Clean/Meta_Net1',  meta_pred1.sum()/meta_pred1.shape[0], epoch) 

            log.write('Net2: meta labeled data -> {}, meta labeled ratio -> {}'.format(meta_pred2.sum(), meta_pred2.sum()/meta_pred2.shape[0]))   
            
            tf_writer.add_scalar('Clean/Meta_Net2',  meta_pred2.sum()/meta_pred2.shape[0], epoch)



            print('\n\nTrain Net1')
            labeled_trainloader, unlabeled_trainloader = loader.run('train',pred2,prob2,paths=paths2,  eval_train_loss=loss2, epoch=epoch, meta_pred = meta_pred2,  meta_prob = meta_prob2, use_meta_label=args.use_meta_label) # co-divide

            #with torch.autograd.set_detect_anomaly(True):
            global_iter_1, global_meta_iter_1 = train(epoch,net1,net2, meta_net1,optimizer1, meta_optimizer1,labeled_trainloader, unlabeled_trainloader, per_class_thd2,log,model_name='net1',global_iter= global_iter_1, global_meta_iter= global_meta_iter_1)              # train net1
            print('\nTrain Net2')
            labeled_trainloader, unlabeled_trainloader = loader.run('train',pred1,prob1,paths=paths1,  eval_train_loss=loss1,epoch=epoch, meta_pred=meta_pred1, meta_prob=meta_prob1, use_meta_label=args.use_meta_label) # co-divide

            global_iter_2, global_meta_iter_2 =train(epoch,net2,net1, meta_net2,optimizer2,meta_optimizer2,labeled_trainloader, unlabeled_trainloader,per_class_thd1,log, model_name='net2',global_iter=global_iter_2, global_meta_iter=global_meta_iter_2)              # train net2
        
        val_loader = loader.run('val') # validation
        acc1, pre_val1, recall_val1 = val(net1,val_loader,1)
        acc2, pre_val2, recall_val2 = val(net2,val_loader,2)
        tf_writer.add_scalar('Val/Net1_Acc', acc1,  epoch)  
        tf_writer.add_scalar('Val/Net2_Acc', acc2,  epoch)

        log.write('Validation Epoch:%d      Acc1:%.2f  Acc2:%.2f\n'%(epoch,acc1,acc2))
        log.write("\n| Validation\t Net 1  Perclass Acc: %s" %(',\n'.join([str(v) for v in pre_val1])))
        log.write("\n| Validation\t Net 2  Perclass Acc: %s" %(',\n'.join([str(v) for v in pre_val2])))    
        log.flush() 
        print('\n==== net 1 evaluate next epoch training data loss ====') 
        eval_loader = loader.run('eval_train')  # evaluate training data loss for next epoch 
        
        prob1,paths1,per_class_thd1, _, _, _, loss1, meta_prob1,global_meta_iter_1 = eval_train(epoch,net1,log, meta_net=meta_net1,meta_optimizer=meta_optimizer1,net_name='net1', global_meta_iter=global_meta_iter_1) 
        print('\n==== net 2 evaluate next epoch training data loss ====') 
        eval_loader = loader.run('eval_train')  
        
        prob2,paths2,per_class_thd2, _,_, _, loss2, meta_prob2,global_meta_iter_2 = eval_train(epoch,net2,log, meta_net=meta_net2, meta_optimizer=meta_optimizer2,net_name='net2', global_meta_iter=global_meta_iter_2) 



test_loader = loader.run('test')
net1.load_state_dict(torch.load('./checkpoint/%s_net1.pth.tar'%args.id))
net2.load_state_dict(torch.load('./checkpoint/%s_net2.pth.tar'%args.id))
acc, acc1, acc2, perclass_acc = test(net1,net2,test_loader)      

log.write('Test Accuracy:%.2f, %.2f, %.2f\n'%(acc, acc1, acc2))
log.write('Perclass Test Accuracy Enssemble:%s\n'%(', \n'.join([str(v) for v in perclass_acc[0]])))
log.write('Perclass Test Accuracy 1:%s\n'%(', \n'.join([str(v) for v in perclass_acc[1]])))
log.write('Perclass Test Accuracy 2:%s\n'%(', \n'.join([str(v) for v in perclass_acc[2]])))
log.flush()


