import torch, math, time, argparse, os, sys
sys.path.append('.')
import random, dataset, utils, losses, net
import numpy as np
from IPython import embed
from dataset.Inshop import Inshop_Dataset
#from net.resnet import *
from net.googlenet import *
from net.bn_inception import *
from net.target_network import *
from net.modified_vgg import Modified_VGG
from dataset import sampler
from torch.utils.data.sampler import BatchSampler
from torch.utils.data.dataloader import default_collate
from utils import check_unique_code_assign, UnlabeledDataManager, UnlabeledDataManager2, GaussianBlur, plot_TSNE
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
from losses import PseudoCodewordLoss, PseudoCodewordLossSoft
from tqdm import *
import wandb
import losses_metric_learning_DPL_2
from losses_metric_learning_DPL_2 import GenerateLabelMatrix_FD_with_proxies_thresholding, N_PQ_loss, PMSE_loss_soft, PMSE_loss, BCE
from net.resnet_cifar import resnet20
from net.mobilenetv2_cifar import MobileNetV2
from net.mobilenetv2_cifar_tmp import mobilenet_v2


parser = argparse.ArgumentParser(description=
    'CUB200 Proto 1, combcls with linear classifier, shared feature extractor, EMA pseudo codeword for unlabeled'
)
# export directory, training and val datasets, test datasets

parser.add_argument('--LOG_DIR', 
    default='../logs',
    help = 'Path to log folder'
)
parser.add_argument('--project-name', 
    default='dummy_project',
    help = 'project name'
)
parser.add_argument('--name', 
    default='dummy_exp',
    help = 'experiment name'
)
parser.add_argument('--method', default = 'comblearn_consistency/proto2',
    type=str
)
parser.add_argument('--dataset', 
    default='cifar10',
    help = 'Training dataset, e.g. cifar10, cifar100'
)
parser.add_argument('--embedding-size-per-codeword', default = 12, type = int,
    dest = 'sz_embedding',
    help = 'Size of embedding that is appended to backbone model.'
)
parser.add_argument('--batch-size', default = 256, type = int,
    dest = 'sz_batch',
    help = 'Number of samples per batch.'
)
parser.add_argument('--epochs', default = 100, type = int,
    dest = 'nb_epochs',
    help = 'Number of training epochs.'
)
parser.add_argument('--gpu-id', default = 2, type = int,
    help = 'ID of GPU that is used for training.'
)
parser.add_argument('--workers', default = 4, type = int,
    dest = 'nb_workers',
    help = 'Number of workers for dataloader.'
)
parser.add_argument('--seed', default = 1, type = int,
    help = 'Seed'
)
parser.add_argument('--model', default = 'modified_vgg',
    help = 'Model for training'
)
parser.add_argument('--loss', default = 'CombiNormSoftmax',
    help = 'Criterion for training'
)
parser.add_argument('--augmentation', default='basic',
    help = 'augmentation'
)
parser.add_argument('--symmetric', default=1, type=int,
    help = 'symmetric loss'
)
parser.add_argument('--unlabeled-loss', default = 'N_PQ',
    help = 'PMSE, N_PQ, PMSE_soft'
)
parser.add_argument('--metric-mode', default = 'NormSoftmax',
    help = 'Criterion for training'
)
parser.add_argument('--optimizer', default = 'adamw',
    help = 'Optimizer setting'
)
parser.add_argument('--lr', default = 0.0002, type =float,
    help = 'Learning rate setting'
)
parser.add_argument('--pred_lr', default = 0.02, type =float,
    help = 'Learning rate setting'
)
parser.add_argument('--weight-decay', default = 1e-4, type =float,
    help = 'Weight decay setting'
)
parser.add_argument('--lr-decay-step', default = 20, type =int,
    help = 'Learning decay step setting'
)
parser.add_argument('--eval-step', default = 1, type =int,
    help = 'eval step'
)
parser.add_argument('--lr-decay-gamma', default = 0.50, type =float,
    help = 'Learning decay gamma setting'
)
parser.add_argument('--IPC', type = int,
    help = 'Balanced sampling, images per class'
)
parser.add_argument('--warm', default = 0, type = int,
    help = 'Warmup training epochs'
)
parser.add_argument('--bn-freeze', default = 0, type = int,
    help = 'Batch normalization parameter freeze'
)
parser.add_argument('--l2-norm', default = 1, type = int,
    help = 'L2 normlization'
)
parser.add_argument('--remark', default = '',
    help = 'Any reamrk'
)
## Combinatorial learning setting
parser.add_argument('--num-partitionings', default = 24, type = int,
    help = 'number of meta-class sets'
)
parser.add_argument('--num-partitions', default = 4, type = int,
    help = 'number of meta-class'
)
parser.add_argument('--q', default = 50, type = int,
    help = 'number of sub-dimensions'
)
parser.add_argument('--overlap', default = 0, type = int,
    help = 'number of overlapped sub-dimensions'
)
parser.add_argument('--exp-name', default='kmeans', type=str)

