from ast import arg
from multiprocessing import reduction
import os
import os.path
import argparse
import random
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR
from sklearn.neural_network import MLPClassifier

from data.cifar import CIFAR10, CIFAR100

# from torchvision.datasets import CIFAR10, CIFAR100
# from data.cifarN import CIFAR10, CIFAR100
import datetime
import time
import logging
import time
import pickle
import json


from networks.ResNet import ResNet18, ResNet34, PreActResNet18
from common.tools import AverageMeter, getTime, predict_softmax, ProgressMeter, accuracy
from common.NoisyUtil import Train_Dataset, Semi_Labeled_Dataset, Semi_Unlabeled_Dataset
from randaugment import RandAugmentMC
from data.augmentations import Augmentation, CutoutDefault
from data.augmentation_archive import autoaug_paper_cifar10

from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler, MinMaxScaler

from torch.utils.tensorboard import SummaryWriter

parser = argparse.ArgumentParser(description='PyTorch CIFAR10/100-N Training')
parser.add_argument('--dataset', default='cifar100', type=str)
parser.add_argument('--mode', default='train', type=str, help='train or test')
parser.add_argument('--data_path', type=str, default='./data', help='data directory')
parser.add_argument('--data_percent', default=1, type=float, help='data number percent')
parser.add_argument('--noise_rate', type = float, help = 'corruption rate, should be less than 1', default = 0.2)
parser.add_argument('--batch_size', default=128, type=int, help='train batchsize')
parser.add_argument('--train_noise_type', default='symmetric', type=str, help='symmetric, pairflip, asymmetric, instance')
parser.add_argument('--test_noise_type', default='symmetric', type=str, help='symmetric, pairflip, asymmetric, instance')

parser.add_argument('--network', default='paresnet18', type=str, help='used network')
parser.add_argument('--lr', '--learning_rate', default=0.02, type=float, help='initial learning rate')
parser.add_argument('--weight_decay', type=float, help='weight_decay for training', default=5e-4)
parser.add_argument('--num_epochs', default=300, type=int)
parser.add_argument('--optim', default='cos', type=str, help='step, cos')

parser.add_argument('--alpha', default=4, type=float, help='parameter for Beta')
parser.add_argument('--T', default=0.5, type=float, help='sharpening temperature')
parser.add_argument('--lambda_u', default=0, type=float, help='weight for unsupervised loss')
parser.add_argument('--T1', default=20, type=int, help='if 0, set in below')
parser.add_argument('--T2', default=30, type=int, help='if 0, set in below')
parser.add_argument('--sel', default='none', type=str, help='sel type')

parser.add_argument('--meta_type', default='loss', type=str, help='loss, conf, diff')
parser.add_argument('--window', default=50, type=int, help='size of slide window')
parser.add_argument('--denoiser', default='mlp', type=str, help='denoiser type')
parser.add_argument('--output', default='./results/', type=str)
parser.add_argument('--resume', default=None, type=str)
parser.add_argument('--seed', default=1, type=int)
args = parser.parse_args()
print(args)


now_str = datetime.datetime.now().strftime("%Y_%m_%d_%X").replace(':','-')
if args.mode == 'train':
    args.output = os.path.join(args.output, '%s_%s_%s_%s_%s_%s_%s_%s_%s_%s'%(args.dataset, args.mode, args.train_noise_type, args.noise_rate, args.T1, args.T2, args.meta_type, args.window, args.network, args.num_epochs))

    args.denoiser_output = os.path.join(args.output, 'denoiser')
    os.makedirs(args.denoiser_output, exist_ok=True)

elif args.mode == 'test':
    if args.dataset == 'cifar10':
        args.resume = os.path.join(args.output, '%s_%s_%s_%s_%s_%s_%s_%s_%s_%s'%('cifar100', 'train', args.train_noise_type, args.noise_rate, args.T1, args.T2, args.meta_type, args.window, args.network, args.num_epochs))
    elif args.dataset == 'cifar100':
        args.resume = os.path.join(args.output, '%s_%s_%s_%s_%s_%s_%s_%s_%s_%s'%('cifar10', 'train', args.train_noise_type, args.noise_rate, args.T1, args.T2, args.meta_type, args.window, args.network, args.num_epochs))

    args.output = os.path.join(args.output, '%s_%s_%s_%s_%s_%s_%s_%s_%s_%s_%s'%(args.dataset, args.mode, args.test_noise_type, args.train_noise_type, args.noise_rate, args.T1, args.T2, args.meta_type, args.window, args.network, args.num_epochs))


