import pandas as pd
import numpy as np
import tarfile
from PIL import Image
import tqdm
import json as pyjson

class Shard(object):
    """Sugar-coated class for reading shards

    Args:
        object (_type_): _description_
    """
    def __init__(self, shard_path):
        """Files are given as
        - image: {prefix}.jpg
        - caption: {prefix}.txt
        - metadata: {prefix}.json
        Args:
            shard_path (_type_): _description_
        """
        self.tar = tarfile.open(shard_path)
        self.prefix_list = list(set([fname.split('.')[0] for fname in self.tar.getnames()]))
    
    def read_triplet(self, prefix):
        """Read a triplet of image, caption, and metadata

        Args:
            prefix (_type_): _description_

        Returns:
            _type_: _description_
        """
        img = Image.open(self.tar.extractfile(prefix + '.jpg'))
        txt = self.tar.extractfile(prefix + '.txt').read()
        json = self.tar.extractfile(prefix + '.json').read()
        return (img, txt, json)
    
    def read_pair(self, prefix):
        """Read a pair of image, caption

        Args:
            prefix (_type_): _description_

        Returns:
            _type_: _description_
        """
        img = Image.open(self.tar.extractfile(prefix + '.jpg'))
        txt = self.tar.extractfile(prefix + '.txt').read()
        return (img, txt)
    
    def read_all_pairs(self):
        """Read all pairs in prefix_list
        TODO: optimize this function with multiprocessing, if needed

        Returns:
            _type_: _description_
        """
        img_list = []
        txt_list = []
        for prefix in tqdm.tqdm(self.prefix_list, desc="Reading all pairs in a shard..."):
            img, txt = self.read_pair(prefix)
            img_list.append(img)
            txt_list.append(txt)
            
        return img_list, txt_list
    
    
    def read_all_triplets(self):
        """Read all triplets in prefix_list
        TODO: optimize this function with multiprocessing, if needed
        Returns:
            _type_: _description_
        """
        img_list = []
        txt_list = []
        json_list = []
        
        for prefix in tqdm.tqdm(self.prefix_list, desc="Reading all triplets in a shard..."):
            img, txt, json = self.read_triplet(prefix)
            img_list.append(img)
            txt_list.append(txt)
            json_list.append(json)
            
        return img_list, txt_list, json_list
    
    def read_all_triplets_with_blacklist_filter(self, blacklist):
        """
        Read triplets not in blacklist
        """
        img_list = []
        txt_list = []
        json_list = []
        
        blacklist = set(blacklist)
        
        for prefix in tqdm.tqdm(self.prefix_list, desc="Reading blacklist-filtered triplets in a shard..."):
            json = pyjson.load(self.tar.extractfile(f'{prefix}.json'))
            if json['uid'] in blacklist:
                continue
            else:
                img = Image.open(self.tar.extractfile(prefix + '.jpg'))
                txt = self.tar.extractfile(prefix + '.txt').read()
                
                img_list.append(img)
                txt_list.append(txt)
                json_list.append(json)
                
        return img_list, txt_list, json_list
    
    def read_all_triplets_with_whitelist_filter(self, whitelist):
        """
        Read triplets in whitelist
        """
        
        img_list = []
        txt_list = []
        json_list = []
        
        whitelist = set(whitelist)
        
        for prefix in tqdm.tqdm(self.prefix_list, desc="Reading whitelist-filtered triplets in a shard..."):
            json = pyjson.load(self.tar.extractfile(f'{prefix}.json'))
            if json['uid'] in whitelist:
                img = Image.open(self.tar.extractfile(prefix + '.jpg'))
                txt = self.tar.extractfile(prefix + '.txt').read()
                
                img_list.append(img)
                txt_list.append(txt)
                json_list.append(json)
                
            else:
                continue
                
        return img_list, txt_list, json_list

def process_uids(uids, sort=True):
    """ Convert hexadecimal strings into two decimal integers

    Args:
        uids (_type_): _description_
        sort (_type_, optional): _description_. Defaults to False.

    Returns:
        _type_: _description_
    """

    
    processed_uids = np.array([(int(uid[:16], 16), int(uid[16:32], 16)) for uid in uids], np.dtype("u8,u8"))
    if sort:
        processed_uids.sort()
    return processed_uids

def inverse_process_uids(processed_uids):
    """ Convert two decimal integers into hexadecimal strings

    Args:
        processed_uids (_type_): _description_

    Returns:
        _type_: _description_
    """
    return [hex(uid[0])[2:] + hex(uid[1])[2:] for uid in processed_uids]