parser.add_argument('--meta-temperature', default = 0.2, type = float,
    help = 'temperature for meta-classifiers'
)
parser.add_argument('--orig-temperature', default = 0.2, type = float,
    help = 'temperature for original classifiers'
)
parser.add_argument('--proxy-lr', default =0.02, type =float,
    help = 'Learning rate setting'
)
parser.add_argument('--beta', default=0.3, type =float,
    help = 'meta loss weight'
)
parser.add_argument('--alpha', default=0.3, type =float,
    help = 'meta loss weight'
)
parser.add_argument('--lam', default=0.1, type =float,
    help = 'unlabeled sample loss'
)
parser.add_argument('--lam2', default=0.00, type =float,
    help = 'unlabeled sample loss'
)
parser.add_argument('--lam3', default=7.5, type =float,
    help = 'consistency loss'
)
parser.add_argument('--p-threshold', default =0.80, type =float,
    help = 'threshold to filter hard negative examples'
)
parser.add_argument('--n-threshold', default =0.7, type =float,
    help = 'threshold to filter hard negative examples'
)
parser.add_argument('--nb-global-proxies', default =5, type =int,
    help = 'unlabeled sample loss'
)
parser.add_argument('--k', default =1, type =int,
    help = 'threshold to filter hard negative examples'
)
parser.add_argument('--nb-fold', default = 0, type = int,
    help = 'fold idx for unseen classes'
)
parser.add_argument('--topk', default = 1, type =int,
    help = 'threshold to filter hard negative examples'
)
parser.add_argument('--init',  default='imagenet', choices=['imagenet', 'ft_imgnet', 'ft_rotnet'],
    help = 'Use self supervised pretrained model'
)
parser.add_argument('--seen-portion', default =0.5, type =float,
    help = 'labeled example portion'
)
parser.add_argument('--seen_rate', default =0.7, type =float,
    help = 'seen example portion'
)

args = parser.parse_args()
args.nb_global_proxies = int(args.seen_rate*10)
if args.seen_rate == 0.7:
    args.num_partitions = 4
args.pred_lr = args.lr

seed = args.seed
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
cudnn.benchmark = False
cudnn.deterministic = True

if args.gpu_id != -1:
    torch.cuda.set_device(args.gpu_id)

# Directory for Log
args.LOG_DIR = './logs/'
#LOG_DIR = args.LOG_DIR + '/{}/{}'.format(args.project_name, args.name)
LOG_DIR = args.LOG_DIR + '/logs_{}/{}/{}_{}_embedding{}_part_{}_{}_orig_temp_{:.4f}_meta_temp_{:.4f}_lr{}_lam{:.4f}_lam3_{:.4f}_batch{}{}/fold_{}/portion_{}/seed_{}'.format(args.dataset, args.method, args.model, args.loss, args.sz_embedding,
                                                                                                        args.num_partitionings, args.num_partitions,
                                                                                                        args.orig_temperature, args.meta_temperature, args.lr, args.lam,
                                                                                                        args.lam3, args.sz_batch, args.remark, args.nb_fold, args.seen_portion, args.seed)