print(args.output)
print(args.resume)

os.makedirs(args.output, exist_ok=True)

logging.basicConfig(level=logging.INFO,
                    filename= os.path.join(args.output, 'log.log'),
                    datefmt='%Y/%m/%d %H:%M:%S',
                    format='%(asctime)s - %(name)s - %(levelname)s - %(lineno)d - %(module)s - %(message)s')
logger = logging.getLogger(__name__)

summary_writer = SummaryWriter(log_dir=args.output)

def get_raw_dict(args):
    """
    return the dicf contained in args.

    e.g:
        >>> with open(path, 'w') as f:
                json.dump(get_raw_dict(args), f, indent=2)
    """
    if isinstance(args, argparse.Namespace):
        return vars(args)
    else:
        raise NotImplementedError("Unknown type {}".format(type(args)))

path = os.path.join(args.output, "config.json")
with open(path, 'w') as f:
    json.dump(get_raw_dict(args), f, indent=2)

network_map={'resnet18':ResNet18, 'resnet34':ResNet34,'paresnet18':PreActResNet18}
Net=network_map[args.network]

if args.seed is not None:
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def linear_rampup(current, warm_up=20, rampup_length=16):
    current = np.clip((current - warm_up) / rampup_length, 0.0, 1.0)
    return args.lambda_u * float(current)


def train(model, train_loader, optimizer, ceriation, epoch, logger):
    batch_time = AverageMeter('Time', ':6.2f')
    data_time = AverageMeter('Data', ':6.2f')
    losses = AverageMeter('Loss', ':6.2f')
    top1 = AverageMeter('Acc@1', ':6.2f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1],
        prefix="Train Epoch: [{}]".format(epoch))

    model.train()

    end = time.time()
    for i, (_, images, labels,indexes) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        images = images.cuda()
        labels = labels.cuda()

        logist = model(images)
        loss = ceriation(logist, labels)

        acc1, acc5 = accuracy(logist, labels, topk=(1, 5))
        losses.update(loss.item(), images[0].size(0))
        top1.update(acc1[0], images[0].size(0))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end)
        end = time.time()

    # progress.display(0)
    # logger.info('[%d], loss: %.4f, test acc: %.4f'%(epoch, losses.avg, top1.avg))

    return losses.avg, top1.avg.to("cpu", torch.float).item()

def evaluate(model, eva_loader, ceriation, prefix, ignore=-1):
    losses = AverageMeter('Loss', ':3.2f')
    top1 = AverageMeter('Acc@1', ':3.2f')
    model.eval()

    with torch.no_grad():
        for i, (images, labels, _) in enumerate(eva_loader):
            images = images.cuda()
            labels = labels.cuda()

            logist = model(images)

            loss = ceriation(logist, labels)
            acc1, acc5 = accuracy(logist, labels, topk=(1, 5))

            losses.update(loss.item(), images[0].size(0))
            top1.update(acc1[0], images[0].size(0))

    if prefix != "":
        print(getTime(), prefix, round(top1.avg.item(), 2))

    return losses.avg, top1.avg.to("cpu", torch.float).item()

