import numpy as np
import json

from torch.utils.data import Dataset, ConcatDataset, Subset
from functools import partial

from .base_class import DataArguments, LazySupervisedDataset
from TraceVLM.datasets.ln_caption import LocalNarrativeCaptionDataSet
from TraceVLM.datasets.img_ln_caption import ImgLocalNarrativeCaptionDataSet
from TraceVLM.datasets.caption_ln import CaptionLocalNarrativeDataSet
import transformers

seed = 48
def process_dataset(txt_file):
    # The format is task&data_path&image_folder
    data = open(txt_file,"r").readlines()
    data_dict = {}
    for item in data:
        item = item.strip()
        dtype, data_path, image_folder = item.split("&")
        data_dict[dtype.strip()] = (data_path.strip(), image_folder.strip())
    return data_dict


class MultiConcatDataset(Dataset):
    #不采用build的方式，task相关的dataset采用相同的初始化参数，未用到的部分直接不使用
    #DataArgument部分的data_path由str转为dict，key为该dataset的数据集类型
    _repr_indent = 4

    def __init__(self, multi_path: str,
                 tokenizer: transformers.PreTrainedTokenizer,
                 data_args: DataArguments,
                 debug=False
                 ):

        assert data_args.data_path == None and data_args.image_folder == None, "If use multi paths, please do not input single path and folder."
        dataset_dict = process_dataset(multi_path)
        self.dataset_dict = dataset_dict

        lnc_template = "TraceVLM/datasets/templates/LNC_Bin.json"
        img_lnc_template = "TraceVLM/datasets/templates/LNC_IMG_Bin.json"
        caption_lnc_template = "TraceVLM/datasets/templates/LNC_CAPTION_Bin.json"

        coco_instance_json = "/storage-root/datasets/yangfan/coco2017/annotations/instances_train2017.json"

        lazydataset = partial(LazySupervisedDataset, tokenizer=tokenizer, data_args=data_args)
        LNCdataset = partial(LocalNarrativeCaptionDataSet, tokenizer=tokenizer, data_args=data_args, template_file=lnc_template, split="<SPL>", debug=debug, coco_instance_json=coco_instance_json)
        ImgLNCdataset = partial(ImgLocalNarrativeCaptionDataSet, tokenizer=tokenizer, data_args=data_args, template_file=img_lnc_template, split="<SPL>", debug=debug, coco_instance_json=coco_instance_json)
        CaptionLNCdataset = partial(CaptionLocalNarrativeDataSet, tokenizer=tokenizer, data_args=data_args, template_file=caption_lnc_template, split="<SPL>", debug=debug, coco_instance_json=coco_instance_json)


        self.map_func = {
            "LAZY": lazydataset,
            "LNC": LNCdataset,
            "ImgLNC": ImgLNCdataset,
            "CaptionLNC": CaptionLNCdataset,
        }
        datasets = []
        self.modality_lengths = []
        self.modality_types = []
        self.lengths = []

        lnc_datasets_data = {}  # Stores list_data_dict for LNC tasks

        # First pass: Load LNC data and create LNC datasets
        for tp, (image_folder, data_path) in self.dataset_dict.items():
            task_prefix = tp.split("_")[0]
            if task_prefix == "LNC":
                dataset = self.map_func[task_prefix](data_path=data_path, image_folder=image_folder)
                datasets.append(dataset)
                lnc_datasets_data[tp] = dataset.list_data_dict

        # Second pass: Create other datasets, reusing LNC data if possible
        for tp, (image_folder, data_path) in self.dataset_dict.items():
            task_prefix = tp.split("_")[0]
            if task_prefix == "LNC":
                continue  # Already processed

            if task_prefix in ["ImgLNC", "CaptionLNC"]:
                base_lnc_task = "LNC_" + "_".join(tp.split("_")[1:])
                if base_lnc_task in lnc_datasets_data:
                    list_data_dict = lnc_datasets_data[base_lnc_task]
                    dataset = self.map_func[task_prefix](data_path=None, image_folder=image_folder, list_data_dict=list_data_dict)
                else:
                    dataset = self.map_func[task_prefix](data_path=data_path, image_folder=image_folder)
                datasets.append(dataset)
            else:  # For LAZY and other tasks
                dataset = self.map_func[task_prefix](data_path=data_path, image_folder=image_folder)
                datasets.append(dataset)

        for dataset in datasets:
            self.modality_types.extend(dataset.modality_types)
            self.modality_lengths.extend(dataset.modality_lengths)
            self.lengths.extend(dataset.lengths)
        self.concat_dataset = ConcatDataset(datasets)

    def __len__(self):
        return len(self.concat_dataset)

    def __getitem__(self, index):
        return self.concat_dataset[index]

    def __repr__(self) -> str:
        head = "Dataset " + self.__class__.__name__
        body = [
            f"Number of datapoints: {self.__len__()}",
        ]
        for i, ds in enumerate(self.concat_dataset.datasets):
            body.append(f"Subset {i + 1}/{len(self.concat_dataset.datasets)}")
            body += ds.__repr__().splitlines()
        lines = [head] + [" " * self._repr_indent + line for line in body]
        return "\n".join(lines)