if not os.path.exists('{}'.format(LOG_DIR)):
    os.makedirs('{}'.format(LOG_DIR))

# Wandb Initialization
wandb.init(project=args.dataset + '_CIFAR_PC_metric_proto2_F_%d' % (args.nb_fold), notes=LOG_DIR)
#wandb.init(project=args.project_name , notes=LOG_DIR)
wandb.config.update(args)

# Dataset Loader and Sampler
data_root = './data/'
if args.dataset == 'cub200':
    trn_dataset = dataset.Proto1CUB200(
        root=data_root + args.dataset,
        mode='train',
        transform=dataset.utils.make_transform(
            is_train=True,
            is_inception=(args.model == 'bn_inception')
        ),
        seen_rate=args.seen_rate
    )
    u_trn_dataset = dataset.Proto1CUB200(
        root=data_root + args.dataset,
        mode='gallery',
        transform=dataset.utils.make_transform(
            is_train=True,
            is_inception=(args.model == 'bn_inception')
        ))
    unlabeled_datamanager = UnlabeledDataManager(args, u_trn_dataset)

elif args.dataset == 'cifar10':
    if args.augmentation == 'basic':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
        ])
    elif args.augmentation == 'strong':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
        ])
    trn_dataset = dataset.Proto2HardCIFAR10_2(
        root=data_root + args.dataset,
        split='train',
        transform=transform_train,
        nb_fold=args.nb_fold,
        seen_portion=args.seen_portion,
        seen_rate=args.seen_rate
    )
    u_trn_dataset = dataset.Proto2HardCIFAR10_2(
        root=data_root + args.dataset,
        split='gallery',
        transform=transform_train,
        nb_fold=args.nb_fold,
        seen_portion=args.seen_portion,
        seen_rate=args.seen_rate
    )
    unlabeled_datamanager = UnlabeledDataManager2(args, u_trn_dataset)

elif args.dataset == 'cifar100':
    transform_train = transforms.Compose([
        # transforms.RandomRotation(10),  # RandomRotation 추가
        transforms.RandomCrop(32, padding=4),
        # resize 256_comb_coteach_OpenNN_CIFAR -> random_crop 224 ==> crop 32, padding 4
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])
    trn_dataset = dataset.Proto2CIFAR100(
        root=data_root + args.dataset,
        split='train',
        transform=transform_train,
        nb_fold=args.nb_fold,
    )
    u_trn_dataset = dataset.Proto2CIFAR100(
        root=data_root + args.dataset,
        split='gallery',
        transform=transform_train,
        nb_fold=args.nb_fold)
    unlabeled_datamanager = UnlabeledDataManager(args, u_trn_dataset)
else:
    assert False

if args.IPC:
    balanced_sampler = sampler.BalancedSampler(trn_dataset, batch_size=args.sz_batch, images_per_class = args.IPC)
    batch_sampler = BatchSampler(balanced_sampler, batch_size = args.sz_batch, drop_last = True)
    dl_tr = torch.utils.data.DataLoader(
        trn_dataset,
        num_workers = args.nb_workers,
        pin_memory = True,
        batch_sampler = batch_sampler
    )
    print('Balanced Sampling')
else:
    dl_tr = torch.utils.data.DataLoader(
        trn_dataset,
        batch_size = args.sz_batch,
        shuffle = True,
        num_workers = args.nb_workers,
        drop_last = False,
        pin_memory = True
    )

