import os
from dataclasses import dataclass, field
from typing import Optional, List
from transformers import TrainingArguments
from dataclasses import dataclass

@dataclass
class ModelArguments:
    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    # residual_model_name_or_path: str = field(
    #     default=None,
    #     metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    # )
    residual_encoder_name_or_path: str = field(
        default=None,
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    residual_decoder_name_or_path: str = field(
        default=None,
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    residual_num_layer: int = field(default=4)
    residual_num_head: int = field(default=8)

    target_model_path: str = field(
        default=None,
        metadata={"help": "Path to pretrained reranker target model"}
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
    )

    # modeling
    untie_encoder: bool = field(
        default=False,
        metadata={"help": "no weight sharing between qry passage encoders"}
    )

    # parameter efficient methods
    param_efficient: Optional[str] = field(
         default=None,
         metadata={"help": "Param efficient method used in model training"}
     )

    # out projection
    add_pooler: bool = field(default=False)
    projection_in_dim: int = field(default=768)
    projection_out_dim: int = field(default=768)
    
    use_t5_decoder: bool = field(
        default=False,
        metadata={"help": "Use t5 decoder"}
    )
    
    # parameter efficient methods
    param_efficient: Optional[str] = field(
         default=None,
         metadata={"help": "Param efficient method used in model training"}
     )
    fix_gpt: bool = field(
        default=False,
        metadata={"help": "fix backbone"}
    )
    no_use_gpt: bool = field(
        default=False,
        metadata={"help": "no use gpt --> vanilla cross-attn"}
    )
    top_gpt: bool = field(
        default=False,
        metadata={"help": "top gpt --> top vanilla cross-attn on gpt"}
    )
    
    ## cosine similarity
    cosine_scale: Optional[float] = field(
         default=None,
         metadata={"help": "use cosine similarity func: temperature"}
     )
    
    ## from which gpt layer
    bottom_layer_num: int = field(default=None)
    top_layer_num: int = field(default=None)
        

@dataclass
class DataArguments:
    train_dir: str = field(
        default=None, metadata={"help": "Path to train directory"}
    )
    eval_dir: str = field(
        default=None, metadata={"help": "Path to eval directory"}
    )
    dataset_name: str = field(
        default=None, metadata={"help": "huggingface dataset name"}
    )
    passage_field_separator: str = field(default=' ')
    dataset_proc_num: int = field(
        default=12, metadata={"help": "number of proc used in dataset preprocess"}
    )
    train_n_passages: int = field(default=8)
    positive_passage_no_shuffle: bool = field(
        default=False, metadata={"help": "always use the first positive passage"})
    negative_passage_no_shuffle: bool = field(
        default=False, metadata={"help": "always use the first negative passages"})

    encode_in_path: List[str] = field(default=None, metadata={"help": "Path to data to encode"})
    encoded_save_path: str = field(default=None, metadata={"help": "where to save the encode"})
    encode_is_qry: bool = field(default=False)
    encode_num_shard: int = field(default=1)
    encode_shard_index: int = field(default=0)

    q_max_len: int = field(
        default=32,
        metadata={
            "help": "The maximum total input sequence length after tokenization for query. Sequences longer "
                    "than this will be truncated, sequences shorter will be padded."
        },
    )
    p_max_len: int = field(
        default=128,
        metadata={
            "help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
                    "than this will be truncated, sequences shorter will be padded."
        },
    )
    train_cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the train data downloaded from huggingface, if None, repeated download"}
    )

    eval_cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the eval data downloaded from huggingface, if None, repeated download"}
    )
        
    ## split load
    split_load_data: bool = field(default=False)
    
    ## *******************************************
    ## Modified: Few-Shot Training
    split_dataset_stg: str = field(default=None, metadata={"help": "Base/Novel class qid splitation"})
    ## *******************************************
    
    ## beir dataset
    qry_template: str = field(default=None, metadata={"help": "beir inference template"})
    psg_template: str = field(default=None, metadata={"help": "beir inference template"})
    sub_split_num: int = field(default=1)
    
    def __post_init__(self):
        if self.dataset_name is not None:
            info = self.dataset_name.split('/')
#             self.dataset_split = info[-1] if len(info) == 3 else 'train'
            self.dataset_name = "/".join(info[:-1]) if len(info) == 3 else '/'.join(info)
            self.dataset_language = 'default'
            if ':' in self.dataset_name:
                self.dataset_name, self.dataset_language = self.dataset_name.split(':')
        else:
            self.dataset_name = 'json'
#             self.dataset_split = 'train'
            self.dataset_language = 'default'
            
        if self.train_dir is not None:
            ## Modified --------------------
            self.train_cache_dir = os.path.join(self.train_dir, "cache") 
            ## Modified --------------------
            files = os.listdir(self.train_dir)
            self.train_path = [
                os.path.join(self.train_dir, f)
                for f in files
                if f.endswith('jsonl') or f.endswith('json')
            ]
        else:
            self.train_path = None
            
        if self.eval_dir is not None:
            ## Modified --------------------
            self.eval_cache_dir = os.path.join(self.eval_dir, "cache") 
            ## Modified --------------------
            files = os.listdir(self.eval_dir)
            self.eval_path = [
                os.path.join(self.eval_dir, f)
                for f in files
                if f.endswith('jsonl') or f.endswith('json')
            ]
        else:
            self.eval_path = None


@dataclass
class DenseTrainingArguments(TrainingArguments):
    warmup_ratio: float = field(default=0.1)
    negatives_x_device: bool = field(default=False, metadata={"help": "share negatives across devices"})
    do_encode: bool = field(default=False, metadata={"help": "run the encoding loop"})

    grad_cache: bool = field(default=False, metadata={"help": "Use gradient cache update"})
    gc_q_chunk_size: int = field(default=4)
    gc_p_chunk_size: int = field(default=32)
        
    ## SS added
    early_stop_step: int = field(default=-1)
    early_stop_epoch: int = field(default=-1)
    tensorboard: bool = field(default=False)
    
    ## *******************************************
    ## Modified: Few-Shot Training
    fewshot_extends: List[int] = field(default=None, metadata={"help": "Few-shot number"})
    
    freeze_encoder_name: str = field(default=None, metadata={"help": "lm_q: freeze query encoder; lm_p: freeze passage encoder"})
    
    
    ## *******************************************
    ## *******************************************
    ## for gpt
    ## *******************************************
    ## *******************************************
    query_reps: Optional[str] = field(
         default=None,
         metadata={"help": "encoding query vectors"}
     )
    passage_reps: Optional[str] = field(
         default=None,
         metadata={"help": "encoding passage vectors"}
     )
    depth: int = field(default=10)
    
    save_ranking_to: Optional[str] = field(
         default=None,
         metadata={"help": "save trec path"}
     )
    

    
# parser.add_argument('--list-type-nargs', type=list, nargs='+')