def FixMatch_train(epoch, net, optimizer, labeled_trainloader, unlabeled_trainloader, class_weights):
    net.train()

    losses = AverageMeter('Loss', ':6.2f')
    losses_lx = AverageMeter('Loss_Lx', ':6.2f')
    losses_lu = AverageMeter('Loss_Lu', ':6.5f')

    labeled_train_iter = iter(labeled_trainloader)
    unlabeled_train_iter = iter(unlabeled_trainloader)
    num_iter = int(50000 / args.batch_size)
    for batch_idx in range(num_iter):
        try:
            inputs_x_w, inputs_x_s, targets = labeled_train_iter.next()
        except StopIteration:
            labeled_train_iter = iter(labeled_trainloader)
            inputs_x_w, inputs_x_s, targets = labeled_train_iter.next()

        try:
            inputs_u_w, inputs_u_s = unlabeled_train_iter.next()
        except StopIteration:
            unlabeled_train_iter = iter(unlabeled_trainloader)
            inputs_u_w, inputs_u_s = unlabeled_train_iter.next()

        n_x = inputs_x_w.size(0)
        n_u = inputs_u_w.size(0)


        # intputs_x_s, targets = inputs_x_s.cuda(), targets.cuda()
        # inputs_u_w, inputs_u_s = inputs_u_w.cuda(), inputs_u_w.cuda()

        inputs = torch.cat((inputs_x_s, inputs_u_w, inputs_u_s)).cuda()
        targets = torch.zeros(n_x, args.num_class).scatter_(1, targets.view(-1, 1), 1)
        targets = targets.cuda()

        logits = net(inputs)

        logits_x_s = logits[:n_x]
        logits_u_w = logits[n_x:n_x+n_u]
        logits_u_s = logits[n_x+n_u:]

        # print(logits_x_s.shape, logits_u_w.shape, logits_u_s.shape)

        del logits

        Lx_mean = -torch.mean(F.log_softmax(logits_x_s, dim=1) * targets, 0)
        Lx = torch.sum(Lx_mean * class_weights)

        pseudo_label = torch.softmax(logits_u_w.detach()/args.T, dim=-1)
        max_probs, targets_u = torch.max(pseudo_label, dim=-1)
        mask = max_probs.ge(args.threshold).float()
        # print(mask.shape)

        targets_u = torch.zeros(n_u, args.num_class).cuda().scatter_(1, targets_u.view(-1, 1), 1)
        Lu_mean = -torch.mean(F.log_softmax(logits_u_s, dim=1) * targets_u*mask.unsqueeze(1), 0)
        Lu = torch.sum(Lu_mean*class_weights)

        loss = Lx + linear_rampup(epoch + batch_idx / num_iter, args.T1) * Lu

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses_lx.update(Lx.item(), n_x)
        losses_lu.update(Lu.item(), n_u)
        losses.update(loss.item(), n_x+n_u)



# MixMatch Training
def MixMatch_train(epoch, net, optimizer, labeled_trainloader, unlabeled_trainloader, class_weights):

    net.train()
    if epoch >= args.num_epochs / 2:
        args.alpha = 0.75

    losses = AverageMeter('Loss', ':6.2f')
    losses_lx = AverageMeter('Loss_Lx', ':6.2f')
    losses_lu = AverageMeter('Loss_Lu', ':6.5f')

    labeled_train_iter = iter(labeled_trainloader)
    unlabeled_train_iter = iter(unlabeled_trainloader)
    num_iter = int(50000 / args.batch_size)
    for batch_idx in range(num_iter):
        try:
            inputs_x, inputs_x2, targets_x = labeled_train_iter.next()
        except StopIteration:
            labeled_train_iter = iter(labeled_trainloader)
            inputs_x, inputs_x2, targets_x = labeled_train_iter.next()

        try:
            inputs_u, inputs_u2 = unlabeled_train_iter.next()
        except StopIteration:
            unlabeled_train_iter = iter(unlabeled_trainloader)
            inputs_u, inputs_u2 = unlabeled_train_iter.next()

        batch_size = inputs_x.size(0)
        targets_x = torch.zeros(batch_size, args.num_class).scatter_(1, targets_x.view(-1, 1), 1)
        inputs_x, inputs_x2, targets_x = inputs_x.cuda(), inputs_x2.cuda(), targets_x.cuda()
        inputs_u, inputs_u2 = inputs_u.cuda(), inputs_u2.cuda()

        with torch.no_grad():
            outputs_u11 = net(inputs_u)
            outputs_u12 = net(inputs_u2)

            pu = (torch.softmax(outputs_u11, dim=1) + torch.softmax(outputs_u12, dim=1)) / 2
            ptu = pu**(1 / args.T)  # temparature sharpening

            targets_u = ptu / ptu.sum(dim=1, keepdim=True)  # normalize
            targets_u = targets_u.detach()

        all_inputs = torch.cat([inputs_x, inputs_x2, inputs_u, inputs_u2], dim=0)
        all_targets = torch.cat([targets_x, targets_x, targets_u, targets_u], dim=0)

        idx = torch.randperm(all_inputs.size(0))
        input_a, input_b = all_inputs, all_inputs[idx]
        target_a, target_b = all_targets, all_targets[idx]

        mixmatch_l = np.random.beta(args.alpha, args.alpha)
        mixmatch_l = max(mixmatch_l, 1 - mixmatch_l)

        mixed_input = mixmatch_l * input_a + (1 - mixmatch_l) * input_b
        mixed_target = mixmatch_l * target_a + (1 - mixmatch_l) * target_b

        logits = net(mixed_input)
        logits_x = logits[:batch_size * 2]
        logits_u = logits[batch_size * 2:]

        Lx_mean = -torch.mean(F.log_softmax(logits_x, dim=1) * mixed_target[:batch_size * 2], 0)
        Lx = torch.sum(Lx_mean * class_weights)

        probs_u = torch.softmax(logits_u, dim=1)
        Lu = torch.mean((probs_u - mixed_target[batch_size * 2:])**2)
        loss = Lx + linear_rampup(epoch + batch_idx / num_iter, args.T1) * Lu

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses_lx.update(Lx.item(), batch_size * 2)
        losses_lu.update(Lu.item(), len(logits) - batch_size * 2)
        losses.update(loss.item(), len(logits))

