from __future__ import print_function
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import random
import os, shutil
import sys
import argparse
import numpy as np
from InceptionResNetV2 import *
import torchvision.models as models
from sklearn.mixture import GaussianMixture
import dataloader_webvision as dataloader
import torchnet
import torch.multiprocessing as mp
#import multiprocess as mp
from utils import verbose_prob_estimate, MetaNet_Bin, hardness_estimate, prob_prototype
from torch.utils.tensorboard import SummaryWriter 
import pickle as pkl
from torch.autograd import Variable
import copy
parser = argparse.ArgumentParser(description='PyTorch WebVision Parallel Training')
parser.add_argument('--batch_size', default=32, type=int, help='train batchsize') 
parser.add_argument('--lr', '--learning_rate', default=0.01, type=float, help='initial learning rate')
parser.add_argument('--meta_lr', '--meta_learning_rate', default=0.01, type=float, help='initial meta_learning rate')
parser.add_argument('--alpha', default=0.5, type=float, help='parameter for Beta')
parser.add_argument('--lambda_u', default=0, type=float, help='weight for unsupervised loss')
parser.add_argument('--p_threshold', default=0.5, type=float, help='clean probability threshold')
parser.add_argument('--T', default=0.5, type=float, help='sharpening temperature')
parser.add_argument('--num_epochs', default=100, type=int)
parser.add_argument('--id', default='',type=str)
parser.add_argument('--seed', default=123)
parser.add_argument('--gpuid1', default=0, type=int)
parser.add_argument('--gpuid2', default=1, type=int)
parser.add_argument('--num_class', default=50, type=int)
parser.add_argument('--data_path', default='./Data/webvision/', type=str, help='path to dataset')
parser.add_argument('--meta_thd', default=0.5, type=float)
parser.add_argument('--use_meta_label', default=5, type=int)
parser.add_argument('--gmm_ablation', action='store_true', default=False)
args = parser.parse_args()

if args.gmm_ablation:
    args.use_meta_label=-1

os.environ["CUDA_VISIBLE_DEVICES"] = '%s,%s'%(args.gpuid1,args.gpuid2)
random.seed(args.seed)
cuda1 = torch.device('cuda:0')
cuda2 = torch.device('cuda:1')