if args.dataset == 'cifar10':
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])
    query_dataset = dataset.Proto2HardCIFAR10(
            root=data_root+args.dataset,
            split='query',
            transform=transform_test,
            nb_fold=args.nb_fold,
            seen_portion=args.seen_portion,
            seen_rate=args.seen_rate
    )
    dl_query = torch.utils.data.DataLoader(
        query_dataset,
        batch_size = args.sz_batch,
        shuffle = False,
        num_workers = args.nb_workers,
        pin_memory = True
    )
    gallery_dataset = dataset.Proto2HardCIFAR10(
            root=data_root+args.dataset,
            split='gallery',
            transform=transform_test,
            nb_fold=args.nb_fold,
            seen_portion=args.seen_portion,
            seen_rate=args.seen_rate
    )
    dl_gallery = torch.utils.data.DataLoader(
        gallery_dataset,
        batch_size = args.sz_batch,
        shuffle = False,
        num_workers = args.nb_workers,
        pin_memory = True
    )
    gallery_loader = torch.utils.data.DataLoader(
        gallery_dataset,
        batch_size=args.sz_batch,
        shuffle=False,
        num_workers=args.nb_workers,
        pin_memory=True,
        drop_last=False
    )
elif args.dataset == 'cifar100':
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])
    query_dataset = dataset.Proto2CIFAR100(
            root=data_root+args.dataset,
            split='query',
            transform=transform_test,
            nb_fold=args.nb_fold,)
    dl_query = torch.utils.data.DataLoader(
        query_dataset,
        batch_size = args.sz_batch,
        shuffle = False,
        num_workers = args.nb_workers,
        pin_memory = True
    )
    gallery_dataset = dataset.Proto2CIFAR100(
            root=data_root+args.dataset,
            split='gallery',
            transform=transform_test,
            nb_fold=args.nb_fold)
    dl_gallery = torch.utils.data.DataLoader(
        gallery_dataset,
        batch_size = args.sz_batch,
        shuffle = False,
        num_workers = args.nb_workers,
        pin_memory = True
    )
else:
    assert False

if args.dataset == 'cifar100':
    nb_classes = int(100 * args.seen_rate)
else:
    nb_classes = int(10 * args.seen_rate)

# Load meta-class configuration
if args.model == 'bn_inception':
    args.partitionings_path =  './data/' + '%s_partitions/%s/select_%d_overlap_%d/seen_%.2f/part_%d_%s.pth.tar' % \
                          (args.dataset, args.model, args.q, args.overlap, args.seen_rate, args.num_partitions, args.exp_name)
elif args.model == 'resnet20':
    args.partitionings_path = './data/%s_partitions/%s/select_%d_overlap_%d/seen_%.2f/fold_%d/part_%d_kmeans.pth.tar' % (
    args.dataset, 'modified_vgg', args.q, 0, args.seen_rate, args.nb_fold, args.num_partitions)
elif args.model == 'mobilenetV2':
    args.partitionings_path = './data/%s_partitions/%s/select_%d_overlap_%d/seen_%.2f/fold_%d/part_%d_kmeans.pth.tar' % (
    args.dataset, 'modified_vgg', args.q, 0, args.seen_rate, args.nb_fold, args.num_partitions)
elif args.model == 'modified_vgg':
    args.partitionings_path = './data/%s_partitions/%s/select_%d_overlap_%d/seen_%.2f/fold_%d/part_%d_kmeans.pth.tar' % \
                                (args.dataset, args.model, args.q, 0, args.seen_rate, args.nb_fold, args.num_partitions)

print(args.partitionings_path)

partitionings_candidate = torch.load(args.partitionings_path)

for i in range(len(partitionings_candidate) - args.num_partitionings):
    partitionings = partitionings_candidate[i * args.num_partitionings
                                                    :(i+1)*args.num_partitionings].t()
    uniqueness = check_unique_code_assign(partitionings, nb_classes)
    if uniqueness:
        print('Partition sample idx is :', i)
        args.sampling_idx = i
        break
    else:
        continue
sz_feature_embedding = args.sz_embedding * args.num_partitionings

