import argparse
import torchvision.transforms as transforms

####################################################################################################################
parser = argparse.ArgumentParser(description="main")

# data condition
parser.add_argument('--dataset', type=str, default="CIFAR10", help = 'CIFAR10, CIFAR100')
parser.add_argument('--noise_type', type=str, default='sym', help='clean, sym, asym')
parser.add_argument('--noisy_ratio', type=str, default="0.2", help='between 0 and 1')

# classifier condition
parser.add_argument('--class_method', type=str, default=None, help='T estimation method')

# experiment condition
parser.add_argument('--optimizer', type=str, default='SGD', help='SGD, Adam')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--lr', type=float, default=0.02, help = "Learning rate (Default : 1e-3)")
parser.add_argument('--total_epochs', type=int, default=200, help='total training epoch')
parser.add_argument('--N', type=float, default=1.0, help='how much samples to choose?')

# etc
parser.add_argument('--set_gpu', type=str, default='0', help='gpu setting')
parser.add_argument('--data_dir', type=str, default='./data/')

parser.add_argument('--mode', type=bool, default="f-PML", help='RENT, f-PML')
parser.add_argument('--div', type=str, default='Jensen-Shannon', help='') # Jensen-Shannon
parser.add_argument('--change_var', type=bool, default=False, help='')
parser.add_argument('--obj_fcn_corr', type=bool, default=False, help='')
parser.add_argument('--posterior_corr', type=bool, default=False, help='')

parser.add_argument('--no_pretrain', type=bool, default=False, help='')

args = parser.parse_args()
####################################################################################################################

# Dataset Information
if args.dataset == 'CIFAR10':
    args.n_channel = 3
    args.noisy_label_list = [0,1,0,5,7,3,6,7,8,1]
    args.n_class = 10
    args.transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    args.test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    args.batch_size = 128

elif args.dataset == 'CIFAR100':
    args.n_channel = 3
    args.n_class = 100
    args.transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    args.test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    args.batch_size = 128

else:
    args.n_channel = -1
    args.n_class = -1
    args.train_transform = None
    args.test_transform = None
    args.batch_size = None


# Dummy setting for code implementations
if args.noise_type == 'clean':
    args.noisy_ratio = '0.0'