# Training
def train(epoch,net,net2,optimizer,labeled_trainloader,unlabeled_trainloader,device,model_name, global_iter=0, meta_net=None, meta_optimizer=None, global_meta_iter=0, perclass_thd=None, queue=None):
    
    perclass_thd = torch.Tensor(perclass_thd).to(device)
    CEloss = nn.CrossEntropyLoss()
    tf_list = []
    net.train()
    net2.eval() #fix one network and train the other
    meta_net.train()

    meta_bce = nn.BCELoss(reduction='none')

    unlabeled_train_iter = iter(unlabeled_trainloader)    
    num_iter = (len(labeled_trainloader.dataset)//args.batch_size)+1
    for batch_idx, (inputs_x, inputs_x2, inputs_x3, inputs_x4, labels_x, w_x, eval_loss_x) in enumerate(labeled_trainloader):      
        try:
            inputs_u, inputs_u2, inputs_u3, inputs_u4, labels_un, w_u, eval_loss_u = unlabeled_train_iter.next()
        except:
            unlabeled_train_iter = iter(unlabeled_trainloader)
            inputs_u, inputs_u2, inputs_u3, inputs_u4, labels_un, w_u, eval_loss_u = unlabeled_train_iter.next() 

        global_iter += 1                    
        batch_size = inputs_x.size(0)
        
        # Transform label to one-hot
        labels_x_l = labels_x.view(-1,1)
        labels_x = torch.zeros(batch_size, args.num_class).scatter_(1, labels_x.view(-1,1), 1)        
        w_x = w_x.view(-1,1).type(torch.FloatTensor) 

        labels_un_l = labels_un.view(-1,1)
        labels_un = torch.zeros(labels_un.size(0), args.num_class).scatter_(1, labels_un.view(-1,1), 1)       
        w_u = w_u.view(-1,1).type(torch.FloatTensor)


        inputs_x, inputs_x2, inputs_x3, inputs_x4, labels_x_l, labels_x, w_x, eval_loss_x = inputs_x.to(device,non_blocking=True), inputs_x2.to(device,non_blocking=True), inputs_x3.to(device,non_blocking=True), inputs_x4.to(device,non_blocking=True), labels_x_l.to(device,non_blocking=True), labels_x.to(device,non_blocking=True), w_x.to(device,non_blocking=True), eval_loss_x.to(device,non_blocking=True)

        inputs_u, inputs_u2, inputs_u3, inputs_u4, labels_un_l, labels_un, w_u, eval_loss_u  = inputs_u.to(device,non_blocking=True), inputs_u2.to(device,non_blocking=True), inputs_u3.to(device,non_blocking=True), inputs_u4.to(device,non_blocking=True), labels_un_l.to(device,non_blocking=True), labels_un.to(device,non_blocking=True), w_u.to(device,non_blocking=True), eval_loss_u.to(device,non_blocking=True)


        with torch.no_grad():
            # label co-guessing of unlabeled samples
           
            fea_u11 = net.features(inputs_u3)
            fea_u12 = net.features(inputs_u4)
            fea_x1 = net.features(inputs_x3)
            fea_x2 = net.features(inputs_x4) 

            outputs_u11 = net(inputs_u)
            outputs_u12 = net(inputs_u2)
            outputs_u21 = net2(inputs_u)
            outputs_u22 = net2(inputs_u2)            
            
            pu = (torch.softmax(outputs_u11, dim=1) + torch.softmax(outputs_u12, dim=1) + torch.softmax(outputs_u21, dim=1) + torch.softmax(outputs_u22, dim=1)) / 4 

       
            ptu = pu**(1/args.T) # temparature sharpening
            
            targets_u = ptu / ptu.sum(dim=1, keepdim=True) # normalize
            targets_u = targets_u.detach()       
            
            # label refinement of labeled samples
            outputs_x = net(inputs_x)
            outputs_x2 = net(inputs_x2)            
            
            px = (torch.softmax(outputs_x, dim=1) + torch.softmax(outputs_x2, dim=1)) / 2
            px = w_x*labels_x + (1-w_x)*px              
            ptx = px**(1/args.T) # temparature sharpening 
                       
            targets_x = ptx / ptx.sum(dim=1, keepdim=True) # normalize           
            targets_x = targets_x.detach()       
        
        # mixmatch
        l = np.random.beta(args.alpha, args.alpha)        
        l = max(l, 1-l)

        if True:
            all_inputs = torch.cat([inputs_x, inputs_x2, inputs_u, inputs_u2], dim=0)
            all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u], dim=0)

            idx = torch.randperm(all_inputs.size(0))

            input_a, input_b = all_inputs, all_inputs[idx]
            target_a, target_b = all_targets, all_targets[idx]

            mixed_input = l * input_a[:batch_size*2] + (1 - l) * input_b[:batch_size*2]        
            mixed_target = l * target_a[:batch_size*2] + (1 - l) * target_b[:batch_size*2]
                
        logits = net(mixed_input)
        
        Lx = -torch.mean(torch.sum(F.log_softmax(logits, dim=1) * mixed_target, dim=1))
        
        prior = torch.ones(args.num_class)/args.num_class
        prior = prior.to(device)        
        pred_mean = torch.softmax(logits, dim=1).mean(0)
        penalty = torch.sum(prior*torch.log(prior/pred_mean))
       
        loss = Lx + penalty
        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        tf_list.append([model_name+'/Loss', loss.item(), global_iter])
        tf_list.append([model_name+'/Lx', Lx.item(), global_iter])
        tf_list.append([model_name+'/penalty', penalty.item(), global_iter])
        
        sys.stdout.write('\n')
        sys.stdout.write('%s |%s Epoch [%3d/%3d] Iter[%4d/%4d]\t Labeled loss: %.2f'
                %(args.id, model_name, epoch, args.num_epochs, batch_idx+1, num_iter, Lx.item()))
        sys.stdout.flush()

        del all_inputs
        del mixed_input 
        del mixed_target

        if  args.gmm_ablation:
            _, meta_output_u1_s = meta_net(fea_u11, labels_un_l)
            _, meta_output_u2_s = meta_net(fea_u12, labels_un_l) 
            pseudo_label_u = torch.max(eval_loss_u, dim=1)[1] 

            meta_reg_u1 = CEloss(meta_output_u1_s, pseudo_label_u.type_as(labels_un_l))
            meta_reg_u2 = CEloss(meta_output_u2_s, pseudo_label_u.type_as(labels_un_l)) 

  
            _, meta_output_x1_s = meta_net(fea_x1, labels_x_l) 
            _, meta_output_x2_s = meta_net(fea_x2, labels_x_l)
            pseudo_label_x = torch.max(eval_loss_x, dim=1)[1]
            meta_reg_x1 = CEloss(meta_output_x1_s, pseudo_label_x.type_as(labels_x_l))
            meta_reg_x2 = CEloss(meta_output_x2_s, pseudo_label_x.type_as(labels_x_l)) 

            meta_reg = (meta_reg_u1 + meta_reg_u2 + meta_reg_x1 + meta_reg_x2)/4 

            tf_list.append(['MetaNet/{}/train_loss'.format(model_name), meta_reg.item(), global_meta_iter] )
            meta_optimizer.zero_grad()
            meta_reg.backward()
            meta_optimizer.step()  
            global_meta_iter += 1 
        
        else:
            
                    
            labels_un_l = labels_un_l.view(-1)
            
            meta_pred_u1, meta_output_u1_s = meta_net(fea_u11, labels_un_l) 
            meta_pred_u2, meta_output_u2_s = meta_net(fea_u12, labels_un_l)
            meta_output_u1 = torch.sigmoid(meta_output_u1_s)
            meta_output_u2 = torch.sigmoid(meta_output_u2_s) 
            if args.use_meta_label < epoch and args.use_meta_label > 0:
                l_u = eval_loss_u.type_as(w_u)
            else:
                l_u = w_u
            select_idx_u = l_u.view(-1) < min(args.meta_thd, args.p_threshold)
            
            

            l_u = torch.zeros_like(w_u).to(device)

            if True:
                
                selected_u = (targets_u > perclass_thd.view(-1)).nonzero()
                meta_pred_u_pos1 = meta_output_u1[selected_u[:,0], selected_u[:,1]]
                meta_pred_u_pos2 = meta_output_u2[selected_u[:,0], selected_u[:,1]]
                l_u_pos = torch.ones_like(meta_pred_u_pos1).type_as(l_u)
                meta_pred_u1 = torch.cat([meta_pred_u1.view(-1), meta_pred_u_pos1.view(-1)]) 
                meta_pred_u2 = torch.cat([meta_pred_u2.view(-1), meta_pred_u_pos2.view(-1)]) 
                l_u = torch.cat([l_u.view(-1), l_u_pos.view(-1)])
                select_idx_u = torch.cat([select_idx_u, l_u_pos])

            meta_reg_u1 = meta_bce(meta_pred_u1.view(-1), l_u.view(-1))
            meta_reg_u2 = meta_bce(meta_pred_u2.view(-1), l_u.view(-1))       
                
            labels_x_l = labels_x_l.view(-1)
                
            meta_pred_x1, meta_output_x1_s = meta_net(fea_x1, labels_x_l)
            meta_output_x1 = torch.sigmoid(meta_output_x1_s) 
            meta_pred_x2, meta_output_x2_s = meta_net(fea_x2, labels_x_l) 
            meta_output_x2 = torch.sigmoid(meta_output_x2_s)

            if args.use_meta_label < epoch and args.use_meta_label > 0:
                l_x = eval_loss_x.type_as(w_x)
            else:
                l_x = w_x
            select_idx_x = l_x.view(-1) > max(1-args.meta_thd, 1-args.p_threshold)

                        
            l_x = torch.ones_like(w_x).to(device)
                        

            full_t = torch.zeros(labels_x_l.size(0), args.num_class).type_as(l_x)
            full_t = full_t.scatter(1, labels_x_l.view(-1,1), l_x.view(-1,1))
            mask = (torch.ones(l_x.size(0), args.num_class)).type_as(meta_output_x1) * 0.1  
                        
            mask = mask * l_x.view(-1,1) 
            mask = mask.scatter(1, labels_x_l.view(-1,1), 1)
            mask = mask * select_idx_x.view(-1,1)     

            meta_reg_x1 = meta_bce(meta_output_x1.view(-1), full_t.view(-1))
                        
            meta_reg_x2 = meta_bce(meta_output_x2.view(-1), full_t.view(-1))

            meta_reg = ((meta_reg_u1 * select_idx_u).sum() + (meta_reg_u2 * select_idx_u).sum() + (meta_reg_x1 * mask.view(-1)).sum() + (meta_reg_x2 * mask.view(-1)).sum() )/(select_idx_u.sum()*2 + mask.sum()*2) 
                            
            tf_list.append(['MetaNet/{}/train_loss'.format(model_name), meta_reg.item(), global_meta_iter] )
            meta_optimizer.zero_grad()
            meta_reg.backward()
            meta_optimizer.step()  
            global_meta_iter += 1

            

    queue.put((global_iter,global_meta_iter, tf_list))