# Backbone Model
if args.model.find('googlenet')+1:
    model = googlenet(embedding_size=sz_feature_embedding, pretrained=True, is_norm=False, bn_freeze = args.bn_freeze)
elif args.model.find('bn_inception')+1:
    model = bn_inception(embedding_size=sz_feature_embedding, pretrained=True, is_norm=False, bn_freeze = args.bn_freeze)
elif args.model.find('resnet18')+1:
    model = Resnet18(embedding_size=sz_feature_embedding, pretrained=True, is_norm=False, bn_freeze = args.bn_freeze)
elif args.model.find('resnet20')+1:
    model = resnet20(num_classes=sz_feature_embedding)
elif args.model.find('mobilenetV2')+1:
    model = mobilenet_v2(pretrained=True, num_classes=200)
    last_channel = model.last_channel
    model.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(last_channel, sz_feature_embedding),
        )
elif args.model.find('resnet101')+1:
    model = Resnet101(embedding_size=sz_feature_embedding, pretrained=True, is_norm=False, bn_freeze = args.bn_freeze)
elif args.model == 'modified_vgg':
    model = Modified_VGG(embedding_size=sz_feature_embedding, pretrained=True, is_norm=False, bn_freeze = args.bn_freeze)

model = model.cuda()

predictor = nn.Sequential(nn.Linear(sz_feature_embedding, int(sz_feature_embedding /2), bias=False),
                        nn.BatchNorm1d(int(sz_feature_embedding /2)),
                        nn.ReLU(inplace=True), 
                        nn.Linear(int(sz_feature_embedding /2), sz_feature_embedding)).cuda()

if args.gpu_id == -1:
    model = nn.DataParallel(model)

target_network = TargetNetwork(moving_average_decay=0.99)

# DML Losses
if args.loss == 'Proxy_Anchor':
    criterion = losses.Proxy_Anchor(nb_classes = nb_classes, sz_embed = args.sz_embedding, mrg = args.mrg, alpha = args.alpha).cuda()
elif args.loss == 'Proxy_NCA':
    criterion = losses.Proxy_NCA().cuda()
elif args.loss == 'MS':
    criterion = losses.MultiSimilarityLoss().cuda()
elif args.loss == 'Contrastive':
    criterion = losses.ContrastiveLoss().cuda()
elif args.loss == 'Triplet':
    criterion = losses.TripletLoss().cuda()
elif args.loss == 'NPair':
    criterion = losses.NPairLoss().cuda()
elif args.loss == 'NormSoftmax':
    criterion = losses.NormSoftmaxLoss(nb_classes = nb_classes, sz_embed=args.sz_embedding, temperature=0.05).cuda()
elif args.loss == 'CombiNormSoftmax':
    criterion = losses_metric_learning_DPL_2.CombiNormSoftmaxLossWithLinearDisjoint(
                nb_global_proxies=args.nb_global_proxies, nb_classes=nb_classes, 
                sz_embed=args.sz_embedding, num_partitions=args.num_partitions,
                num_partitionings=args.num_partitionings, meta_temperature=args.meta_temperature, 
                partitionings=partitionings, orig_temperature=args.orig_temperature, 
                is_norm=args.l2_norm, metric_mode=args.metric_mode, k=args.k).cuda()
    criterion.combinatorial_classifiers.set_partitionings(partitionings)

