from __future__ import print_function, division, absolute_import
import csv
import os
import os.path
import tarfile
from six.moves.urllib.parse import urlparse

import numpy as np
import torch
import torch.utils.data as data
from PIL import Image
import random
import glob
from tqdm import tqdm
from six.moves.urllib.request import urlretrieve

object_categories = ['aeroplane', 'bicycle', 'bird', 'boat',
                     'bottle', 'bus', 'car', 'cat', 'chair',
                     'cow', 'diningtable', 'dog', 'horse',
                     'motorbike', 'person', 'pottedplant',
                     'sheep', 'sofa', 'train', 'tvmonitor']
object_categories_g = {'places205':[str(i) for i in range(205)],
                       'flowers102':[str(i) for i in range(102)],
                       'cars196':[f"[{i}]" for i in range(196)],
                       'herba19':[str(i) for i in range(683)],
                       }
urls = {
    'devkit': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCdevkit_18-May-2011.tar',
    'trainval_2007': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar',
    'test_images_2007': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar',
    'test_anno_2007': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtestnoimgs_06-Nov-2007.tar',
}


def read_image_label(file):
    print('[dataset] read ' + file)
    data = dict()
    with open(file, 'r') as f:
        for line in f:
            tmp = line.split(' ')
            name = tmp[0]
            label = int(tmp[-1])
            data[name] = label
            # data.append([name, label])
            # print('%s  %d' % (name, label))
    return data


def read_object_labels2(root, set, dataname):
    labeled_data = dict()
    num_classes = len(object_categories_g[dataname])
    # labels = np.zeros(num_classes) -1

    for i in range(num_classes):
        data = glob.glob(os.path.join(root, set, object_categories_g[dataname][i])+'/*.jpg')
        for name in data:
            labeled_data[name] = np.zeros(num_classes) - 1
            labeled_data[name][i] = 1
    return labeled_data

def read_object_labels(root, dataset, set):
    path_labels = os.path.join(root, 'VOCdevkit', dataset, 'ImageSets', 'Main')
    labeled_data = dict()
    num_classes = len(object_categories_g)

    for i in range(num_classes):
        file = os.path.join(path_labels, object_categories_g[i] + '_' + set + '.txt')
        data = read_image_label(file)

        if i == 0:
            for (name, label) in data.items():
                labels = np.zeros(num_classes)
                labels[i] = label
                labeled_data[name] = labels
        else:
            for (name, label) in data.items():
                labeled_data[name][i] = label
    return labeled_data


def write_object_labels_csv(file, labeled_data):
    # write a csv file
    print('[dataset] write file %s' % file)
    with open(file, 'w') as csvfile:
        fieldnames = ['name']
        fieldnames.extend(object_categories)
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

        writer.writeheader()
        for (name, labels) in labeled_data.items():
            example = {'name': name}
            for i in range(20):
                example[fieldnames[i + 1]] = int(labels[i])
            writer.writerow(example)
    csvfile.close()

def write_object_labels_csv2(file, labeled_data, dataname):
    # write a csv file
    # print('[dataset] write file %s' % file)
    with open(file, 'w') as csvfile:
        fieldnames = ['name']
        fieldnames.extend(object_categories_g[dataname])
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

        writer.writeheader()
        for (name, labels) in labeled_data.items():
            example = {'name': name}
            for i in range(len(object_categories_g[dataname])):
                example[fieldnames[i + 1]] = np.int(labels[i])
            writer.writerow(example)
    csvfile.close()

def read_object_labels_csv(file, header=True):
    images = []
    num_categories = 0
    # print('[dataset] read', file)
    with open(file, 'r') as f:
        reader = csv.reader(f)
        rownum = 0
        for row in reader:
            if header and rownum == 0:
                header = row
            else:
                if num_categories == 0:
                    num_categories = len(row) - 1
                name = row[0]
                labels = (np.asarray(row[1:num_categories + 1])).astype(np.float32)
                labels = torch.from_numpy(labels)
                item = (name, labels)
                images.append(item)
            rownum += 1
    return images


def find_images_classification(root, dataset, set):
    path_labels = os.path.join(root, 'VOCdevkit', dataset, 'ImageSets', 'Main')
    images = []
    file = os.path.join(path_labels, set + '.txt')
    with open(file, 'r') as f:
        for line in f:
            images.append(line)
    return images