def update_trainloader(model, train_data, clean_targets, noisy_targets, confident_indexs, unconfident_indexs):

    noisy_targets = np.array(noisy_targets)
    confident_dataset = Semi_Labeled_Dataset(train_data[confident_indexs], noisy_targets[confident_indexs], transform_strong)
    unconfident_dataset = Semi_Unlabeled_Dataset(train_data[unconfident_indexs], transform_strong)

    uncon_batch = int(args.batch_size / 2) if len(unconfident_indexs) > len(confident_indexs) else int(len(unconfident_indexs) / (len(confident_indexs) + len(unconfident_indexs)) * args.batch_size)
    con_batch = args.batch_size - uncon_batch

    labeled_trainloader = DataLoader(dataset=confident_dataset, batch_size=con_batch, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
    unlabeled_trainloader = DataLoader(dataset=unconfident_dataset, batch_size=uncon_batch, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)

    # Loss function
    train_nums = np.zeros(args.num_class, dtype=int)
    for item in noisy_targets[confident_indexs]:
        train_nums[item] += 1

    # zeros are not calculated by mean
    # avoid too large numbers that may result in out of range of loss.
    with np.errstate(divide='ignore'):
        cw = np.mean(train_nums[train_nums != 0]) / train_nums
        cw[cw == np.inf] = 0
        cw[cw > 3] = 3
    class_weights = torch.FloatTensor(cw).cuda()
    # print("Category", train_nums, "precent", class_weights)
    return labeled_trainloader, unlabeled_trainloader, class_weights


def meta_denoising(train_loader, model, recorder, epoch, args):

    example_loss = np.zeros_like(noise_or_not, dtype=float)
    example_conf = np.zeros_like(noise_or_not, dtype=float)
    example_diff = np.zeros_like(noise_or_not, dtype=float)
    example_diff2 = np.zeros_like(noise_or_not, dtype=float)

    model.eval()
    sel_idxs_pp = []
    for i, (images, _, labels, indexes) in enumerate(train_loader):
        # measure data loading time
        images = images.cuda()
        labels = labels.cuda()

        logits = model(images)
        loss = celoss(logits, labels)

        conf = F.softmax(logits, dim=1)
        conf_label = conf[range(len(indexes)),labels]
        conf_max, preds = torch.max(conf.data, dim=1)
        conf_copy = conf.clone().detach()
        conf_copy[range(len(indexes)),labels] = 0.0
        conf_max2, _ = torch.max(conf_copy.data, dim=1)
        # print(conf_label.shape, conf_max.shape)
        for i, lo, co, df, df2, pred, lb  in zip(indexes, loss, conf_label, conf_max-conf_label, conf_max2-conf_label, preds, labels):
            i = i.cpu().data.item()
            example_loss[i] = lo.cpu().data.item()
            example_conf[i] = co.cpu().data.item()
            example_diff[i] = df.cpu().data.item()
            example_diff2[i] = df2.cpu().data.item()
            if pred == lb:
                sel_idxs_pp.append(i)

    recorder['loss'].append(example_loss)
    recorder['conf'].append(example_conf)
    recorder['diff'].append(example_diff)
    recorder['diff2'].append(example_diff2)

    if args.meta_type == 'loss':
        meta_fea = np.array(recorder['loss']).transpose()
    elif args.meta_type == 'conf':
        meta_fea = np.array(recorder['conf']).transpose()
    elif args.meta_type == 'diff':
        meta_fea = np.array(recorder['diff']).transpose()
    elif args.meta_type == 'diff2':
        meta_fea = np.array(recorder['diff2']).transpose()

    if args.window > 1:
        if epoch >= args.T1:
            if meta_fea.shape[1]>args.window:
                meta_fea = meta_fea[:, -args.window:]
    else:
        meta_fea = meta_fea[:, -args.window:]

    print(meta_fea.shape)

    meta_labels = np.array(noise_or_not)

    if args.mode == 'train':
        denoiser = LogisticRegression(random_state=1, C=1.0, solver='lbfgs', max_iter=1000)

        denoiser.fit(meta_fea, meta_labels)
        pred_proba = denoiser.predict_proba(meta_fea)[:,0]

        sorted_idxs = np.argsort(-pred_proba).tolist()

        sorted_idxs_plus = np.argsort(meta_fea[:,-1]).tolist()


        remember_rate = 1 - args.actual_noise_rate
        num_remember = int(remember_rate * len(sorted_idxs))
        num_forget = len(sorted_idxs)-num_remember

        if args.sel == 'plus':
            sel_idxs = list(set(sorted_idxs[:num_remember]+sorted_idxs_plus[:num_remember]+sel_idxs_pp))
        elif args.sel == 'none':
            sel_idxs = list(set(sorted_idxs[:num_remember]+sorted_idxs_plus[:num_remember]))

        unsel_idxs = list(set(range(len(sorted_idxs_plus)))-set(sel_idxs))

        num_correct = np.sum(meta_labels[sorted_idxs[:num_remember]] == 0)
        num_correct_plus = np.sum(meta_labels[sel_idxs]==0)
        num_detect = np.sum(meta_labels[sorted_idxs[num_remember:]] == 1)
        meta_clean_prec = num_correct / float(num_remember)
        clean_prec_plus = num_correct_plus/float(len(sel_idxs))
        meta_noise_prec = num_detect / float(num_forget)

        num_top_remember = int(0.1 * len(sorted_idxs))

        meta_clean_top_prec = np.sum(meta_labels[sorted_idxs[:num_top_remember]] == 0) / float(num_top_remember)

        with open(os.path.join(args.denoiser_output, 'denoiser_%s.pkl'%(epoch+1)), 'wb') as f:
            pickle.dump(denoiser, f)

        print('[%d/%d] Train clean prec: %.4f [%d/%d] Train clean prec plus: %.4f [%d/%d] Train noise prec: %.4f [%d/%d] Train clean top prec: %.4f '%(epoch+1, args.num_epochs, meta_clean_prec, num_correct, num_remember, clean_prec_plus, num_correct_plus, len(sel_idxs), meta_noise_prec, num_detect, num_forget, meta_clean_top_prec))
        logger.info('[%d/%d] Train clean prec: %.4f [%d/%d] Train clean prec plus: %.4f [%d/%d] Train noise prec: %.4f [%d/%d] Train clean top prec: %.4f'%(epoch+1, args.num_epochs, meta_clean_prec, num_correct, num_remember, clean_prec_plus, num_correct_plus, len(sel_idxs), meta_noise_prec, num_detect, num_forget, meta_clean_top_prec))

    elif args.mode == 'test':
        model_path = os.path.join(args.resume, 'denoiser', 'denoiser_%s.pkl'%(epoch+1))

        err = True
        time_start = time.time()
        while err:
            try:
                with open(model_path, 'rb') as f:
                    denoiser = pickle.load(f)
                    err = False
            except :
                err = True
                now_time = time.time()
                waiting_time = now_time - time_start
                print('waiting for loading LogisticRegression model: {:.0f}h {:.0f}m {:.0f}s'.format(waiting_time/3600, (waiting_time%3600)/60, (waiting_time% 3600) %60), end="")
                time.sleep(30)
                print('\r', end="", flush=True)

        pred_proba = denoiser.predict_proba(meta_fea)[:,0]

        sorted_idxs = np.argsort(-pred_proba).tolist()

        sorted_idxs_plus = np.argsort(meta_fea[:,-1]).tolist()

        remember_rate = 1 - args.actual_noise_rate
        num_remember = int(remember_rate * len(sorted_idxs))
        num_forget = len(sorted_idxs)-num_remember

        if args.sel == 'plus':
            sel_idxs = list(set(sorted_idxs[:num_remember]+sorted_idxs_plus[:num_remember]+sel_idxs_pp))
        elif args.sel == 'none':
            sel_idxs = list(set(sorted_idxs[:num_remember]+sorted_idxs_plus[:num_remember]))

        unsel_idxs = list(set(range(len(sorted_idxs_plus)))-set(sel_idxs))

        num_correct = np.sum(meta_labels[sorted_idxs[:num_remember]] == 0)
        num_correct_plus = np.sum(meta_labels[sel_idxs]==0)
        num_detect = np.sum(meta_labels[sorted_idxs[num_remember:]] == 1)
        meta_clean_prec = num_correct / float(num_remember)
        clean_prec_plus = num_correct_plus/float(len(sel_idxs))
        meta_noise_prec = num_detect / float(num_forget)

        num_top_remember = int(0.1 * len(sorted_idxs))

        meta_clean_top_prec = np.sum(meta_labels[sorted_idxs[:num_top_remember]] == 0) / float(num_top_remember)

        print('[%d/%d] Test clean prec: %.4f [%d/%d] Test clean prec plus: %.4f [%d/%d] Test noise prec: %.4f [%d/%d] Test clean top prec: %.4f '%(epoch+1, args.num_epochs, meta_clean_prec, num_correct, num_remember, clean_prec_plus, num_correct_plus, len(sel_idxs), meta_noise_prec, num_detect, num_forget, meta_clean_top_prec))
        logger.info('[%d/%d] Test clean prec: %.4f [%d/%d] Test clean prec plus: %.4f [%d/%d] Test noise prec: %.4f [%d/%d] Test clean top prec: %.4f'%(epoch+1, args.num_epochs, meta_clean_prec, num_correct, num_remember, clean_prec_plus, num_correct_plus, len(sel_idxs), meta_noise_prec, num_detect, num_forget, meta_clean_top_prec))

    if summary_writer:
        # tensorboard logger
        summary_writer.add_scalar('clean_prec', meta_clean_prec, epoch)
        summary_writer.add_scalar('clean_prec_plus', clean_prec_plus, epoch)

    if epoch < args.T2:
        confident_idx = sorted_idxs[:num_remember]
        unconfident_idx = sorted_idxs[num_remember:]
    else:
        confident_idx = sel_idxs
        unconfident_idx = unsel_idxs

    return confident_idx, unconfident_idx


noise_type_map = {'clean':'clean_label', 'worst': 'worse_label', 'aggre': 'aggre_label', 'rand1': 'random_label1', 'rand2': 'random_label2', 'rand3': 'random_label3', 'clean100': 'clean_label', 'noisy': 'noisy_label'}

norm_map = {'cifar10':transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            'cifar100':transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))}
transforms_map = {'strong':transforms.Compose([
                  Augmentation(autoaug_paper_cifar10()),
                  transforms.RandomCrop(32, padding=4),
                  transforms.RandomHorizontalFlip(),
                  transforms.ToTensor(),
                  norm_map[args.dataset],
                  CutoutDefault(16)]),
                  'weak': transforms.Compose([
                  transforms.RandomCrop(32, padding=4),
                  transforms.RandomHorizontalFlip(),
                  transforms.ToTensor(),
                  norm_map[args.dataset]])}



