from typing import Optional ,Union 

from torch .utils .data import DataLoader 
import lightning as L 
from datasets import Dataset ,DatasetDict 
from ..utils import instantiate_from_config ,instantiate_from_config_hf_pretrained 


class BaseDataModule (L .LightningDataModule ):
    def __init__ (
    self ,
    data_config :Optional [dict ]=None ,
    data_path :Optional [str ]=None ,
    data_name :Optional [str ]=None ,
    data_kwargs :Optional [dict ]=None ,
    tokenizer_config :Optional [dict ]=None ,
    transform_config :Optional [dict ]=None ,
    batch_size :int =64 ,
    eval_batch_size :Optional [int ]=None ,
    num_workers :int =0 ,
    val_size :Optional [float ]=None ,
    transform_all :bool =False ,
    ):
        super ().__init__ ()
        if data_config is None :
            if data_kwargs is None :
                data_kwargs ={}
            data_config ={
            "target":"datasets.load_dataset",
            "params":{"path":data_path ,"name":data_name ,**data_kwargs },
            }
        self .data_config =data_config 
        self .batch_size =batch_size 
        self .eval_batch_size =(
        eval_batch_size if eval_batch_size is not None else batch_size 
        )
        self .num_workers =num_workers 
        self .val_size =val_size 
        self .transform_all =transform_all 
        self .tokenizer =(
        instantiate_from_config_hf_pretrained (tokenizer_config )
        if tokenizer_config is not None 
        else None 
        )
        if transform_config is None :
            transform_config ={"target":"ltm.data.transforms.BaseTransform"}
        self .transform =instantiate_from_config (
        transform_config ,tokenizer =self .tokenizer 
        )
        self .datasets ={}

    def prepare_data (self ):
        instantiate_from_config (self .data_config )

    def preprocess (
    self ,dataset :Union [Dataset ,DatasetDict ]
    )->Union [Dataset ,DatasetDict ]:
        return dataset 

    def setup (self ,stage =None ):
        dataset =instantiate_from_config (self .data_config )
        dataset =self .preprocess (dataset )
        if self .transform is not None and self .transform_all :
            dataset =dataset .map (
            self .transform ,
            num_proc =self .num_workers ,
            desc ="Applying transform",
            fn_kwargs ={"batched":False },
            )
        elif self .transform is not None :
            dataset =dataset .with_transform (self .transform )
        if self .val_size is not None :
            self .datasets ["train"],self .datasets ["val"]=(
            dataset ["train"]
            .train_test_split (test_size =self .val_size ,seed =42 )
            .values ()
            )
        else :
            self .datasets ["train"]=(
            dataset ["train"]if isinstance (dataset ,DatasetDict )else dataset 
            )
            if isinstance (dataset ,DatasetDict )and "validation"in dataset :
                self .datasets ["val"]=dataset ["validation"]
        if "test"in dataset :
            self .datasets ["test"]=dataset ["test"]

    def train_dataloader (self ):
        return DataLoader (
        self .datasets ["train"],
        batch_size =self .batch_size ,
        shuffle =True ,
        num_workers =self .num_workers ,
        )

    def val_dataloader (self ):
        return DataLoader (
        self .datasets ["val"],
        batch_size =self .eval_batch_size ,
        shuffle =False ,
        num_workers =self .num_workers ,
        )

    def test_dataloader (self ):
        return DataLoader (
        self .datasets ["test"],
        batch_size =self .eval_batch_size ,
        shuffle =False ,
        num_workers =self .num_workers ,
        )
