from copy import deepcopy
import itertools
import torch
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
import json
import random
from torch.utils.data import Dataset, DataLoader,Subset
import pytorch_lightning as pl
from transformers import AutoTokenizer, T5Tokenizer,RobertaTokenizer,RobertaTokenizerFast,GPT2Tokenizer, PreTrainedTokenizerFast
import numpy as np
import random
from typing import *
from common import *
from tokenizers import AddedToken
import ast

class BooleanDataset(Dataset):  # type: ignore
    def __init__(
        self,
        composition: bool,
        path: str,
        model_name: str,
        model_type: str,
        max_input_len: int,
        max_output_len: int,
        task: str,
        method: str,
        tokenizer_type: str,

    ):

        super().__init__()

        if 't5' in model_name:
            if method=='un-pretrained':
                #self.tokenizer = T5Tokenizer.from_pretrained("/userhomes/Boolean/src/my_model2.model")
                #self.tokenizer2 = AutoTokenizer.from_pretrained(model_name)
                if 'pretrained' in tokenizer_type:
                    self.tokenizer = T5Tokenizer.from_pretrained(model_name)
                    self.tokenizer.add_tokens(AddedToken("~", single_word=True))
                elif 'custom' in tokenizer_type:
                    self.tokenizer = T5Tokenizer.from_pretrained(
                        "/userhomes/Boolean/src/vocab_depth2_sentencepiece.model")
                elif 'character' in tokenizer_type:
                    self.tokenizer = T5Tokenizer.from_pretrained("/app/input/dataset/boolean2/vocab_character_sentencepiece.model")


            else:
                self.tokenizer = T5Tokenizer.from_pretrained(model_name,use_fast=False)
                self.tokenizer.add_tokens(AddedToken("~", single_word=True))
                if 'character' in tokenizer_type:
                    self.tokenizer = T5Tokenizer.from_pretrained("/userhomes/Boolean/src/vocab_character_sentencepiece.model")

            #print(len(self.tokenizer))
            #special_tokens_dict = {'additional_special_tokens': ['0','1','~','+','*','(',')']}

            #self.tokenizer.add_special_tokens({"additional_special_tokens": ["0"]})
            #self.tokenizer.add_special_tokens({"additional_special_tokens": ["1"]})
            #self.tokenizer.add_special_tokens({"additional_special_tokens": ["~"]})
            #self.tokenizer.add_special_tokens({"additional_special_tokens": ["+"]})
            #self.tokenizer.add_special_tokens({"additional_special_tokens": ["*"]})
            #self.tokenizer.add_special_tokens({"additional_special_tokens": ["("]})
            #self.tokenizer.add_special_tokens({"additional_special_tokens": [")"]})
            """
            self.tokenizer.add_tokens("0")
            self.tokenizer.add_tokens("1")
            self.tokenizer.add_tokens("~")
            self.tokenizer.add_tokens("+")
            self.tokenizer.add_tokens("*")
            self.tokenizer.add_tokens("(")
            self.tokenizer.add_tokens(")")
            """

        elif 'roberta' in model_name:
            if method == 'un-pretrained':
                custom_vocab_file = "/app/input/dataset/boolean2/roberta_vocab2.json"
                custom_merges_file = "/app/input/dataset/boolean2/roberta_merges2.txt"
                self.tokenizer = RobertaTokenizer(vocab_file=custom_vocab_file,merges_file=custom_merges_file)
                if 'pretrained' in tokenizer_type:
                    self.tokenizer = RobertaTokenizer.from_pretrained(model_name)

            else:
                self.tokenizer = RobertaTokenizer.from_pretrained(model_name)
            if 'custom' in tokenizer_type:
                self.tokenizer = PreTrainedTokenizerFast(
                    tokenizer_file='/userhomes/Boolean/src/basic_BPEtokenizer.json')

        elif 'gpt' in model_name:
            if method == 'un-pretrained':
                custom_vocab_file = "/app/input/dataset/boolean2/gpt2_vocab.json"		#"/userhomes/Boolean/src/gpt2_vocab.json"	("/app/input/dataset/boolean2/vocab_character_sentencepiece.model")
                custom_merges_file = "/app/input/dataset/boolean2/gpt2_merges.txt"		#"/userhomes/Boolean/src/gpt2_merges.txt"	
                self.tokenizer = GPT2Tokenizer(vocab_file=custom_vocab_file, merges_file=custom_merges_file)
                if 'pretrained' in tokenizer_type:
                    self.tokenizer = AutoTokenizer.from_pretrained(model_name)
                print(self.tokenizer.eos_token)
                self.tokenizer.pad_token = self.tokenizer.eos_token
                special_tokens={"<|endoftext|>"}
            else:
                self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
                self.tokenizer.pad_token = self.tokenizer.eos_token

            if 'custom' in tokenizer_type:
                self.tokenizer = PreTrainedTokenizerFast(
                    tokenizer_file='/userhomes/Boolean/src/basic_BPEtokenizer.json')

 # The maximum number of premises used in data augmentation.
        self.max_input_len = max_input_len
        self.max_output_len= max_output_len
        self.task=task
        self.data = self.preprocess(path)
        self.method=method
        self.model_type=model_type
        self.model_name=model_name
        self.composition=composition
    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> Example:
        ex = self.data[idx]
        #premises = deepcopy(ex["premises"])
        #random.shuffle(premises)
        #premises = ". ".join(premises) + "."
        #print(ex)
        #{"uid": 0, "tree": "~(~0+~0)", "label": "0", "stepwise_proof": ["~(1+~0)", "~(1+1)", "~1", "0"]}

        if self.composition:
            if self.task == 'classification':
                return {
                    "id": ex["uid"],
                    "depth":ex["depth"],
                    "tree": ex["tree"],
                    "label": ex["label"],
                    "stepwise_proof": ex["stepwise_proof"],
                    # Whether it is a valid entailment.
                }

            elif self.task == 'one-shot':

                entire_proof = ";".join(ex["stepwise_proof"])
                if len(entire_proof) == 1:
                    print("tree", ex["tree"])
                    print("tree", entire_proof)

                try:
                    return {
                        "id": ex["uid"],
                        "depth": ex["depth"],
                        "tree": ex["tree"],
                        "label": ex["label"],
                        "entire_proof": entire_proof,

                    }
                except:
                    return {
                        "id": ex["uid"],
                        "tree": ex["tree"],
                        "label": ex["label"],
                        "entire_proof": entire_proof,

                    }


            elif self.task == 'stepwise':
                """
                selected_index=random.randint(0,len(ex["stepwise_proof"])-1)
                selected_proof=ex["stepwise_proof"][selected_index]
                partial_proof=ex["stepwise_proof"][:selected_index]
                partial_proof = ";".join(partial_proof)
                entire_proof= ";".join(ex["stepwise_proof"])
                """
                try:
                    selected_proof = ex["target_proof"]
                    partial_proof = ex["partial_proof"]
                    # if len(partial_proof)==0:
                    #    print("zero base")
                    partial_proof = ";".join(partial_proof)
                    # print("partial_proof",partial_proof)
                    entire_proof = ";".join(ex["stepwise_proof"])

                    # print("stepwise")

                except:
                    selected_index = random.randint(0, len(ex["stepwise_proof"]) - 1)
                    selected_proof = ex["stepwise_proof"][selected_index]
                    partial_proof = ex["stepwise_proof"][:selected_index]
                    partial_proof = ";".join(partial_proof)
                    entire_proof = ";".join(ex["stepwise_proof"])
                try:
                    return {
                        "id": ex["uid"],
                        "depth": ex["depth"],
                        "tree": ex["tree"],
                        "label": ex["label"],
                        "target_proof": selected_proof,
                        "partial_proof": partial_proof,
                        "entire_proof": entire_proof
                        # Whether it is a valid entailment.
                    }
                except:
                    return {
                        "id": ex["uid"],
                        "tree": ex["tree"],
                        "label": ex["label"],
                        "target_proof": selected_proof,
                        "partial_proof": partial_proof,
                        "entire_proof": entire_proof
                        # Whether it is a valid entailment.
                    }




        else:
            if self.task=='classification':
                return {
                    "id": ex["uid"],
                    "tree": ex["tree"],
                    "label": ex["label"],
                    "stepwise_proof": ex["stepwise_proof"],
                     # Whether it is a valid entailment.
                }

            elif self.task=='one-shot':

                entire_proof=";".join(ex["stepwise_proof"])
                if len(entire_proof)==1:
                    print("tree",ex["tree"])
                    print("tree", entire_proof)
                return{
                "id": ex["uid"],
                "tree": ex["tree"],
                "label": ex["label"],
                "entire_proof": entire_proof,


                }

            elif self.task == 'stepwise':
                """
                selected_index=random.randint(0,len(ex["stepwise_proof"])-1)
                selected_proof=ex["stepwise_proof"][selected_index]
                partial_proof=ex["stepwise_proof"][:selected_index]
                partial_proof = ";".join(partial_proof)
                entire_proof= ";".join(ex["stepwise_proof"])
                """
                try:
                    selected_proof = ex["target_proof"]
                    partial_proof = ex["partial_proof"]
                    #if len(partial_proof)==0:
                    #    print("zero base")
                    partial_proof = ";".join(partial_proof)
                    #print("partial_proof",partial_proof)
                    entire_proof = ";".join(ex["stepwise_proof"])

                    #print("stepwise")

                except:
                    selected_index = random.randint(0, len(ex["stepwise_proof"]) - 1)
                    selected_proof = ex["stepwise_proof"][selected_index]
                    partial_proof = ex["stepwise_proof"][:selected_index]
                    partial_proof = ";".join(partial_proof)
                    entire_proof = ";".join(ex["stepwise_proof"])



                return{
                    "id": ex["uid"],
                    "tree": ex["tree"],
                    "label": ex["label"],
                    "target_proof":selected_proof,
                    "partial_proof": partial_proof,
                    "entire_proof": entire_proof
                    # Whether it is a valid entailment.
                }



    def preprocess(self, path: str) -> List[Example]:
        data = []

        for line in open(path):
            #print(line)
            ex = json.loads(line)
            # print("ex",ex)

            data.append(ex)

        if self.task!='stepwise':
            random.shuffle(data)
            # print(f"#positives: {num_pos}\n#pseudo-negatives: {num_neg}")

        return data

    def collate(self, examples: List[Example]) -> Batch:
        #  task: classification  #calssification, one-shot, stepwise
        if self.method=='pretrained':
            if self.task == 'classification':
                input_text = [
                    f"${ex['tree']}"
                    for ex in examples]

                """
                input_text = [
                    f"$Simplify the following boolean algebra problem$ = {ex['tree']}"
                    for ex in examples]
                """
                input = self.tokenizer(
                    input_text,
                    padding="longest",
                    truncation="longest_first",
                    max_length=self.max_input_len,
                    return_tensors="pt",
                )

                if self.model_type=='seq2seq':
                    label = [f"{ex['label']}" for ex in examples]
                    #label = [f"$label$ = {ex['label']}; $explanation$ = {ex['explanation1']}" for ex in examples]
                    output_seq=self.tokenizer(label,padding="longest",max_length=self.max_output_len,truncation=True,return_tensors='pt')
                    output_seq.input_ids[output_seq.input_ids == self.tokenizer.pad_token_id] = -100

                    return {

                        "input_text" : input_text,
                        "input_seq_ids": input["input_ids"],
                        "input_seq_mask": input["attention_mask"],
                        "output_seq": label,
                        "output_seq_ids": output_seq.input_ids,
                        "output_seq_mask": output_seq.attention_mask,
                    }
                else:
                    label = torch.tensor([int(ex["label"]) for ex in examples], dtype=torch.int64)
                    return {
                        "input_text" : input_text,
                        "input_seq_ids": input["input_ids"],
                        "input_seq_mask": input["attention_mask"],
                        "label": label,
                    }


            elif self.task == 'one-shot':
                if 't5' in self.model_name:
                    input_text = [
                        f"${ex['tree']}"
                        for ex in examples]
                    """
                    input_text = [
                        f"$Write the whole steps to simplify the following boolean expression$ = {ex['tree']}"
                        for ex in examples]
                    """
                    input = self.tokenizer(
                        input_text,
                        padding="longest",
                        truncation="longest_first",
                        max_length=self.max_input_len,
                        return_tensors="pt",
                    )

                    label = [f"{ex['entire_proof']}" for ex in examples]
                    #label = [f"$label$ = {ex['label']}; $explanation$ = {ex['explanation1']}" for ex in examples]
                    output_seq=self.tokenizer(label,padding="longest",max_length=self.max_output_len,truncation=True,return_tensors='pt')
                    output_seq.input_ids[output_seq.input_ids == self.tokenizer.pad_token_id] = -100

                    return {

                        "input_text" : input_text,
                        "input_seq_ids": input["input_ids"],
                        "input_seq_mask": input["attention_mask"],
                        "output_seq": label,
                        "output_seq_ids": output_seq.input_ids,
                        "output_seq_mask": output_seq.attention_mask,
                    }

                elif 'gpt2' in self.model_name:
                    input_text = [
                        f"${ex['tree']}$"
                        for ex in examples]                 #f"${ex['tree']}"
                    valid_input_text=input_text
                    label = [f"{ex['entire_proof']}" for ex in examples]

                    input_text=[input_text[i]+label[i] for i in range(len(input_text))]
                    output_seq=self.tokenizer(label,padding="longest",max_length=self.max_output_len,truncation=True,return_tensors='pt')
                    output_seq.input_ids[output_seq.input_ids == self.tokenizer.pad_token_id] = -100

                    input = self.tokenizer(
                        input_text,
                        padding="longest",
                        truncation="longest_first",
                        max_length=self.max_input_len,
                        return_tensors="pt",
                    )

                    return {

                        "input_text": valid_input_text,
                        "input_seq_ids": input["input_ids"],
                        "input_seq_mask": input["attention_mask"],
                        "output_seq": label,
                        "output_seq_ids": output_seq.input_ids,
                        "output_seq_mask": output_seq.attention_mask,

                    }

            elif self.task == 'stepwise':
                if 't5' in self.model_name:
                    raw_input_text = [
                        f"$Boolean expression$ = {ex['tree']}, $partial proof$ = "
                        for ex in examples]
                    input_text = [
                        f"$Boolean expression$ = {ex['tree']}, $partial proof$ = {ex['partial_proof']}"
                        for ex in examples]


                    """
                    raw_input_text = [
                        f"$Given a boolean expression and a partial proof, provide the next step for the partial proof based on the given information.$ $Boolean expression$ = {ex['tree']}, $partial proof$ = "
                        for ex in examples]
                    input_text = [
                        f"$Given a boolean expression and a partial proof, provide the next step for the partial proof based on the given information.$ $Boolean expression$ = {ex['tree']}, $partial proof$ = {ex['partial_proof']}"
                        for ex in examples]
                    """
                    input = self.tokenizer(
                        input_text,
                        padding="longest",
                        truncation="longest_first",
                        max_length=self.max_input_len,
                        return_tensors="pt",
                    )

                    label = [f"{ex['target_proof']}" for ex in examples]
                    #label = [f"$label$ = {ex['label']}; $explanation$ = {ex['explanation1']}" for ex in examples]
                    output_seq=self.tokenizer(label,padding="longest",max_length=self.max_output_len,truncation=True,return_tensors='pt')
                    output_seq.input_ids[output_seq.input_ids == self.tokenizer.pad_token_id] = -100
                    entire_proof=[f"{ex['entire_proof']}" for ex in examples]
                    #print("input_text",input_text)
                    #print("label",label)
                    return {
                        "raw_input_text" : raw_input_text,
                        "input_text" : input_text,
                        "input_seq_ids": input["input_ids"],
                        "input_seq_mask": input["attention_mask"],
                        "entire_proof": entire_proof,
                        "output_seq": label,
                        "output_seq_ids": output_seq.input_ids,
                        "output_seq_mask": output_seq.attention_mask,
                    }

                elif 'gpt2' in self.model_name:
                    raw_input_text = [
                        f"$Boolean expression$ = {ex['tree']}, $partial proof$ = "
                        for ex in examples]
                    input_text = [
                        f"$Boolean expression$ = {ex['tree']}, $partial proof$ = {ex['partial_proof']}"
                        for ex in examples]
                    label = [f"{ex['target_proof']}" for ex in examples]
                    input_text=[input_text[i]+label[i] for i in range(len(input_text))]
                    input = self.tokenizer(
                        input_text,
                        padding="longest",
                        truncation="longest_first",
                        max_length=self.max_input_len,
                        return_tensors="pt",
                    )
                    output_seq=self.tokenizer(label,padding="longest",max_length=self.max_output_len,truncation=True,return_tensors='pt')
                    output_seq.input_ids[output_seq.input_ids == self.tokenizer.pad_token_id] = -100
                    entire_proof=[f"{ex['entire_proof']}" for ex in examples]
                    #print("input_text",input_text)
                    #print("label",label)
                    #if len(examples[0]['partial_proof']) == 0:
                    #    print("zero base2")
                    #    print("input_text",input_text)
                    #   print("input_text2", input_text)
                    return {
                        "raw_input_text": raw_input_text,
                        "input_text": input_text,
                        "input_seq_ids": input["input_ids"],
                        "input_seq_mask": input["attention_mask"],
                        "entire_proof": entire_proof,
                        "output_seq": label,
                        "output_seq_ids": output_seq.input_ids,
                        "output_seq_mask": output_seq.attention_mask,

                    }



        if self.method == 'un-pretrained':
            if self.task == 'classification':
                input_text = [f"{ex['tree']}"for ex in examples]
                input = self.tokenizer(
                    input_text,
                    padding="longest",
                    truncation="longest_first",
                    max_length=self.max_input_len,
                    return_tensors="pt",
                )
                if self.model_type=='seq2seq':
                    label = [f"{ex['label']}" for ex in examples]
                    #label = [f"$label$ = {ex['label']}; $explanation$ = {ex['explanation1']}" for ex in examples]
                    output_seq=self.tokenizer(label,padding="longest",max_length=self.max_output_len,truncation=True,return_tensors='pt')
                    output_seq.input_ids[output_seq.input_ids == self.tokenizer.pad_token_id] = -100

                    return {

                        "input_text" : input_text,
                        "input_seq_ids": input["input_ids"],
                        "input_seq_mask": input["attention_mask"],
                        "output_seq": label,
                        "output_seq_ids": output_seq.input_ids,
                        "output_seq_mask": output_seq.attention_mask,
                    }
                else:
                    label = torch.tensor([int(ex["label"]) for ex in examples], dtype=torch.int64)
                    return {
                        "input_text" : input_text,
                        "input_seq_ids": input["input_ids"],
                        "input_seq_mask": input["attention_mask"],
                        "label": label,
                    }



            elif self.task == 'one-shot':

                if 't5' in self.model_name:
                    input_text = [f"{ex['tree']}"for ex in examples]
                    input = self.tokenizer(
                        input_text,
                        padding="longest",
                        truncation="longest_first",
                        max_length=self.max_input_len,
                        return_tensors="pt",
                    )

                    label = [f"{ex['entire_proof']}" for ex in examples]
                    #label = [f"$label$ = {ex['label']}; $explanation$ = {ex['explanation1']}" for ex in examples]
                    output_seq=self.tokenizer(label,padding="longest",max_length=self.max_output_len,truncation=True,return_tensors='pt')
                    output_seq.input_ids[output_seq.input_ids == self.tokenizer.pad_token_id] = -100
                    try:
                        depth=[ex['depth'] for ex in examples]
                        return {

                            "input_text" : input_text,
                            "depth" : depth,
                            "input_seq_ids": input["input_ids"],
                            "input_seq_mask": input["attention_mask"],
                            "output_seq": label,
                            "output_seq_ids": output_seq.input_ids,
                            "output_seq_mask": output_seq.attention_mask,
                        }
                    except:
                        return {

                            "input_text": input_text,
                            "input_seq_ids": input["input_ids"],
                            "input_seq_mask": input["attention_mask"],
                            "output_seq": label,
                            "output_seq_ids": output_seq.input_ids,
                            "output_seq_mask": output_seq.attention_mask,
                        }

                elif 'gpt2' in self.model_name:
                    input_text = [
                        f"{ex['tree']}"
                        for ex in examples]

                    valid_input_text=input_text
                    label = [f"{ex['entire_proof']}" for ex in examples]

                    input_text = [input_text[i] + label[i] for i in range(len(input_text))]

                    input = self.tokenizer(
                        input_text,
                        padding="longest",
                        truncation="longest_first",
                        max_length=self.max_input_len,
                        return_tensors="pt",
                    )
                    output_seq=self.tokenizer(label,padding="longest",max_length=self.max_output_len,truncation=True,return_tensors='pt')
                    output_seq.input_ids[output_seq.input_ids == self.tokenizer.pad_token_id] = -100

                    try:
                        depth = [ex['depth'] for ex in examples]
                        return  {

                            "input_text": valid_input_text,
                            "input_seq_ids": input["input_ids"],
                            "input_seq_mask": input["attention_mask"],
                            "depth": depth,
                            "output_seq": label,
                            "output_seq_ids": output_seq.input_ids,
                            "output_seq_mask": output_seq.attention_mask,

                        }

                    except:
                        return {

                            "input_text": valid_input_text,
                            "input_seq_ids": input["input_ids"],
                            "input_seq_mask": input["attention_mask"],
                            "output_seq": label,
                            "output_seq_ids": output_seq.input_ids,
                            "output_seq_mask": output_seq.attention_mask,

                        }

            elif self.task == 'stepwise':
                if 't5' in self.model_name:
                    raw_input_text = [
                        f"${ex['tree']}$"for ex in examples]
                    input_text = [
                        f"${ex['tree']}$, ${ex['partial_proof']}$"for ex in examples]
                    input = self.tokenizer(
                        input_text,
                        padding="longest",
                        truncation="longest_first",
                        max_length=self.max_input_len,
                        return_tensors="pt",
                    )
                    #if len(ex['partial_proof'])==0:
                    #    print("from zero base!!!!")
                    label = [f"{ex['target_proof']}" for ex in examples]
                    #label = [f"$label$ = {ex['label']}; $explanation$ = {ex['explanation1']}" for ex in examples]
                    output_seq=self.tokenizer(label,padding="longest",max_length=self.max_output_len,truncation=True,return_tensors='pt')
                    output_seq.input_ids[output_seq.input_ids == self.tokenizer.pad_token_id] = -100
                    entire_proof=[f"{ex['entire_proof']}" for ex in examples]
                    return {
                        "raw_input_text" : raw_input_text,
                        "input_text" : input_text,
                        "input_seq_ids": input["input_ids"],
                        "input_seq_mask": input["attention_mask"],
                        "entire_proof": entire_proof,
                        "output_seq": label,
                        "output_seq_ids": output_seq.input_ids,
                        "output_seq_mask": output_seq.attention_mask,
                    }

                elif 'gpt2' in self.model_name:
                    raw_input_text = [
                        f"${ex['tree']}$"for ex in examples]
                    input_text = [
                        f"${ex['tree']}$, ${ex['partial_proof']}$"for ex in examples]
                    label = [f"{ex['target_proof']}" for ex in examples]
                    input_text=[input_text[i]+label[i] for i in range(len(input_text))]
                    input = self.tokenizer(
                        input_text,
                        padding="longest",
                        truncation="longest_first",
                        max_length=self.max_input_len,
                        return_tensors="pt",
                    )
                    #if len(ex['partial_proof'])==0:
                    #    print("from zero base!!!!")
                    output_seq=self.tokenizer(label,padding="longest",max_length=self.max_output_len,truncation=True,return_tensors='pt')
                    output_seq.input_ids[output_seq.input_ids == self.tokenizer.pad_token_id] = -100
                    entire_proof=[f"{ex['entire_proof']}" for ex in examples]

                    return {
                        "raw_input_text": raw_input_text,
                        "input_text": input_text,
                        "input_seq_ids": input["input_ids"],
                        "input_seq_mask": input["attention_mask"],
                        "entire_proof": entire_proof,
                        "output_seq": label,
                        "output_seq_ids": output_seq.input_ids,
                        "output_seq_mask": output_seq.attention_mask,

                    }