if args.dataset == 'cifar10' or args.dataset == 'CIFAR10':
    if args.T1 == 0:
        # TODO:
        args.T1 = 20
    args.num_class = 10
    transform_weak = transforms_map['weak']
    transform_strong = transforms_map['strong']
    print(transform_weak, transform_strong)
    # print(transform_train)
    transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

    if args.mode == 'train':
        cifar10_noise_type = args.train_noise_type
    elif args.mode == 'test':
        cifar10_noise_type = args.test_noise_type

    train_set = CIFAR10(root=args.data_path,download=False,train=True,transform_weak=transform_weak,transform_strong=transform_strong, noise_type=cifar10_noise_type,noise_rate=args.noise_rate,random_state=args.seed)
    test_set = CIFAR10(root=args.data_path,download=False,train=False,transform_weak=transform_test, noise_type=None, noise_rate=None, random_state=args.seed)

    args.lambda_u = 5
    # For CIFAR-10N noisy labels
    data = train_set.train_data
    clean_labels = train_set.train_labels
    noisy_labels = train_set.train_noisy_labels


elif args.dataset == 'cifar100' or args.dataset == 'CIFAR100':
    if args.T1 == 0:
        args.T1 = 20
    args.num_class = 100
    transform_strong = transforms_map['strong']
    transform_weak = transforms_map['weak']
    print(transform_weak, transform_strong)
    # print(transform_train)
    transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])

    if args.mode == 'train':
        cifar100_noise_type = args.train_noise_type
    elif args.mode == 'test':
        cifar100_noise_type = args.test_noise_type
    train_set = CIFAR100(root=args.data_path,download=True,train=True,transform_weak=transform_weak,transform_strong=transform_strong, noise_type=cifar100_noise_type,noise_rate=args.noise_rate,random_state=args.seed)
    test_set = CIFAR100(root=args.data_path,download=True,train=False,transform_weak=transform_test,noise_type=None,noise_rate=None,random_state=args.seed)

    args.lambda_u = 75
    # For CIFAR-100N noisy labels
    data = train_set.train_data
    clean_labels = train_set.train_labels
    noisy_labels = train_set.train_noisy_labels

