import argparse
import os
import torch
import torch.optim as optim
from thop import profile, clever_format
from torch.utils.data import DataLoader
from tqdm import tqdm

import utils
from model import Model, ModelMNIST, ModelImageNet


parser = argparse.ArgumentParser(description='Train SimCLR')
parser.add_argument('--feature_dim', default=300, type=int, help='Feature dim for latent vector')
parser.add_argument('--out_dim', default=1024, type=int, help='Out dim for latent vector')
parser.add_argument('--temperature', default=0.5, type=float, help='Temperature used in softmax')
parser.add_argument('--k', default=3, type=int, help='Top k most similar images used to predict the label')
parser.add_argument('--batch_size', default=512, type=int, help='Number of images in each mini-batch')
parser.add_argument('--epochs', default=100, type=int, help='Number of sweeps over the dataset to train')
parser.add_argument('--saved_feature_file', type=str, default='cifar10_featues_SimCLR.dat')
parser.add_argument('--dataset', type=str, default='cifar10')
# args parse
args = parser.parse_args()
feature_dim, out_dim, temperature, k = args.feature_dim, args.out_dim, args.temperature, args.k
batch_size, epochs = args.batch_size, args.epochs

# data prepare

# model setup and optimizer config
save_name_pre = '{}_{}_{}_{}_{}_{}_{}'.format(args.dataset, feature_dim, out_dim, temperature, k, batch_size, epochs)
if args.dataset == 'cifar10':
    memory_data = utils.CIFAR10Pair(root='../data/CIFAR10/', train=True, transform=utils.test_transform, download=False)
    memory_loader = DataLoader(memory_data, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

    test_data = utils.CIFAR10Pair(root='../data/CIFAR10/', train=False, transform=utils.test_transform, download=False)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

    model = Model(feature_dim, out_dim).cuda()
    flops, params = profile(model, inputs=(torch.randn(1, 3, 32, 32).cuda(),))
elif args.dataset == 'mnist':
    memory_data = utils.MNISTPair(root='../data/MNIST/', train=True, transform=utils.mnist_test_transform, download=False)
    memory_loader = DataLoader(memory_data, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

    test_data = utils.MNISTPair(root='../data/MNIST/', train=False, transform=utils.mnist_test_transform, download=False)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

    model = ModelMNIST(feature_dim).cuda()
    flops, params = profile(model, inputs=(torch.randn(1, 1, 28, 28).cuda(),))
elif args.dataset == 'imagenet':
    data_transforms = {'train': utils.imagenet_train_transform, 'test': utils.imagenet_test_transform}
    memory_data = utils.ImageNetPair(os.path.join('../data/ImageNet', 'train'), data_transforms['test'])
    test_data = utils.ImageNetPair(os.path.join('../data/ImageNet', 'val'), data_transforms['test'])
    memory_loader = torch.utils.data.DataLoader(memory_data, batch_size=batch_size, num_workers=8, shuffle=False, pin_memory=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, num_workers=8, shuffle=False, pin_memory=True)
    # model setup and optimizer config
    model = ModelImageNet(feature_dim).cuda()
    flops, params = profile(model, inputs=(torch.randn(1, 3, 224, 224).cuda(),))
    
model.load_state_dict(torch.load('results/{}_model.pth'.format(save_name_pre)))
model = model.cuda()
model.eval()

flops, params = clever_format([flops, params])
print('# Model Params: {} FLOPs: {}'.format(params, flops))
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-6)
c = len(memory_data.classes)

train_features, train_labels = [], []
test_features, test_labels = [], []

with torch.no_grad():
    # generate feature bank
    for data, _, target in tqdm(memory_loader, desc='Feature extracting'):
        feature, out = model(data.cuda(non_blocking=True))
        train_features.append(feature)
    train_features = torch.cat(train_features, dim=0)
    train_labels = torch.tensor(memory_loader.dataset.targets, device=train_features.device)
    
    test_bar = tqdm(test_loader)
    for data, _, target in test_bar:
        data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
        feature, out = model(data)
        test_features.append(feature)
    test_features = torch.cat(test_features, dim=0)
    test_labels = torch.tensor(test_loader.dataset.targets, device=test_features.device)

print("Train: {}, Test: {}".format(train_features.shape, test_features.shape))
print("saving features ...")
torch.save({'trX': train_features.cpu().numpy(), 'trY': train_labels.cpu().numpy(), 
            'teX': test_features.cpu().numpy(), 'teY': test_labels.cpu().numpy() }, 'results/{}.dat'.format(save_name_pre))