def download_voc2007(root):
    path_devkit = os.path.join(root, 'VOCdevkit')
    path_images = os.path.join(root, 'VOCdevkit', 'VOC2007', 'JPEGImages')
    tmpdir = os.path.join(root, 'tmp')

    # create directory
    if not os.path.exists(root):
        os.makedirs(root)

    if not os.path.exists(path_devkit):

        if not os.path.exists(tmpdir):
            os.makedirs(tmpdir)

        parts = urlparse(urls['devkit'])
        filename = os.path.basename(parts.path)
        cached_file = os.path.join(tmpdir, filename)

        if not os.path.exists(cached_file):
            print('Downloading: "{}" to {}\n'.format(urls['devkit'], cached_file))
            download_url(urls['devkit'], cached_file)

        # extract file
        # print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root))
        cwd = os.getcwd()
        tar = tarfile.open(cached_file, "r")
        os.chdir(root)
        tar.extractall()
        tar.close()
        os.chdir(cwd)
        # print('[dataset] Done!')

    # train/val images/annotations
    if not os.path.exists(path_images):

        # download train/val images/annotations
        parts = urlparse(urls['trainval_2007'])
        filename = os.path.basename(parts.path)
        cached_file = os.path.join(tmpdir, filename)

        if not os.path.exists(cached_file):
            print('Downloading: "{}" to {}\n'.format(urls['trainval_2007'], cached_file))
            download_url(urls['trainval_2007'], cached_file)

        # extract file
        # print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root))
        cwd = os.getcwd()
        tar = tarfile.open(cached_file, "r")
        os.chdir(root)
        tar.extractall()
        tar.close()
        os.chdir(cwd)
        # print('[dataset] Done!')

    # test annotations
    test_anno = os.path.join(path_devkit, 'VOC2007/ImageSets/Main/aeroplane_test.txt')
    if not os.path.exists(test_anno):

        # download test annotations
        parts = urlparse(urls['test_images_2007'])
        filename = os.path.basename(parts.path)
        cached_file = os.path.join(tmpdir, filename)

        if not os.path.exists(cached_file):
            print('Downloading: "{}" to {}\n'.format(urls['test_images_2007'], cached_file))
            download_url(urls['test_images_2007'], cached_file)

        # extract file
        print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root))
        cwd = os.getcwd()
        tar = tarfile.open(cached_file, "r")
        os.chdir(root)
        tar.extractall()
        tar.close()
        os.chdir(cwd)
        print('[dataset] Done!')

    # test images
    test_image = os.path.join(path_devkit, 'VOC2007/JPEGImages/000001.jpg')
    if not os.path.exists(test_image):

        # download test images
        parts = urlparse(urls['test_anno_2007'])
        filename = os.path.basename(parts.path)
        cached_file = os.path.join(tmpdir, filename)

        if not os.path.exists(cached_file):
            print('Downloading: "{}" to {}\n'.format(urls['test_anno_2007'], cached_file))
            download_url(urls['test_anno_2007'], cached_file)

        # extract file
        print('[dataset] Extracting tar file {file} to {path}'.format(file=cached_file, path=root))
        cwd = os.getcwd()
        tar = tarfile.open(cached_file, "r")
        os.chdir(root)
        tar.extractall()
        tar.close()
        os.chdir(cwd)
        print('[dataset] Done!')

def download_url(url, destination=None, progress_bar=True):
    """Download a URL to a local file.
    Parameters
    ----------
    url : str
        The URL to download.
    destination : str, None
        The destination of the file. If None is given the file is saved to a temporary directory.
    progress_bar : bool
        Whether to show a command-line progress bar while downloading.
    Returns
    -------
    filename : str
        The location of the downloaded file.
    Notes
    -----
    Progress bar use/example adapted from tqdm documentation: https://github.com/tqdm/tqdm
    """

    def my_hook(t):
        last_b = [0]

        def inner(b=1, bsize=1, tsize=None):
            if tsize is not None:
                t.total = tsize
            if b > 0:
                t.update((b - last_b[0]) * bsize)
            last_b[0] = b

        return inner

    if progress_bar:
        with tqdm(unit='B', unit_scale=True, miniters=1, desc=url.split('/')[-1]) as t:
            filename, _ = urlretrieve(url, filename=destination, reporthook=my_hook(t))
    else:
        filename, _ = urlretrieve(url, filename=destination)

