import argparse
import torch
import torch.nn as nn
import utils.datasets as dl
from utils.load_trained_model import load_model
import pathlib
import matplotlib as mpl
from tqdm import tqdm, trange
import numpy as np
from multiprocessing import Pool
import time
import ssl_utils as ssl
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os

torch.backends.cudnn.benchmark = True

mpl.use('Agg')
import matplotlib.pyplot as plt

def main(gpu, model_name, model_architecture, model_checkpoint, dataset, cifar_subset, od_dataset,  unlabeled_samples):
    bs = 1024

    if len(gpu)==0:
        device_ids = None
        device = torch.device('cpu')
        print('Warning! Computing on CPU')
    elif len(gpu)==1:
        device_ids = None
        device = torch.device('cuda:' + str(gpu[0]))
        bs = bs
    else:
        device_ids = [int(i) for i in gpu]
        device = torch.device('cuda:' + str(device_ids[0]))
        bs = bs * len(device_ids)

    path = os.path.join( 'DatasetClassifications/',  f'{dataset}_{cifar_subset}_{od_dataset}_{unlabeled_samples}')
    pathlib.Path(path).mkdir(parents=True, exist_ok=True)

    if model_name is not None:
        print('Using passed model name')
        model_descriptions = [
            (model_architecture, model_name, model_checkpoint, None, False),
        ]
    else:
        model_descriptions = [
            (model_architecture, 'CEDA_30-08-2021_14:21:48', 'best', None, False),
        ]

    type, folder, checkpoint, temperature, temp = model_descriptions[0]

    if 'BiT' in type:
        img_size = 128
    else:
        img_size = 32

    model = load_model(type, folder, checkpoint, temperature, device, load_temp=temp, dataset=dataset)
    if device_ids is not None and len(device_ids) > 1:
        model = nn.DataParallel(model, device_ids=device_ids)

    print(f'Dataset {dataset} - subset {cifar_subset} - OD Dataset {od_dataset} - {unlabeled_samples}')

    model.eval()


    print(f'{folder} - {dataset}')

    #standard implementation with multithreaded rescaling
    idx = 0
    if dataset != 'cifar10':
        raise NotImplementedError()

    od_loader = ssl.get_CIFAR10_subset_plus_OD('unlabeled', cifar_subset // 10, od_dataset, unlabeled_samples,
                                               batch_size=bs, augm_type='none', num_workers=8, shuffle=False,
                                               size=img_size)
    model_outs = torch.zeros((len(od_loader.dataset), 10))



    pbar = tqdm(total=len(od_loader), bar_format='{l_bar}{bar:5}{r_bar}{bar:-5b}')
    with torch.no_grad():
        for batch, _ in od_loader:
            idx_next = idx + batch.shape[0]

            batch = batch.to(device)
            model_out_i = model(batch)
            model_outs[idx:idx_next] = model_out_i.detach().cpu()
            idx = idx_next

            pbar.update(1)

    torch.save(model_outs, os.path.join(path, f'{folder}.pt'))

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Parse arguments.', prefix_chars='-')

    parser.add_argument('--gpu', '--list', nargs='+', default=[0],
                        help='GPU indices, if more than 1 parallel modules will be called')
    parser.add_argument('--dataset', type=str, default='cifar10')
    parser.add_argument('--cifar_subset', type=int, default=4000)
    parser.add_argument('--od_dataset', type=str, default='tinyImages_subset')
    parser.add_argument('--unlabeled_samples', type=int, default=1_000_000)

    parser.add_argument('--model_architecture', type=str, default='WideResNet28x2')
    parser.add_argument('--model_name', type=str, default=None)
    parser.add_argument('--model_checkpoint', type=str, default='best')

    hps = parser.parse_args()

    gpu = hps.gpu
    dataset = hps.dataset
    cifar_subset = hps.cifar_subset
    od_dataset = hps.od_dataset
    unlabeled_samples = hps.unlabeled_samples

    model_name = hps.model_name
    model_architecture = hps.model_architecture
    model_checkpoint = hps.model_checkpoint
    main(gpu, model_name, model_architecture, model_checkpoint, dataset, cifar_subset, od_dataset,  unlabeled_samples)