""" Not now
if args.init == 'ft_rotnet':
    print("Initialize with ft_rotnet")
    checkpoint_dir = args.LOG_DIR + '/logs_{}/{}/{}_{}_embedding{}_part_{}_{}_orig_temp_{:.4f}_meta_temp_{:.4f}_lr{}_proxylr{:.4f}_alpha_{:.2f}_batch{}{}/fold_{}/seed_{}'.format(args.dataset, 'comblearn-seen-pretraining/proto2/self-supervised', args.model, args.loss, args.sz_embedding,
                                                                                                        args.num_partitionings, args.num_partitions,
                                                                                                        args.orig_temperature, args.meta_temperature, args.lr, args.proxy_lr,
                                                                                                        args.alpha, args.sz_batch, args.remark, args.nb_fold, 1)
    checkpoint = torch.load(os.path.join(checkpoint_dir, '%s_%s_100.pth' % (args.dataset, args.model)))
    print(model.load_state_dict(checkpoint['model_state_dict']))
    print(criterion.load_state_dict(checkpoint['loss_state_dict']))

elif args.init == 'ft_imgnet':
    print("Initialize with ft_imgnet")
    checkpoint_dir = args.LOG_DIR + '/logs_{}/{}/{}_{}_embedding{}_part_{}_{}_orig_temp_{:.4f}_meta_temp_{:.4f}_lr{}_proxylr{:.4f}_alpha_{:.2f}_batch{}{}/fold_{}/seed_{}'.format(
        args.dataset, 'comblearn-seen-pretraining/proto2', args.model, args.loss, args.sz_embedding,
        args.num_partitionings, args.num_partitions,
        args.orig_temperature, args.meta_temperature, args.lr, args.proxy_lr,
        args.alpha, args.sz_batch, args.remark, args.nb_fold, 1)
    checkpoint = torch.load(os.path.join(checkpoint_dir, '%s_%s_100.pth' % (args.dataset, args.model)))

    print(model.load_state_dict(checkpoint['model_state_dict']))
    print(criterion.load_state_dict(checkpoint['loss_state_dict']))
else:
    print("Initialize with Imgnet")
"""

# Train Parameters
param_groups = [
    {'params': list(set(model.parameters())) if args.gpu_id != -1 else list(set(model.module.parameters()))},
]

if args.loss == 'Proxy_Anchor':
    param_groups.append({'params': criterion.proxies, 'lr':float(args.lr) * 100})
elif args.loss == 'NormSoftmax':
    param_groups.append({'params': criterion.proxies, 'lr': float(args.lr) * 1})
elif args.loss == 'CombiNormSoftmax':
    param_groups.append({'params': criterion.proxies, 'lr': float(args.proxy_lr) * 1})
    param_groups.append({'params': predictor.parameters(), 'lr': float(args.pred_lr) * 1})
    #param_groups.append({'params': criterion.combinatorial_classifiers.proxies, 'lr': float(args.proxy_lr) * 1})

# Optimizer Setting
if args.optimizer == 'sgd': 
    opt = torch.optim.SGD(param_groups, lr=float(args.lr), weight_decay = args.weight_decay, momentum = 0.9, nesterov=True)
elif args.optimizer == 'adam': 
    opt = torch.optim.Adam(param_groups, lr=float(args.lr), betas=(0.5, 0.999), weight_decay = args.weight_decay)
elif args.optimizer == 'rmsprop':
    opt = torch.optim.RMSprop(param_groups, lr=float(args.lr), alpha=0.9, weight_decay = args.weight_decay, momentum = 0.9)
elif args.optimizer == 'adamw':
    opt = torch.optim.AdamW(param_groups, lr=float(args.lr), betas=(0.5, 0.999),  weight_decay = args.weight_decay)

#opt2 = torch.optim.SGD(predictor.parameters(), lr=float(args.pred_lr), momentum=0.9)
#opt2 = torch.optim.AdamW(predictor.parameters(), lr=float(args.pred_lr), betas=(0.5, 0.999),  weight_decay = args.weight_decay)

scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=args.lr_decay_step, gamma = args.lr_decay_gamma)
#scheduler2 = torch.optim.lr_scheduler.StepLR(opt2, step_size=args.lr_decay_step, gamma = args.lr_decay_gamma)

print("Training parameters: {}".format(vars(args)))
print("Training for {} epochs.".format(args.nb_epochs))
losses_list = []
cls_losses_list = []
l_losses_list = []
u_losses_list = []
nm_losses_list = []
cons_losses_list = []
best_mAP = 0
best_epoch = 0