def warmup(epoch,net,optimizer,dataloader,device,whichnet):
    CEloss = nn.CrossEntropyLoss()
    acc_meter = torchnet.meter.ClassErrorMeter(topk=[1,5], accuracy=True)
    
    net.train()
    num_iter = (len(dataloader.dataset)//dataloader.batch_size)+1
    for batch_idx, (inputs, labels, path) in enumerate(dataloader):      
        inputs, labels = inputs.to(device), labels.to(device,non_blocking=True) 
        optimizer.zero_grad()
        outputs = net(inputs)               
        loss = CEloss(outputs, labels)   
        
        #penalty = conf_penalty(outputs)
        L = loss #+ penalty      

        L.backward()  
        optimizer.step() 

        sys.stdout.write('\n')
        sys.stdout.write('%s |%s  Epoch [%3d/%3d] Iter[%4d/%4d]\t CE-loss: %.4f'
                %(args.id, whichnet, epoch, args.num_epochs, batch_idx+1, num_iter, loss.item()))
        sys.stdout.flush()

        
def test(epoch,net1,net2,test_loader,device,queue, meta_net1=None, meta_net2=None, is_warm=True):
    acc_meter = torchnet.meter.ClassErrorMeter(topk=[1,5], accuracy=True)
    acc_meter1 = torchnet.meter.ClassErrorMeter(topk=[1,5], accuracy=True)
    acc_meter2 = torchnet.meter.ClassErrorMeter(topk=[1,5], accuracy=True) 
    acc_meter.reset()
    acc_meter1.reset()
    acc_meter2.reset()
    if not is_warm:
        meta_acc_meter = torchnet.meter.ClassErrorMeter(topk=[1,5], accuracy=True)
        meta_acc_meter1 = torchnet.meter.ClassErrorMeter(topk=[1,5], accuracy=True)
        meta_acc_meter2 = torchnet.meter.ClassErrorMeter(topk=[1,5], accuracy=True) 
        meta_acc_meter.reset()
        meta_acc_meter1.reset()
        meta_acc_meter2.reset()


    net1.eval()
    net2.eval()
    if not is_warm:
        meta_net1.eval()
        meta_net2.eval()
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device), targets.to(device,non_blocking=True)
            fea1 = net1.features(inputs)
            fea2 = net2.features(inputs)
            outputs1 = net1.logits(fea1)
            outputs2 = net2.logits(fea2)           
            outputs = outputs1+outputs2
            _, predicted = torch.max(outputs, 1)                 
            acc_meter.add(outputs,targets)
            acc_meter1.add(outputs1, targets)
            acc_meter2.add(outputs2,targets)

            if not is_warm:
                _, meta_outputs_s1 = meta_net1(fea1)
                meta_outputs1 = torch.softmax(meta_outputs_s1,dim=-1)
                _, meta_outputs_s2 = meta_net2(fea2)
                meta_outputs2 = torch.softmax(meta_outputs_s2,dim=-1) 
                meta_outputs = meta_outputs1 + meta_outputs2
                meta_acc_meter.add(meta_outputs, targets)
                meta_acc_meter1.add(meta_outputs1, targets)
                meta_acc_meter2.add(meta_outputs2, targets)

                

    accs = acc_meter.value()
    accs1 = acc_meter1.value()
    accs2 = acc_meter2.value()
    if not is_warm:
        meta_accs = meta_acc_meter.value()
        meta_accs1 = meta_acc_meter1.value()
        meta_accs2 = meta_acc_meter2.value()
        
        queue.put((accs, accs1, accs2, meta_accs, meta_accs1, meta_accs2)) 
    else:
        queue.put((accs, accs1, accs2, None, None, None))



