import numpy as np
import pickle
import h5py
import os
from tqdm import tqdm
import torch
import argparse
import random
from skimage.transform import resize


def convert_idx_masks_to_bool(masks):
    """
    input: masks (1, img_dim1, img_dim2)
    output: masks_bool (num_masks, img_dim1, img_dim2)
    """
    unique_idxs = torch.sort(torch.unique(masks)).values
    idxs = unique_idxs.view(-1, 1, 1)
    broadcasted_masks = masks.expand(unique_idxs.shape[0], 
                                     masks.shape[1], 
                                     masks.shape[2])
    masks_bool = (broadcasted_masks == idxs)
    return masks_bool


def compress_masks(data, resize_to=None, min_size=0):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    bsz, img_dim_1, img_dim_2 = data.shape

    data = torch.tensor(data).int()

    if resize_to is not None:
        data_resized = torch.stack([(torch.tensor(resize(mask, 
                                                        (resize_to, resize_to), 
                                                        preserve_range=True)) > 0.5).int() \
                            for mask in data.cpu().numpy()])
        data = data_resized

    data_count = data.sum(dim=-1).sum(dim=-1)
    _, sorted_indices = torch.sort(data_count, descending=True)

    data = data[sorted_indices]  # sorted masks

    masks = torch.zeros(data.shape[-2], data.shape[-1])

    count = 1
    for mask in data:
        new_mask = mask.bool() ^ (mask.bool() & masks.bool())
        if torch.sum(new_mask) >= min_size:
            masks[new_mask] = count
            count += 1

    masks = masks - 1
    masks = masks.int()
    masks[masks == -1] = torch.max(masks) + 1
    
    return masks.numpy()

def parse_args():
    parser = argparse.ArgumentParser()

    # paths and info
    parser.add_argument('--input-dir', type=str, 
                        default='../datasets/flowers/sam_seg',
                        help='input dir')
    parser.add_argument('--output-dir', type=str, 
                        default='../datasets/flowers/sam_seg_compressed.h5',
                        help='output dir')
    parser.add_argument('--class-names-filepath', type=str, 
                        default=None,
                        help='class names filepath')
    parser.add_argument('--seed', type=int, 
                        default=42, 
                        help='seed')
    parser.add_argument('--flat',
                        default=False,
                        action='store_true', 
                        help='if true, do flat')
    parser.add_argument('--resize-to', type=int, 
                        default=None, 
                        help='resize')
    parser.add_argument('--min-size', type=int, 
                        default=0, 
                        help='resize')
    
    return parser
    

def main():
    parser = parse_args()
    args = parser.parse_args()

    print('\n---argparser---:')
    for arg in vars(args):
        print(arg, getattr(args, arg), '\t', type(arg))

    # if args.sweep == 1:
    #     args.track = True
    #     # args.shuffle = True
    #     args.freeze = True
    #     # args.block_cuda = True
    # os.makedirs(args.output_dir, exist_ok=True)

    # Set the seed for reproducibility
    if args.seed != -1:
        # Torch RNG
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        # Python RNG
        np.random.seed(args.seed)
        random.seed(args.seed)

    input_dir=args.input_dir
    output_dir=args.output_dir

    total_masks = 0.
    total_count = 0

    with h5py.File(output_dir, 'w') as h5_file:
        if not args.flat:
            if args.class_names_filepath is None:
                dirnames = list(sorted(os.listdir(args.input_dir)))
            else:
                with open(args.class_names_filepath, 'rt') as input_file:
                    dirnames = [line.strip()[1:] for line in input_file.readlines()]
            # print('dirnames', dirnames)
            # import pdb
            # pdb.set_trace()
            for subdir in tqdm(dirnames):
                group = h5_file.create_group(subdir)
                src_dir = os.path.join(input_dir, subdir)

                for filename in tqdm(os.listdir(src_dir)):
                    if filename.endswith('.pkl'):
                        # output_filename = filename.replace('.pkl', '.h5')
                        dataset_name = filename.replace('.pkl', '')
                        input_path = os.path.join(src_dir, filename)
                        # output_path = os.path.join(dest_dir, output_filename)
                        with open(input_path, 'rb') as pickle_file:
                            data = pickle.load(pickle_file)
                        if len(data.shape) == 2:
                            data_tensor = torch.tensor(data).unsqueeze(0)
                            data_bool = convert_idx_masks_to_bool(data_tensor)
                            data = data_bool.cpu().numpy()
                        if len(data) > 0:
                            mapped_array = compress_masks(data, 
                                                          resize_to=args.resize_to,
                                                          min_size=args.min_size)
                        else:
                            mapped_array = data
                        
                        try:
                            total_masks += mapped_array.max() + 1
                        except:
                            import pdb
                            pdb.set_trace()
                        total_count += 1
                        
                        group.create_dataset(dataset_name, data=mapped_array, 
                                            compression='gzip', compression_opts=9)
        else:
            src_dir = input_dir
            for filename in tqdm(os.listdir(src_dir)):
                if filename.endswith('.pkl'):
                    # output_filename = filename.replace('.pkl', '.h5')
                    dataset_name = filename.replace('.pkl', '')
                    input_path = os.path.join(src_dir, filename)
                    # output_path = os.path.join(dest_dir, output_filename)
                    with open(input_path, 'rb') as pickle_file:
                        data = pickle.load(pickle_file)
                    if len(data.shape) == 2:
                        data_tensor = torch.tensor(data).unsqueeze(0)
                        data_bool = convert_idx_masks_to_bool(data_tensor)
                        data = data_bool.cpu().numpy()
                    if len(data) > 0:
                        mapped_array = compress_masks(data, 
                                                      resize_to=args.resize_to,
                                                      min_size=args.min_size)
                    else:
                        mapped_array = data

                    total_masks += mapped_array.max() + 1
                    total_count += 1
                    
                    h5_file.create_dataset(dataset_name, data=mapped_array, 
                                           compression='gzip', compression_opts=9)
    print('avg number of masks', total_masks / total_count)


if __name__ == '__main__':
    main()