# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import bisect
import copy
import logging
import os
import torch
import torch.utils.data
import torch.distributed
from torch.utils.data.dataset import ConcatDataset

from .catalog import DatasetCatalog
from .clip_datasets.clip_img_txt_pair_tsv import CLIPImgTxtPairTSVDataset

from .transforms.build import build_clip_transforms

def config_tsv_dataset_args(cfg, dataset_file, factory_name=None, is_train=True):
    ############### code removecd as tsv_dataset_name = factory_name = "CLIPImgTxtPairTSVDataset" ##############
    if factory_name is not None:
        tsv_dataset_name = factory_name

    if tsv_dataset_name in ["CLIPImgTxtPairTSVDataset"]:
        # no need for extra arguments
        args = {}
        args['args'] = cfg
        args['seq_len'] = cfg.DATASETS.MAX_SEQ_LENGTH # cfg.max_seq_length

    return args, tsv_dataset_name


def build_dataset(cfg, transforms, dataset_catalog, is_train=True, is_aux=False):
    """
    Arguments:
        cfg: config file.
        transforms (callable): transforms to apply to each (image, target) sample
        dataset_catalog (DatasetCatalog): contains the information on how to construct a dataset.
        is_train (bool): whether to setup the dataset for training or testing
    """

    dataset_list = (cfg.DATASETS.TRAIN if not is_aux else cfg.DATASETS.AUX) if is_train else cfg.DATASETS.TEST
    factory_list = (cfg.DATASETS.FACTORY_TRAIN if not is_aux else cfg.DATASETS.FACTORY_AUX) if is_train else cfg.DATASETS.FACTORY_TEST
    path_list = (cfg.DATASETS.PATH_TRAIN if not is_aux else cfg.DATASETS.PATH_AUX) if is_train else cfg.DATASETS.PATH_TEST

    if not isinstance(dataset_list, (list, tuple)):
        raise RuntimeError(
                "dataset_list should be a list of strings, got {}".format(dataset_list))
    if not isinstance(factory_list, (list, tuple)):
        raise RuntimeError(
                "factory_list should be a list of strings, got {}".format(factory_list))
    datasets = []
    target_offset = 0
    for i, dataset_name in enumerate(dataset_list):
        factory_name = factory_list[i] if i < len(factory_list) else None

        if factory_name == "CLIPImgTxtPairTSVDataset":
            dataset_names_merged = dataset_name.split('+')
            path_lists_merged = path_list[i].split('+')

            assert len(dataset_names_merged) == len(path_lists_merged), "number of datasets must match that of dataset paths"

            image_tsv_list = []
            text_tsv_list = []
            dataset_name_list = []  
            map_files = []
            max_num_tsv = 20  # maximum tsv files to load within a given folder        

            for dname, dpath in zip(dataset_names_merged, path_lists_merged):            
                args, tsv_dataset_name = config_tsv_dataset_args(
                    cfg, dataset_name, factory_name, is_train
                )
                factory = CLIPImgTxtPairTSVDataset if tsv_dataset_name in ["CLIPImgTxtPairTSVDataset"] else None
                prev_len = len(image_tsv_list)

                isFile = os.path.isfile(dpath) 
                if isFile:
                    dpath_listed_files = [os.path.basename(dpath)]
                    dpath = os.path.dirname(dpath)
                else:
                    dpath_listed_files = sorted(os.listdir(dpath))

                for filename in dpath_listed_files:
                    if ("images" in filename or "image" in filename or "img" in filename) and filename.endswith(".tsv"):
                        image_tsv_list.append(os.path.join(dpath, filename))     
                        if "images" in filename: # "images" - "text"
                            text_tsv_list.append(os.path.join(dpath, filename.replace("images", "text")))
                        elif "image" in filename: # "image"-"text"
                            text_tsv_list.append(os.path.join(dpath, filename.replace("image", "text")))
                        elif "img" in filename: # "img"-"caption"
                            text_tsv_list.append(os.path.join(dpath, filename.replace("img", "caption")))
                        if len(image_tsv_list) - prev_len == max_num_tsv:
                            break                                                        
                dataset_name_list += [dname] * (len(image_tsv_list) - prev_len)

                if dname == "imagenet22k":
                    map_files += [os.path.join(dpath, 'darknet_data_imagenet.labels.list')] * (len(image_tsv_list) - prev_len)
                else:
                    map_files += [None] * (len(image_tsv_list) - prev_len)

                assert len(image_tsv_list) == len(text_tsv_list), \
                    "the number image tsv files must be equal to that of text tsv files, otherwise check your data!"                

            args["image_tsv_file"] = image_tsv_list
            args["text_tsv_file"] = text_tsv_list
            args["dataset_name"] = dataset_name_list
            args["map_file"] = map_files                           
            args["filtered_datasets"] = cfg.DATASETS.FILTERED_CLASSIFICATION_DATASETS
            assert len(image_tsv_list) == len(text_tsv_list) == len(dataset_name_list) == len(map_files)

            print("number of image tsv files: ", len(image_tsv_list))
            print("number of text tsv fies: ", len(text_tsv_list))
                
        args["is_train"] = is_train
        args["transforms"] = transforms
        args["target_offset"] = target_offset
        if "bpe" in cfg.INPUT.TEXT_TOKENIZER:
            from detectron2.data.datasets.clip_prompt_utils import SimpleTokenizer as _Tokenizer
            tokenizer = _Tokenizer()                
            args["tokenizer_type"] = "bpe"
        args["tokenizer"] = tokenizer
        # make dataset from factory
        dataset = factory(**args)
        datasets.append(dataset)

    precomputed_tokens = {}
    dataset_classes = {}
    for dataset in datasets:
        if hasattr(dataset, "input_ids_all_classes"):
            precomputed_tokens["imagenet"] = \
                [dataset.input_ids_all_classes, dataset.input_mask_all_classes, dataset.segment_ids_all_classes]
        if hasattr(dataset, "classnames"):
            if isinstance(dataset.classnames, dict):
                dataset_classes.update(dataset.classnames)
            else:
                dataset_classes[dataset.dataset_name] = dataset.classnames

    # for testing, return a list of datasets
    if not is_train:
        return datasets, precomputed_tokens, dataset_classes

    if len(datasets) == 0:
        return None, None, None

    # for training, concatenate all datasets into a single one
    dataset = datasets[0]

    if len(datasets) > 1:
        dataset = ConcatDataset(datasets)
    return [dataset], precomputed_tokens, dataset_classes


def make_clip_dataset(cfg, is_train=True, is_aux=False, transforms=None):
    if transforms is None:
        transforms = build_clip_transforms(cfg, is_train)
    print("data transforms: ")
    print(transforms)
    datasets, precomputed_tokens, dataset_classes = build_dataset(cfg, transforms, DatasetCatalog, is_train, is_aux)

    if not datasets:
        return None, None, None
    return datasets, precomputed_tokens, dataset_classes