for epoch in range(0, args.nb_epochs):
    model.train()
    criterion.train()
    bn_freeze = args.bn_freeze
    if bn_freeze:
        modules = model.model.modules() if args.gpu_id != -1 else model.module.model.modules()
        for m in modules: 
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

    losses_per_epoch = []
    cls_losses_per_epoch = []
    l_losses_per_epoch = []
    u_losses_per_epoch = []
    nm_losses_per_epoch = []
    cons_losses_per_epoch = []
    
    # Warmup: Train only new params, helps stabilize learning.
    if args.warm >= epoch + 1:
        assert False

    pbar = tqdm(enumerate(dl_tr))

    for batch_idx, (x, x2, y) in pbar:

        # calculate loss for labeled samples
        """Changed"""
        m, m2 = model(x.squeeze().cuda()), model(x2.squeeze().cuda())
        # l_loss: meta loss, cls_loss: base loss
        l_loss, cls_loss, nm_loss = criterion(m, y.squeeze().cuda())
        
        if args.symmetric:
            l_loss2, cls_loss2, nm_loss2 = criterion(m2, y.squeeze().cuda())
            l_loss = (l_loss + l_loss2)/2
            cls_loss = (cls_loss + cls_loss2)/2
            nm_loss += (nm_loss + nm_loss2)/2
        # calculate loss for unlabeled samples
        x_u, x_u2, y_u = unlabeled_datamanager.next_unlabeled_train()
        m_u, m_u2 = model(x_u.squeeze().cuda()), model(x_u2.squeeze().cuda())

        if args.l2_norm:
            m = F.normalize(m.view(-1, args.num_partitionings, args.sz_embedding), dim=2)
            m = m.view(-1, args.num_partitionings * args.sz_embedding)
            m2 = F.normalize(m2.view(-1, args.num_partitionings, args.sz_embedding), dim=2)
            m2 = m2.view(-1, args.num_partitionings * args.sz_embedding)

            m_u = F.normalize(m_u.view(-1, args.num_partitionings, args.sz_embedding), dim=2)
            m_u = m_u.view(-1, args.num_partitionings * args.sz_embedding)
            m_u2 = F.normalize(m_u2.view(-1, args.num_partitionings, args.sz_embedding), dim=2)
            m_u2 = m_u2.view(-1, args.num_partitionings * args.sz_embedding)


        descriptor_l = criterion.SoftAssignment(criterion.proxies, m, args.num_partitionings, zeta=20)
        descriptor_l2 = criterion.SoftAssignment(criterion.proxies, m2, args.num_partitionings, zeta=20)
        descriptor_u = criterion.SoftAssignment(criterion.proxies, m_u, args.num_partitionings, zeta=20)
        descriptor_u2 = criterion.SoftAssignment(criterion.proxies, m_u2, args.num_partitionings, zeta=20)

        """Changed"""

        with torch.no_grad():
            u_label_matrix, mask, nb_positive = GenerateLabelMatrix_FD_with_proxies_thresholding(
                                                    descriptor_u.detach().clone(), descriptor_l.clone().detach(),
                                                    k=args.topk, positive_threshold=args.p_threshold, 
                                                    negative_threshold=args.n_threshold, return_nb_positive=True)
            if args.symmetric :
                u_label_matrix2, mask2, nb_positive2 = GenerateLabelMatrix_FD_with_proxies_thresholding(
                                                    descriptor_u2.detach().clone(), descriptor_l2.clone().detach(),
                                                    k=args.topk, positive_threshold=args.p_threshold, 
                                                    negative_threshold=args.n_threshold, return_nb_positive=True)
        # default : N_PQ
        if args.unlabeled_loss == 'N_PQ':
            u_loss = N_PQ_loss(u_label_matrix.cuda(), m_u, descriptor_u, args.num_partitionings, descriptor_l, mask=None)
            if args.symmetric:
                u_loss2 = N_PQ_loss(u_label_matrix2.cuda(), m_u2, descriptor_u2, args.num_partitionings, descriptor_l2, mask=None)
                u_loss = (u_loss + u_loss2)/2
        elif args.unlabeled_loss == 'PMSE':
            u_loss = PMSE_loss(u_label_matrix.cuda(), m_u, descriptor_u, args.num_partitionings, descriptor_l)
        elif args.unlabeled_loss == 'PMSE_soft':
            u_loss = PMSE_loss_soft(u_label_matrix.cuda(), m_u, descriptor_u, args.num_partitionings, descriptor_l)
        elif args.unlabeled_loss == 'BCE':
            u_loss = BCE(u_label_matrix.cuda(), descriptor_u, descriptor_l, nb_positive)
        else:
            assert False

        z_l1, z_l2 = descriptor_l.clone(), descriptor_l2.clone()
        z_u1, z_u2 = descriptor_u.clone(), descriptor_u2.clone()
        p_l1, p_l2 = predictor(z_l1), predictor(z_l2)
        p_u1, p_u2 = predictor(z_u1), predictor(z_u2)

        criterion2 = nn.CosineSimilarity(dim=1).cuda()        
        cons_loss = -(criterion2(p_l1, z_l2.detach()).mean() + criterion2(p_l2, z_l1.detach()).mean() + \
                      criterion2(p_u1, z_u2.detach()).mean() + criterion2(p_u2, z_u1.detach()).mean()) * 0.25

        loss = args.beta * l_loss + args.alpha * cls_loss + args.lam3 * cons_loss + args.lam * u_loss
        
        opt.zero_grad()
        #opt2.zero_grad()
        loss.backward()
        
        #criterion.combinatorial_classifiers.rescale_grad()
        losses_per_epoch.append(loss.data.cpu().numpy())
        cls_losses_per_epoch.append(cls_loss.data.cpu().numpy())
        l_losses_per_epoch.append(l_loss.data.cpu().numpy())
        u_losses_per_epoch.append(u_loss.data.cpu().numpy())
        nm_losses_per_epoch.append(nm_loss.data.cpu().numpy())
        cons_losses_per_epoch.append(cons_loss.data.cpu().numpy())
        opt.step()
        #opt2.step()

        #target_network.update_moving_average(model, criterion.combinatorial_classifiers)

        pbar.set_description(
            'Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.3f} | L_Loss: {:.3f} | Cls_Loss: {:.3f} | \
                                             U_Loss: {:.3f} | NM_Loss: {:.3f} | Con_Loss: {:.3f}' \
                .format(
                epoch, batch_idx + 1, len(dl_tr),
                100. * batch_idx / len(dl_tr),
                loss.item(),
                l_loss.item(),
                cls_loss.item(),
                u_loss.item(),
                nm_loss.item(),
                cons_loss.item()
            ))
        
    losses_list.append(np.mean(losses_per_epoch))
    cls_losses_list.append(np.mean(cls_losses_per_epoch))
    l_losses_list.append(np.mean(l_losses_per_epoch))
    u_losses_list.append(np.mean(u_losses_per_epoch))
    nm_losses_list.append(np.mean(nm_losses_per_epoch))
    cons_losses_list.append(np.mean(cons_losses_per_epoch))
    wandb.log({'loss': losses_list[-1],
               'l_loss': l_losses_list[-1],
               'u_loss': u_losses_list[-1],
               'cls_loss': cls_losses_list[-1],
               'nm_loss': nm_losses_list[-1],
               'cons_loss': cons_losses_list[-1]}, step=epoch)
    scheduler.step()
    #scheduler2.step()

    if((epoch + 1 ) % args.eval_step == 0):

        torch.save({'model_state_dict': model.state_dict(), 'loss_state_dict': criterion.state_dict()},
                   '{}/{}_{}_{}.pth'.format(LOG_DIR, args.dataset, args.model, epoch + 1)
                   )

    