class BooleandataModule(pl.LightningDataModule):
    def __init__(
        self,
        model_name: str,
        composition: bool,
        model_type: str,
        max_input_len: int,
        max_output_len: int,
        batch_size: int,
        num_workers: int,
        task: str,
        path_train: str,
        path_val: str,
        path_test: str,
        method: str,
        tokenizer_type: str,
    ) -> None:
        super().__init__()
        self.model_type=model_type
        self.composition=composition
        self.model_name = model_name
        self.max_input_len = max_input_len
        self.max_output_len = max_output_len
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.task=task
        self.path_train = path_train
        self.path_val = path_val
        self.path_test = path_test
        self.method=method
        self.tokenizer_type=tokenizer_type

        self.ds_train = BooleanDataset(
                    self.composition,
                    self.path_train,
                    self.model_name,
                    self.model_type,
                    self.max_input_len,
                    self.max_output_len,
                    self.task,
                    self.method,
                    self.tokenizer_type
                )



        self.ds_val = BooleanDataset(
                    self.composition,
                    self.path_val,
                    self.model_name,
                    self.model_type,
                    self.max_input_len,
                    self.max_output_len,
                    self.task,
                    self.method,
                    self.tokenizer_type
                )




        self.ds_test = BooleanDataset(
                    self.composition,
                    self.path_test,
                    self.model_name,
                    self.model_type,
                    self.max_input_len,
                    self.max_output_len,
                    self.task,
                    self.method,
                    self.tokenizer_type
                )


    def train_dataloader(self) -> DataLoader:  # type: ignore

        if self.task=='stepwise':

            return  DataLoader(
            self.ds_train,
            self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            collate_fn=self.ds_train.collate,
            pin_memory=True,
            drop_last=True,
        )
        else:
            return DataLoader(
                self.ds_train,
                self.batch_size,
                shuffle=True,
                num_workers=self.num_workers,
                collate_fn=self.ds_train.collate,
                pin_memory=True,
                drop_last=True,
            )

    def val_dataloader(self) -> DataLoader:  # type: ignore
        return DataLoader(
            self.ds_val,
            self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            collate_fn=self.ds_val.collate,
            pin_memory=True,
            drop_last=False,
        )

    def test_dataloader(self) -> DataLoader:  # type: ignore
        return DataLoader(
            self.ds_test,
            self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            collate_fn=self.ds_test.collate,
            pin_memory=True,
            drop_last=False,
        )