def eval_train(eval_loader, model, device, net_name,queue, log, meta_net, meta_optimizer, global_meta_iter, epoch ): 
    log=open(log,'a') 
    tf_list = []

    CE = nn.CrossEntropyLoss(reduction='none')
    model.eval()
    num_iter = (len(eval_loader.dataset)//eval_loader.batch_size)+1
    losses = torch.zeros(len(eval_loader.dataset)) 
    labels = torch.zeros(len(eval_loader.dataset))
    
    logits = torch.zeros(len(eval_loader.dataset),args.num_class)

    meta_net.eval()
    meta_prob = torch.zeros(len(eval_loader.dataset))   
    fea_list = []
    label_list = []
    index_list = []
    n=0  

    if args.gmm_ablation:
        meta_logits = torch.zeros(len(eval_loader.dataset), args.num_class) 
        meta_logits_list = []

    with torch.no_grad():
        for batch_idx, (inputs, targets, index) in enumerate(eval_loader):
            inputs, targets = inputs.to(device), targets.to(device,non_blocking=True) 
            fea = model.features(inputs) 
            outputs = model.logits(fea)

            if args.gmm_ablation:
                _, meta_logits_batch = meta_net(fea, targets) 

                meta_logits_batch = (torch.softmax(meta_logits_batch,dim=-1) + torch.softmax(outputs,dim=-1))/2 
                meta_logits_list.append(meta_logits_batch.cpu())
                fea_list.append(fea.cpu().numpy())
                label_list.append(targets.cpu().numpy())   

            if not args.gmm_ablation:
                fea_list.append(fea.cpu().numpy())
                label_list.append(targets.cpu().numpy())

            loss = CE(outputs, targets)  
            tmp_index_list = []
            for b in range(inputs.size(0)):
                
                meta_prob[n] = meta_pred[b]    
                if args.gmm_ablation:
                    meta_logits[n] = meta_logits_batch[b].cpu()
                    
                losses[index[b]]=loss[b]
                labels[index[b]] = targets[b]
                
                logits[n] = torch.softmax(outputs[b],dim=-1).cpu() 
                tmp_index_list.append(n)
                n += 1  
            index_list.append(tmp_index_list)       
            sys.stdout.write('\n')
            sys.stdout.write('|%s Evaluating loss Iter[%3d/%3d]\t' %(net_name,batch_idx,num_iter)) 
            sys.stdout.flush()    
                                    
    losses = (losses-losses.min())/(losses.max()-losses.min())    
    losses = losses.reshape(-1,1)
    if args.gmm_ablation:
        #get prob from predicted label and prototypes
        prob = prob_prototype(meta_logits, labels)
    else:
        # fit a two-component GMM to the loss
        prob, all_in_one_gmm, perclass_gmm = verbose_prob_estimate(losses, labels,  args.num_class, log, False) #记得检查gmm模型参数保存的时候是不是会被覆盖写 


    if True:
        if epoch > 5:
            per_class_thd, _ = hardness_estimate(logits, prob> args.p_threshold, args.num_class, 0, log, epoch)
        else:
            per_class_thd = [1.0] * args.num_class
        per_class_thd = torch.Tensor(per_class_thd).type_as(losses).to(device)
    

    if True:
        meta_net.train()
        meta_bce = nn.BCELoss(reduction='none')
        CEloss = nn.CrossEntropyLoss()
        if True:
            meta_prob = torch.zeros(len(eval_loader.dataset))   
            for i in range(len(fea_list)):
                fea = torch.Tensor(fea_list[i]).to(device, non_blocking=True)
                l = torch.LongTensor(label_list[i])
                l = l.to(device, non_blocking=True) # learn embedding

                if args.gmm_ablation:
                    batch_logits = meta_logits_list[i].cuda() 
                    pseudo_label = torch.max(batch_logits, dim=1)[1]
                    meta_pred, meta_outputs_s = meta_net(fea, l.view(-1))
                    meta_loss = CEloss(meta_outputs_s, pseudo_label.type_as(l))                 
                else:

                    batch_t = torch.Tensor(prob[index_list[i]]).to(device)

                    
                    batch_logits = logits[index_list[i]].to(device)
                    
                    select_idx = torch.logical_or(batch_t<args.meta_thd, batch_t>(1-args.meta_thd))
                  

                    h_t = np.where(prob[index_list[i]] > args.p_threshold, np.ones_like(prob[index_list[i]]), np.zeros_like(prob[index_list[i]]))
                    t = torch.Tensor(h_t).to(device) 

                    # 

                    fea, l, t = Variable(fea), Variable(l), Variable(t)


                    meta_pred, meta_outputs_s = meta_net(fea, l.view(-1)) 
                    meta_outputs = torch.sigmoid(meta_outputs_s)
                    meta_logits = torch.softmax(meta_outputs_s, dim=-1)


                    for b,j in enumerate( index_list[i]):
                        meta_prob[j] = meta_pred[b].item()

                    if not select_idx.sum().item() > 0:

                        continue
                    
                    
                    full_t = torch.zeros(t.size(0), args.num_class).type_as(t)
                    full_t = full_t.scatter(1, l.view(-1,1), t.view(-1,1))
                    mask = (torch.ones(t.size(0), args.num_class)).type_as(meta_outputs) * 0.1 
                        
                    mask = mask * t.view(-1,1) 
                    mask = mask.scatter(1, l.view(-1,1), 1)
                    
                    if True:
                        full_t[batch_logits > per_class_thd.view(-1)] = 1
                        mask[batch_logits > per_class_thd.view(-1)] = 1 

                    mask = mask * select_idx.view(-1,1)

                    meta_loss = meta_bce(meta_outputs.view(-1), full_t.view(-1))
                    meta_loss = (meta_loss * mask.view(-1)).sum()/mask.sum()  
                
                tf_list.append(['MetaNet/{}/train_loss'.format(net_name), meta_loss.item(), global_meta_iter])
                global_meta_iter += 1

                meta_optimizer.zero_grad()
                meta_loss.backward()
                meta_optimizer.step()
        


   
    log.close()
    
    if args.gmm_ablation:
        queue.put((prob, meta_logits.numpy(), meta_prob.numpy(), global_meta_iter, tf_list, per_class_thd))
    else:
        queue.put((prob, losses.cpu().numpy(), meta_prob.numpy(), global_meta_iter, tf_list, per_class_thd.detach().cpu().numpy()))
     

def linear_rampup(current, warm_up, rampup_length=16):
    current = np.clip((current-warm_up) / rampup_length, 0.0, 1.0)
    return args.lambda_u*float(current)

class SemiLoss(object):
    def __call__(self, outputs_x, targets_x, outputs_u, targets_u, epoch, warm_up):
        probs_u = torch.softmax(outputs_u, dim=1)

        Lx = -torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))
        Lu = torch.mean((probs_u - targets_u)**2)

        return Lx, Lu, linear_rampup(epoch,warm_up)

