"""
Dataset for CRBench
"""

import os 
from typing import Optional 
from functools import partial 

import lightning as L 
from torch .utils .data import DataLoader 

from datasets import Dataset ,DatasetDict 
from ..utils import instantiate_from_config ,instantiate_from_config_hf_pretrained 


def get_instance (example ,task ):
    instance =task .get_instance (example ["seed"])
    example ["input"]=instance .input 
    example ["verifier_hint"]=instance .verifier_hint 
    return example 


class CRDataModule (L .LightningDataModule ):
    def __init__ (
    self ,
    crbench_task_config :dict ,
    tokenizer_config :dict ,
    transform_config :Optional [dict ]=None ,





    num_train :int =200 ,
    num_val :int =1 ,
    num_test :int =1_000 ,
    batch_size :int =1 ,
    eval_batch_size :Optional [int ]=None ,
    num_workers :Optional [int ]=0 ,
    ):
        super ().__init__ ()
        self .task =instantiate_from_config (crbench_task_config )
        self .tokenizer =instantiate_from_config_hf_pretrained (tokenizer_config )
        self .transform =(
        instantiate_from_config (transform_config ,tokenizer =self .tokenizer )
        if transform_config 
        else None 
        )
        self .num_train =num_train 
        self .num_val =num_val 
        self .num_test =num_test 
        self .batch_size =batch_size 
        self .eval_batch_size =(
        eval_batch_size if eval_batch_size is not None else batch_size 
        )
        self .num_workers =num_workers 
        self .datasets =DatasetDict ()

    def setup (self ,stage :str ):
        num_proc =os .cpu_count ()

        seeds =list (range (self .num_train +self .num_val +self .num_test ))
        dataset =Dataset .from_dict ({"seed":seeds })


        dataset =dataset .map (
        get_instance ,num_proc =num_proc ,fn_kwargs ={"task":self .task }
        )


        self .datasets ["train"]=dataset .select (range (self .num_train ))
        self .datasets ["val"]=dataset .select (
        range (self .num_train ,self .num_train +self .num_val )
        )
        self .datasets ["test"]=dataset .select (
        range (
        self .num_train +self .num_val ,
        self .num_train +self .num_val +self .num_test ,
        )
        )

        if self .transform :
            for split in self .datasets :
                self .datasets [split ]=self .datasets [split ].with_transform (
                partial (self .transform ,task =self .task ,do_eval =split =="test")
                )

    def train_dataloader (self ):
        return DataLoader (
        self .datasets ["train"],
        batch_size =self .batch_size ,
        shuffle =True ,
        num_workers =self .num_workers ,
        )

    def val_dataloader (self ):
        return DataLoader (
        self .datasets ["val"],
        batch_size =self .eval_batch_size ,
        shuffle =False ,
        num_workers =self .num_workers ,
        )

    def test_dataloader (self ):
        return DataLoader (
        self .datasets ["test"],
        batch_size =self .eval_batch_size ,
        shuffle =False ,
        num_workers =self .num_workers ,
        )
