from random import sample
from torch.utils.data import DataLoader
import torch
from torchvision import datasets, transforms
import os
import pandas as pd
from . import dataset_DCL
from . import config
# import dataset_DCL
# import config
import argparse

def parse_args():
    parser = argparse.ArgumentParser(description='dcl parameters')
    parser.add_argument('--data', dest='dataset',
    default=None, type=str)
    parser.add_argument('--backbone', dest='backbone',
    default='resnet50', type=str)
    parser.add_argument('--b', dest='batch_size',
    default=16, type=int)
    parser.add_argument('--nw', dest='num_workers',
    default=16, type=int)
    parser.add_argument('--ver', dest='version',
    default='test', type=str)
    parser.add_argument('--save', dest='resume',
    default=None, type=str)
    parser.add_argument('--size', dest='resize_resolution',
    default=512, type=int)
    parser.add_argument('--crop', dest='crop_resolution',
    default=448, type=int)
    parser.add_argument('--ss', dest='save_suffix',
    default=None, type=str)
    parser.add_argument('--acc_report', dest='acc_report',
    action='store_true')
    parser.add_argument('--swap_num', default=[7, 7],
    nargs=2, metavar=('swap1', 'swap2'),
    type=int, help='specify a range')
    args = parser.parse_args()
    return args

def get_cub(batch_size, **kwargs):
    args = parse_args()
    args.data = 'CUB'
    args.dataset = 'CUB'
    Config = config.LoadConfig(args, 'test')
    anno = pd.read_csv(os.path.join(Config.anno_root, 'test.txt'),\
                                    sep=" ",\
                                    header=None,\
                                    names=['ImageName', 'label'])

    transformers = config.load_data_transformers(512, 448, [7,7])
    data_set = dataset_DCL.dataset(Config,
                       anno=anno,
                       swap=transformers["None"],
                       totensor=transformers['test_totensor'],
                       test=True)

    dataloader = torch.utils.data.DataLoader(data_set,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=8,
                                             collate_fn=dataset_DCL.collate_fn4test)

    return dataloader

def get_car(batch_size, **kwargs):
    args = parse_args()
    args.data = 'STCAR'
    args.dataset = 'STCAR'
    # args.data = data_name
    Config = config.LoadConfig(args, 'test')
    anno = pd.read_csv(os.path.join(Config.anno_root, 'test.txt'),\
                                    sep=" ",\
                                    header=None,\
                                    names=['ImageName', 'label'])

    transformers = config.load_data_transformers(512, 448, [7,7])
    data_set = dataset_DCL.dataset(Config,
                       anno=anno,
                       swap=transformers["None"],
                       totensor=transformers['test_totensor'],
                       test=True)

    dataloader = torch.utils.data.DataLoader(data_set,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=8,
                                             collate_fn=dataset_DCL.collate_fn4test)

    return dataloader


def get_air(batch_size, **kwargs):
    args = parse_args()
    args.data = 'AIR'
    args.dataset = 'AIR'
    Config = config.LoadConfig(args, 'test')
    anno = pd.read_csv(os.path.join(Config.anno_root, 'test.txt'),\
                                    sep=" ",\
                                    header=None,\
                                    names=['ImageName', 'label'])

    transformers = config.load_data_transformers(512, 448, [2,2])
    data_set = dataset_DCL.dataset(Config,
                       anno=anno,
                       swap=transformers["None"],
                       totensor=transformers['test_totensor'],
                       test=True)

    dataloader = torch.utils.data.DataLoader(data_set,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             num_workers=8,
                                             collate_fn=dataset_DCL.collate_fn4test)

    return dataloader


if __name__ == '__main__':
    data = get_cub(2)
    for img, label in data:
        print(img.shape)
        print(label.shape)