from __future__ import print_function
import os
import numpy as np
import os.path as osp
import random
import xml.etree.ElementTree as ET
import pickle
# from utils import *
ospj = osp.join
ospeu = osp.expanduser

def may_make_dir(path):
    """
    Args:
        path: a dir, or result of `os.path.dirname(os.path.abspath(file_path))`
    Note:
        `os.path.exists('')` returns `False`, while `os.path.exists('.')` returns `True`!
    """
    # This clause has mistakes:
    # if path is None or '':

    if path in [None, '']:
        return
    if not os.path.exists(path):
        os.makedirs(path)


def save_pickle(obj, path):
    """Create dir and save file."""
    may_make_dir(osp.dirname(osp.abspath(path)))
    with open(path, 'wb') as f:
        pickle.dump(obj, f, protocol=2)


def veri_partitions(train_path, paths_dic):
    test_path = paths_dic['test_path']
    partition_path = paths_dic['partition_path']

    ######## creat train ########
    train_im_names = []
    train_im_ids = []
    train_ids2labels = {}

    with open(train_path, 'r') as f:
        lines = f.readlines()
        for line in lines:
            tmp_data = line.rstrip().split(' ')
            train_im_names.append(tmp_data[0])
            train_im_ids.append(tmp_data[1])

    tmp_train_im_ids = list(set(train_im_ids))
    # tmp_train_im_ids.sort()
    random.seed = 1993
    random.shuffle(tmp_train_im_ids)

    # tmp_train_im_ids = [
    #     '68', '56', '78', '8', '23', '84', '90', '65', '74', '76', '40', '89', '3', '92', '55', '9', '26', '80',
    #     '43', '38', '58', '70', '77', '1', '85', '19', '17', '50', '28', '53', '13', '81', '45', '82', '6', '59',
    #     '83', '16', '15', '44', '91', '41', '72', '60', '79', '52', '20', '10', '31', '54', '37', '95', '14', '71',
    #     '96', '98', '97', '2', '64', '66', '42', '22', '35', '86', '24', '34', '87', '21', '99', '0', '88', '27',
    #     '18', '94', '11', '12', '47', '25', '30', '46', '62', '69', '36', '61', '7', '63', '75', '5', '32', '4',
    #     '51', '48', '73', '93', '39', '67', '29', '49', '57', '33'
    # ]
    # tmp_train_im_ids = [
    #     26, 31, 37, 86, 76, 14, 88, 48, 71, 67, 13, 84, 58, 40, 75, 94, 20, 57,
    #     61, 80, 90, 24, 50, 29, 54, 44, 78, 53, 52, 16, 49, 9, 69, 23, 74, 38,
    #     8, 59, 66, 72, 39, 51, 30, 89, 99, 22, 32, 77, 36, 43, 2, 21, 68, 96,
    #     81, 63, 42, 19, 5, 85, 45, 56, 41, 62, 10, 91, 98, 92, 27, 97, 83, 47,
    #     65, 46, 6, 70, 55, 60, 93, 33, 12, 0, 28, 35, 25, 7, 18, 73, 15, 11,
    #     82, 4, 79, 1, 95, 17, 87, 64, 3, 34
    # ]
    # tmp_train_im_ids = [str(i) for i in tmp_train_im_ids]
    print(tmp_train_im_ids)
    train_ids2labels.update(
        dict(zip(tmp_train_im_ids, list(range(len(tmp_train_im_ids))))))

    test_im_names = []  # osp.join('image_test', im_name)
    test_im_ids = []
    test_ids2labels = {}


    with open(test_path, 'r') as f:
        lines = f.readlines()
        for line in lines:
            tmp_data = line.rstrip().split(' ')
            test_im_names.append(tmp_data[0])
            test_im_ids.append(tmp_data[1])


    tmp_test_im_ids = list(set(test_im_ids))
    tmp_test_im_ids.sort()
    # test_ids2labels.update(
    #     dict(zip(tmp_test_im_ids, list(range(len(tmp_test_im_ids))))))


    new_partitions = {
        'train_im_names': train_im_names,
        'train_im_ids': train_im_ids,
        'train_ids2labels': train_ids2labels,
        'test_im_names': test_im_names,
        'test_im_ids': test_im_ids,
        'test_ids2labels': train_ids2labels,
    }

    save_pickle(new_partitions, partition_path)
    # print(tmp_train_im_ids)
    # print(tmp_test_im_ids)
    print('Partition file saved to {}'.format(partition_path))


if __name__ == '__main__':
    data_root_path = ''
    dataset_path = os.path.join(data_root_path, 'cifar-100')
    train_path = os.path.join(data_root_path, 'cifar-100/train_list.txt')
    test_path = os.path.join(data_root_path, 'cifar-100/test_list.txt')
    partition_path = osp.join(data_root_path, 'cifar-100/cifar-100.pkl')
    paths_dic = {
        'test_path': test_path,
        'partition_path': partition_path
    }
    veri_partitions(train_path, paths_dic)


# [68, 56, 78, 8, 23, 84, 90, 65, 74, 76, 40, 89, 3, 92, 55, 9, 26, 80, 43, 38, 58, 70, 77, 1, 85, 19, 17, 50, 28, 53, 13, 81, 45, 82, 6, 59, 83, 16, 15, 44, 91, 41, 72, 60, 79, 52, 20, 10, 31, 54, 37, 95, 14, 71, 96, 98, 97, 2, 64, 66, 42, 22, 35, 86, 24, 34, 87, 21, 99, 0, 88, 27, 18, 94, 11, 12, 47, 25, 30, 46, 62, 69, 36, 61, 7, 63, 75, 5, 32, 4, 51, 48, 73, 93, 39, 67, 29, 49, 57, 33]