class NegEntropy(object):
    def __call__(self,outputs):
        probs = torch.softmax(outputs, dim=1)
        return torch.mean(torch.sum(probs.log()*probs, dim=1))

class MyResNet50(nn.Module):
    def __init__(self, pretrained=False, class_num=50):
        super(MyResNet50,self).__init__()
        self.model = models.resnet50(pretrained=pretrained)
        self.model.fc = nn.Linear(2048,class_num)
        #self.mlp = nn.Linear(2048, 128)
    
    def forward(self,x):
        return self.model(x)

    def logits(self, features):
        x = features
        x = self.model.fc(x)
        return x

    def features(self,x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)

        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)

        x = self.model.avgpool(x)
        fea = torch.flatten(x, 1)
        
        return fea

def create_model(device):
    
    model = MyResNet50(class_num = args.num_class)
    model = model.to(device)
    return model



if __name__ == "__main__":
    
    mp.set_start_method('spawn')
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)    
    
    f = open('log_path.txt', 'w+')

    if os.path.exists('./checkpoint/webvision'):
        pass 
    else:
        os.mkdir('./checkpoint/webvision')


    if os.path.exists('./checkpoint/webvision/%s'%args.id):
        shutil.rmtree('./checkpoint/webvision/%s'%args.id)
    else:
        os.mkdir('./checkpoint/webvision/%s'%args.id)

    tf_writer =  SummaryWriter('./checkpoint/webvision/%s'%args.id)
    f.write('-r ./checkpoint/webvision/%s\n'%args.id)

    stats_log=open('./checkpoint/%s'%(args.id)+'_stats.txt','w') 
    test_log=open('./checkpoint/%s'%(args.id)+'_acc.txt','w')     

    f.write('-f ./checkpoint/%s'%(args.id)+'_stats.txt\n')
    f.write('-f ./checkpoint/%s'%(args.id)+'_acc.txt\n')  
        
    warm_up=1

    loader = dataloader.webvision_dataloader(batch_size=args.batch_size,num_workers=8,root_dir=args.data_path,log=stats_log, num_class=args.num_class)

    print('| Building net')
    
    net1 = create_model(cuda1)
    net2 = create_model(cuda2)
    
    net1_clone = create_model(cuda2)
    net2_clone = create_model(cuda1)
    
    if args.my_resnet:
        fea_dim = 2048
    else:
        fea_dim = 1536

    meta_net1 = MetaNet_Bin( fea_dim, args.num_class)
    meta_net1_clone = MetaNet_Bin( fea_dim, args.num_class) 
    meta_net1 = meta_net1.to(cuda1)
    meta_net1_clone = meta_net1_clone.to(cuda2)


    meta_net2 = MetaNet_Bin( fea_dim, args.num_class)
    meta_net2_clone = MetaNet_Bin( fea_dim, args.num_class)
    meta_net2 = meta_net2.to(cuda2)
    meta_net2_clone = meta_net2_clone.to(cuda1)


    cudnn.benchmark = True
    
    optimizer1 = optim.SGD(net1.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
    optimizer2 = optim.SGD(net2.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)

    meta_optimizer1 = optim.SGD(meta_net1.parameters(), lr=args.meta_lr, momentum=0.9, weight_decay=5e-4)
    meta_optimizer2 = optim.SGD(meta_net2.parameters(), lr=args.meta_lr, momentum=0.9, weight_decay=5e-4)

    web_valloader = loader.run('test')
    imagenet_valloader = loader.run('imagenet')   
    
    best_web_acc = 0
    best_web_epoch = 0

    global_iter_1 = 0
    global_iter_2 = 0
    global_meta_iter_1 = 0
    global_meta_iter_2 = 0


    for epoch in range(args.num_epochs+1):   
        lr=args.lr
        meta_lr = args.meta_lr
        if epoch >= 50:
            lr /= 10  
            meta_lr /= 10    
        for param_group in optimizer1.param_groups:
            param_group['lr'] = lr       
        for param_group in optimizer2.param_groups:
            param_group['lr'] = lr              

        for param_group in meta_optimizer1.param_groups:
            param_group['lr'] = meta_lr 
        for param_group in meta_optimizer2.param_groups:
            param_group['lr'] = meta_lr 

        if epoch<warm_up and not args.debug:
            warmup_trainloader1 = loader.run('warmup')
            warmup_trainloader2 = loader.run('warmup')
            p1 = mp.Process(target=warmup, args=(epoch,net1,optimizer1,warmup_trainloader1,cuda1,'net1'))                      
            p2 = mp.Process(target=warmup, args=(epoch,net2,optimizer2,warmup_trainloader2,cuda2,'net2'))
            p1.start() 
            p2.start()        
        else:                
            pred1 = (prob1 > args.p_threshold)      
            pred2 = (prob2 > args.p_threshold)      

            meta_pred1 = (meta_prob1 > args.p_threshold)      
            meta_pred2 = (meta_prob2 > args.p_threshold)

            agree1 = pred1 == meta_pred1
            agree2 = pred2 == meta_pred2


            labeled_trainloader1, unlabeled_trainloader1 = loader.run('train',pred2,prob2,eval_train_loss=loss2, epoch=epoch, meta_pred = meta_pred2,  meta_prob = meta_prob2, use_meta_label=args.use_meta_label) # co-divide
            labeled_trainloader2, unlabeled_trainloader2 = loader.run('train',pred1,prob1,eval_train_loss=loss1,epoch=epoch, meta_pred=meta_pred1, meta_prob=meta_prob1, use_meta_label=args.use_meta_label)

            q1 = mp.Queue()
            q2 = mp.Queue()

            p1 = mp.Process(target=train, args=(epoch,net1,net2_clone,optimizer1,labeled_trainloader1, unlabeled_trainloader1,cuda1,'net1', global_iter_1, meta_net1, meta_optimizer1, global_meta_iter_1, perclass_thd2, q1))                             
            p2 = mp.Process(target=train, args=(epoch,net2,net1_clone,optimizer2,labeled_trainloader2, unlabeled_trainloader2,cuda2,'net2', global_iter_2, meta_net2, meta_optimizer2, global_meta_iter_2, perclass_thd1, q2))
            p1.start()  
            p2.start() 
            global_iter_1, global_meta_iter_1, tf_log1 = q1.get()
            global_iter_2, global_meta_iter_2, tf_log2 = q2.get()          

            for hist in tf_log1:
                key = hist[0]
                value = hist[1]
                counter = hist[2]
                tf_writer.add_scalar(key, value, counter)
            
            for hist in tf_log2:
                key = hist[0]
                value = hist[1]
                counter = hist[2]
                tf_writer.add_scalar(key, value, counter) 

        p1.join()
        p2.join()
    
        net1_clone.load_state_dict(net1.state_dict())
        net2_clone.load_state_dict(net2.state_dict())
        meta_net1_clone.load_state_dict(meta_net1.state_dict())
        meta_net2_clone.load_state_dict(meta_net2.state_dict()) 

        q1 = mp.Queue()
        q2 = mp.Queue()
        if epoch<warm_up:
            is_warm =True
        else:
            is_warm = False

        p1 = mp.Process(target=test, args=(epoch,net1,net2_clone,web_valloader,cuda1,q1,meta_net1, meta_net2_clone, is_warm))                
        p2 = mp.Process(target=test, args=(epoch,net1_clone,net2,imagenet_valloader,cuda2,q2, meta_net1_clone, meta_net2, is_warm))
        
        p1.start()   
        p2.start()
        
        web_acc, web_acc1, web_acc2, web_meta_acc, web_meta_acc1, web_meta_acc2 = q1.get()
        imagenet_acc, imagenet_acc1, imagenet_acc2, imagenet_meta_acc, imagenet_meta_acc1, imagenet_meta_acc2 = q2.get()
        
        p1.join()
        p2.join()        

        if not is_warm:
            tf_writer.add_scalar('Val/Meta_Net1_Acc', web_meta_acc1[0],  epoch)  
            tf_writer.add_scalar('Val/Meta_Net2_Acc', web_meta_acc2[0],  epoch)
            tf_writer.add_scalar('Val/Meta_Acc', web_meta_acc[0],  epoch)

            print("\n| Test Epoch #%d\t WebVision Acc: %.2f%% (%.2f%%) \t ImageNet Acc: %.2f%% (%.2f%%)\n"%(epoch,web_meta_acc[0],web_meta_acc[1],imagenet_meta_acc[0],imagenet_meta_acc[1]))

            print("\n| Test Epoch #%d\t WebVision Acc: %.2f%% (%.2f%%), %.2f%% (%.2f%%), %.2f%% (%.2f%%) \t\n"%(epoch,web_meta_acc[0],web_meta_acc[1],web_meta_acc1[0],web_meta_acc1[1],web_meta_acc2[0],web_meta_acc2[1]))     
            

        tf_writer.add_scalar('Val/Net1_Acc', web_acc1[0],  epoch)  
        tf_writer.add_scalar('Val/Net2_Acc', web_acc2[0],  epoch)
        tf_writer.add_scalar('Val/Acc', web_acc[0],  epoch) 
        
        if is_warm:
            if web_acc[0] > best_web_acc:
                best_web_acc = web_acc[0]
                best_web_epoch = epoch
        else:
            max_acc = web_acc[0]
            if max_acc > best_web_acc:
                best_web_acc = max_acc
                best_web_epoch = epoch
        print("\n| Current Best Web Epoch #%d\t Accuracy: %.2f%%\n" %(best_web_epoch,best_web_acc))  
        test_log.write('Current Best Web Epoch:%d   Accuracy:%.2f\n'%(best_web_epoch,best_web_acc))
        test_log.flush()
        
        tf_writer.add_scalar('Imagenet/Acc', imagenet_acc[0],  epoch) 
        print("\n| Test Epoch #%d\t WebVision Acc: %.2f%% (%.2f%%) \t ImageNet Acc: %.2f%% (%.2f%%)\n"%(epoch,web_acc[0],web_acc[1],imagenet_acc[0],imagenet_acc[1]))  
        test_log.write('Epoch:%d \t WebVision Acc: %.2f%% (%.2f%%) \t ImageNet Acc: %.2f%% (%.2f%%)\n'%(epoch,web_acc[0],web_acc[1],imagenet_acc[0],imagenet_acc[1]))
        print("\n| Test Epoch #%d\t WebVision Acc: %.2f%% (%.2f%%), %.2f%% (%.2f%%), %.2f%% (%.2f%%) \t\n"%(epoch,web_acc[0],web_acc[1],web_acc1[0],web_acc1[1],web_acc2[0],web_acc2[1]))  
        test_log.write('Epoch:%d \t WebVision Acc: %.2f%% (%.2f%%),%.2f%% (%.2f%%),%.2f%% (%.2f%%) \t \n'%(epoch,web_acc[0],web_acc[1],web_acc1[0],web_acc1[1],web_acc2[0],web_acc2[1]))
        test_log.flush()

        ''' 
        pth1 = './checkpoint/webvision/{}_net1_epoch{}.pth.tar'.format(args.id, epoch)
        pth2 = './checkpoint/webvision/{}_net1_epoch{}.pth.tar'.format(args.id, epoch)

        torch.save(net1.state_dict(), pth1)
        torch.save(net2.state_dict(), pth2) 
        '''

        eval_loader1 = loader.run('eval_train')          
        eval_loader2 = loader.run('eval_train')       
        q1 = mp.Queue()
        q2 = mp.Queue()
        p1 = mp.Process(target=eval_train, args=(eval_loader1,net1,cuda1,'net1', q1, './checkpoint/%s'%(args.id)+'_stats.txt', meta_net1, meta_optimizer1, global_meta_iter_1, epoch))                
        p2 = mp.Process(target=eval_train, args=(eval_loader2,net2,cuda2,'net2',q2, './checkpoint/%s'%(args.id)+'_stats.txt', meta_net2, meta_optimizer2, global_meta_iter_2, epoch))
        
        p1.start()   
        p2.start()
        
        prob1, loss1, meta_prob1, global_meta_iter_1, tf_log1, perclass_thd1  = q1.get()
        prob2, loss2, meta_prob2, global_meta_iter_2, tf_log2, perclass_thd2 = q2.get()
        
        for hist in tf_log1:
            key = hist[0]
            value = hist[1]
            counter = hist[2]
            tf_writer.add_scalar(key, value, counter)
        
        for hist in tf_log2:
            key = hist[0]
            value = hist[1]
            counter = hist[2]
            tf_writer.add_scalar(key, value, counter)

        p1.join()
        p2.join()
