import argparse
import torch
import torch.backends.cudnn as cudnn
from torchvision import models
from data_aug.contrastive_learning_dataset import ContrastiveLearningDataset
from models.resnet_simclr import ResNetSimCLR
from simclr import SimCLR

import optimizers

from utils import ModelEMA, ModelAverage


model_names = sorted(name for name in models.__dict__
                     if name.islower() and not name.startswith("__")
                     and callable(models.__dict__[name]))

parser = argparse.ArgumentParser(description='PyTorch SimCLR')
parser.add_argument('-data', metavar='DIR', default='./datasets',
                    help='path to dataset')
parser.add_argument('-dataset-name', default='stl10',
                    help='dataset name', choices=['stl10', 'cifar10'])
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
                    choices=model_names,
                    help='model architecture: ' +
                         ' | '.join(model_names) +
                         ' (default: resnet50)')
parser.add_argument('-j', '--workers', default=12, type=int, metavar='N',
                    help='number of data loading workers (default: 32)')
parser.add_argument('--epochs', default=200, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('-b', '--batch-size', default=256, type=int,
                    metavar='N',
                    help='mini-batch size (default: 256), this is the total '
                         'batch size of all GPUs on the current node when '
                         'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', '--learning-rate', default=1.0, type=float,
                    metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--wd', '--weight-decay', default=0.0, type=float,
                    metavar='W', help='weight decay (default: 0.0)',
                    dest='weight_decay')
parser.add_argument('--seed', default=None, type=int,
                    help='seed for initializing training. ')
parser.add_argument('--disable-cuda', action='store_true',
                    help='Disable CUDA')
parser.add_argument('--fp16-precision', action='store_true',
                    help='Whether or not to use 16-bit precision GPU training.')

parser.add_argument('--out_dim', default=128, type=int,
                    help='feature dimension (default: 128)')
parser.add_argument('--log-every-n-steps', default=100, type=int,
                    help='Log every n steps')
parser.add_argument('--temperature', default=0.07, type=float,
                    help='softmax temperature (default: 0.07)')
parser.add_argument('--n-views', default=2, type=int, metavar='N',
                    help='Number of views for contrastive learning training.')
parser.add_argument('--gpu-index', default=0, type=int, help='Gpu index.')

parser.add_argument('--optim', type=str, default='adam')


def main():
    args = parser.parse_args()
    assert args.n_views == 2, "Only two view training is supported. Please use --n-views 2."
    # check if gpu training is available
    if not args.disable_cuda and torch.cuda.is_available():
        args.device = torch.device('cuda')
    else:
        args.device = torch.device('cpu')
        args.gpu_index = -1

    dataset = ContrastiveLearningDataset(args.data)

    train_dataset = dataset.get_dataset(args.dataset_name, args.n_views)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True, drop_last=True)

    model = ResNetSimCLR(base_model=args.arch, out_dim=args.out_dim).to(args.device)
    if args.optim in ['dog', 'ldog']:
        model_ema = ModelAverage(model)
    else:
        model_ema = ModelEMA(model, 0.99)

    if args.optim == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), args.lr, weight_decay=args.weight_decay)
    elif args.optim == 'sps':
        optimizer = optimizers.Sps(model.parameters(), 1.0, weight_decay=args.weight_decay)
    elif args.optim == 'dog':
        optimizer = optimizers.DoG(model.parameters(), 1.0, weight_decay=args.weight_decay)
    elif args.optim == 'dasgd':
        optimizer = optimizers.DAdaptSGD(model.parameters(), 1.0, weight_decay=args.weight_decay)
    elif args.optim == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay)
    elif args.optim == 'cocob':
        optimizer = optimizers.COCOB(model.parameters(), 1.0, weight_decay=args.weight_decay)
    elif args.optim == 'ldog':
        optimizer = optimizers.LDoG(model.parameters(), 1.0, weight_decay=args.weight_decay)
    elif args.optim == 'daadam':
        optimizer = optimizers.DAdaptAdam(model.parameters(), 1.0, weight_decay=args.weight_decay)
    elif args.optim == 'prodigy':
        optimizer = optimizers.Prodigy(model.parameters(), 1.0, weight_decay=args.weight_decay)
    elif args.optim == 'pssps':
        optimizer = optimizers.PSSps(model.parameters(), 1.0, weight_decay=args.weight_decay)
    elif args.optim == 'psdasgd':
        optimizer = optimizers.PSDASGD(model.parameters(), 1.0, weight_decay=args.weight_decay)

    if args.optim in ['sps', 'dog', 'cocob', 'ldog', 'pssps']:
        scheduler = None
    else:
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=0,
                                                           last_epoch=-1)

    #  It’s a no-op if the 'gpu_index' argument is a negative integer or None.
    with torch.cuda.device(args.gpu_index):
        simclr = SimCLR(model=model, model_ema=model_ema, optimizer=optimizer, scheduler=scheduler, args=args)
        simclr.train(train_loader)


if __name__ == "__main__":
    main()