noise_or_not = train_set.noise_or_not
args.actual_noise_rate = noise_or_not.sum()/len(noise_or_not)
print('Actual noise_rate: %.4f'%(args.actual_noise_rate))

criterion = nn.CrossEntropyLoss().cuda()
train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
test_loader = DataLoader(dataset=test_set, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)

model = Net(num_classes=args.num_class).cuda()
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
if args.optim == 'cos':
    scheduler = CosineAnnealingLR(optimizer, args.num_epochs, args.lr / 100)
else:
    scheduler = MultiStepLR(optimizer, milestones=[150, 250], gamma=0.1)

best_test_acc = 0
celoss = torch.nn.CrossEntropyLoss(reduction='none',ignore_index=-1).cuda()

recorder = {'loss': [], 'conf': [], 'diff': [], 'diff2': []}
for epoch in range(args.num_epochs):
    if epoch < args.T1:
    # if epoch < 1:
        train(model, train_loader, optimizer, criterion, epoch, logger)

        _, test_acc = evaluate(model, test_loader, criterion, "")
        best_test_acc = test_acc if best_test_acc < test_acc else best_test_acc

        print('[%d/%d] Noisy Training Test Acc: %.4f Best Acc: %.4f' % (epoch + 1, args.num_epochs, test_acc, best_test_acc))
        logger.info('[%d/%d] Noisy Training Test Acc: %.4f Best Acc: %.4f' % (epoch + 1, args.num_epochs, test_acc, best_test_acc))


        meta_confident_idx, meta_unconfident_idx = meta_denoising(train_loader, model, recorder, epoch, args)

    else:

        labeled_trainloader, unlabeled_trainloader, class_weights = update_trainloader(model, data, clean_labels, noisy_labels, meta_confident_idx, meta_unconfident_idx)
        MixMatch_train(epoch, model, optimizer, labeled_trainloader, unlabeled_trainloader, class_weights)

        _, test_acc = evaluate(model, test_loader, criterion, "")
        best_test_acc = test_acc if best_test_acc < test_acc else best_test_acc

        print('[%d/%d] Semi Training Test Acc: %.4f Best Acc: %.4f' % (epoch + 1, args.num_epochs, test_acc, best_test_acc))
        logger.info('[%d/%d] Semi Training Test Acc: %.4f Best Acc: %.4f' % (epoch + 1, args.num_epochs, test_acc, best_test_acc))


        meta_confident_idx, meta_unconfident_idx = meta_denoising(train_loader, model, recorder, epoch, args)

        # FixMatch_train(epoch, model, optimizer, labeled_trainloader, unlabeled_trainloader, class_weights)

    if summary_writer:
        # tensorboard logger
        summary_writer.add_scalar('test_acc', test_acc, epoch)

    scheduler.step()
