import os
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_url
from PIL import Image
import pickle


class CUBDataset(Dataset):
    """
    Returns a compatible Torch Dataset object customized for the CUB dataset
    """
    name = 'cub'
    def __init__(self, root, split, transform=None):
        self.root = root
        
        self.pkl_urls = ['https://worksheets.codalab.org/rest/bundles/0x5b9d528d2101418b87212db92fea6683/contents/blob/class_attr_data_10/train.pkl', 
                    'https://worksheets.codalab.org/rest/bundles/0x5b9d528d2101418b87212db92fea6683/contents/blob/class_attr_data_10/test.pkl',
                    'https://worksheets.codalab.org/rest/bundles/0x5b9d528d2101418b87212db92fea6683/contents/blob/class_attr_data_10/val.pkl']
        self.cub_url = 'https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz?download=1'
        self.filename = 'CUB_200_2011.tgz'
        self.tgz_md5 = '97eceeb196236b17998738112f37df78'
        # Create folders
        os.makedirs(self.root,exist_ok=True)
        os.makedirs(os.path.join(self.root,'class_attr_data_10'), exist_ok=True)
        self.splits = ['train.pkl','test.pkl','val.pkl']
        self.pkl_file_paths = [os.path.join(self.root,'class_attr_data_10',split) for split in self.splits]
        # Download CUB
        self._download()
        self.data = []
        self.has_concepts = True
        self.is_train = any(["train" in path for path in self.pkl_file_paths])
        if not self.is_train:
            assert any([("test" in path) or ("val" in path) for path in self.pkl_file_paths])
        
        if split == 'train':
            self.data.extend(pickle.load(open(self.pkl_file_paths[0], 'rb')))
        elif split == 'test':
            self.data.extend(pickle.load(open(self.pkl_file_paths[1], 'rb')))
        elif split == 'val':
            self.data.extend(pickle.load(open(self.pkl_file_paths[2], 'rb')))
        else:
            raise NotImplementedError
        
        self.transform = transform
        self.image_dir = os.path.join(root,'CUB_200_2011','images')

    def _download(self):
        import tarfile
        # Download CUB images
        #print(os.path.join(self.root,'CUB_200_2011.tgz'))
        if os.path.exists(os.path.join(self.root,'CUB_200_2011.tgz')):
            print('Files already downloaded and extracted.')
        else:
            print('Downloading CUB...')
            download_url(self.cub_url, self.root, self.filename, self.tgz_md5)
        if not os.path.exists(os.path.join(self.root,'CUB_200_2011/')):
            print("Extracting CUB_200_2011.tgz")
            with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
                tar.extractall(path=self.root)
        # Download CUB attributes
        for i,p in enumerate(self.pkl_file_paths):
            if not os.path.exists(p):
                print(f"Downloading {p}")
                download_url(self.pkl_urls[i],os.path.join(self.root,'class_attr_data_10'), self.splits[i])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_data = self.data[idx]
        img_path = img_data['img_path']
        idx = img_path.split('/').index('CUB_200_2011')
        img_path = '/'.join([self.image_dir] + img_path.split('/')[idx+2:])
        img = Image.open(img_path).convert('RGB')
        class_label = img_data['class_label']
        if self.transform:
            img = self.transform(img)
        attr_label = img_data['attribute_label']

        print(attr_label.shape, class_label.shape)
        quit()    
        
        return img, attr_label, class_label

if __name__ == "__main__":
    data = CUBDataset(
        base_path="CUB"
    )