class Voc2007Classification(data.Dataset):

    def __init__(self, root, set, transform=None, target_transform=None):
        self.root = root
        self.path_devkit = os.path.join(root, 'VOCdevkit')
        self.path_images = os.path.join(root, 'VOCdevkit', 'VOC2007', 'JPEGImages')
        self.set = set
        self.transform = transform
        self.target_transform = target_transform
        self.low_shot = False

        # download dataset
        download_voc2007(self.root)

        # define path of csv file
        path_csv = os.path.join('/scratch/shared/beegfs/yuki/data/', 'VOC2007')
        # define filename of csv file
        file_csv = os.path.join(path_csv, 'classification_' + set + '.csv')

        # create the csv file if necessary
        if not os.path.exists(file_csv):
            if not os.path.exists(path_csv):  # create dir if necessary
                os.makedirs(path_csv)
            # generate csv file
            labeled_data = read_object_labels(self.root, 'VOC2007', self.set)
            # write csv file
            write_object_labels_csv(file_csv, labeled_data)

        self.classes = object_categories
        self.images = read_object_labels_csv(file_csv)

        # print('[dataset] VOC 2007 classification set=%s number of classes=%d  number of images=%d' % (
        #     set, len(self.classes), len(self.images)))

    def __getitem__(self, index):
        if self.low_shot:
            path, target = self.images_lowshot[index]
        else:
            path, target = self.images[index]
        img = Image.open(os.path.join(self.path_images, path + '.jpg')).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target

    def __len__(self):
        if self.low_shot:
            return len(self.images_lowshot)
        else:
            return len(self.images)

    def get_number_classes(self):
        return len(self.classes)


    def convert_low_shot(self, k): #sample k images per class
        label2img = {c:[] for c in range(len(self.classes))}
        for img in self.images:
            label = img[1]
            label_classes = torch.where(label>0)[0]
            for c in label_classes:
                label2img[c.item()].append(img)

        self.images_lowshot = []
        for c,imlist in label2img.items():
            random.shuffle(imlist)
            self.images_lowshot += imlist[:k]
        self.low_shot = True


class ImageFolderLowshot(data.Dataset):

    def __init__(self, root, set, dataname='places205', transform=None, target_transform=None):
        self.root = root
        self.set = set
        self.transform = transform
        self.target_transform = target_transform
        self.low_shot = False
        # define path of csv file
        path_csv = os.path.join('/scratch/shared/beegfs/yuki/data/', 'lowshot')
        # define filename of csv file
        file_csv = os.path.join(path_csv, 'classification_'+dataname +'_' + set + '.csv')

        # create the csv file if necessary
        if not os.path.exists(file_csv):
            if not os.path.exists(path_csv):  # create dir if necessary
                os.makedirs(path_csv)
            # generate csv file
            labeled_data = read_object_labels2(self.root, self.set, dataname)
            # write csv file
            write_object_labels_csv2(file_csv, labeled_data, dataname)

        self.classes = object_categories_g[dataname]
        self.images = read_object_labels_csv(file_csv)

        print('[dataset] %s classification set=%s number of classes=%d  number of images=%d' % (dataname,
            set, len(self.classes), len(self.images)))

    def __getitem__(self, index):
        if self.low_shot:
            path, target = self.images_lowshot[index]
        else:
            path, target = self.images[index]
        img = Image.open(os.path.join(path)).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target

    def __len__(self):
        if self.low_shot:
            return len(self.images_lowshot)
        else:
            return len(self.images)

    def get_number_classes(self):
        return len(self.classes)


    def convert_low_shot(self, k): # sample k images per class
        label2img = {c:[] for c in range(len(self.classes))}
        for img in self.images:
            label = img[1]
            label_classes = torch.where(label>0)[0]
            for c in label_classes:
                label2img[c.item()].append(img)

        self.images_lowshot = []
        for c,imlist in label2img.items():
            random.shuffle(imlist)
            self.images_lowshot += imlist[:k]
        self.low_shot = True