
import os
from dataclasses import dataclass, field
from typing import Optional, List
from transformers import TrainingArguments


@dataclass
class ModelArguments:
    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
    )

    # modeling
    untie_encoder: Optional[bool] = field(
        default=False,
        metadata={"help": "no weight sharing between qry passage encoders"}
    )

    # pooling
    pooling_type: Optional[str] = field(
        default="avg",
        metadata={"help": "pooling method (i.e., cls/avg/last) to obtain the text embedding given a sequence of tokens"}
    )
    skip_l2norm: Optional[bool] = field(
        default=False,
        metadata={"help": "whether we skip l2-normalization step to the final text embedding. (Default False)"}
    )

    # for Jax training
    dtype: Optional[str] = field(
        default="float32",
        metadata={
            "help": "Floating-point format in which the model weights should be initialized and trained. Choose one "
                    "of `[float32, float16, bfloat16]`. "
        },
    )


@dataclass
class TrainingDataArguments:
    lbl_folder: str = field(
        metadata={"help": "Path to a folder of label text parquet files"}
    )
    trn_folder: str = field(
        metadata={"help": "Path to a folder of train CSR parquet files"}
    )
    y_npz_path: str = field(
        metadata={"help": "Path to the npz file of (weighted) Y.trn.npz"}
    )
    inp_key_col: str = field(
        default="qid",
        metadata={"help": "The key column corresponds to the text_col in input_folder"}
    )
    lbl_key_col: str = field(
        default="lid",
        metadata={"help": "The key column corresponds to the text_col in label_folder"}
    )
    train_group_size: int = field(
        default=1, metadata={"help": "number of passages/labels used to train for each query"}
    )
    positive_passage_no_shuffle: bool = field(
        default=False, metadata={"help": "always use the first positive passage/label"}
    )
    negative_passage_no_shuffle: bool = field(
        default=False, metadata={"help": "always use the first n negative passages/labels for training"}
    )
    dataset_proc_num: int = field(
        default=96, metadata={"help": "number of proc used in dataset preprocess"}
    )

    q_max_len: int = field(
        default=32,
        metadata={
            "help": "The maximum total input sequence length after tokenization for query. Sequences longer "
                    "than this will be truncated, sequences shorter will be padded."
        },
    )
    p_max_len: int = field(
        default=32,
        metadata={
            "help": "The maximum total input sequence length after tokenization for passage. Sequences longer "
                    "than this will be truncated, sequences shorter will be padded."
        },
    )
    train_dual: bool = field(
        default=True, metadata={"help": "Whether to train dual encoder loss."}
    )
    label_freq: int = field(
        default=2000, metadata={"help": "When sampling a label for supc loss, if its frequency is greater than label_freq, skip this label."}
    )
    supc_neg_num: int = field(
        default=0, metadata={"help": "number of negative queries for supc loss."}
    )
    max_label_per_query: int = field(
        default=50, metadata={"help": "maximum number of labels for each query."}
    )


@dataclass
class SearcherDataArguments:
    y_npz_path: str = field(
        metadata={"help": "Path to the npz file of Y.trn.npz"}
    )

    trn_folder: str = field(
        metadata={"help": "Path to a input folder of encoding text parquet files"}
    )

    tst_folder: str = field(
        metadata={"help": "Path to a input folder of encoding text parquet files"}
    )
    
    inp_key_col: str = field(
        metadata={"help": "The key column corresponds to the text_col in input_folder"}
    )

    lbl_key_col: str = field(
        metadata={"help": "The key column corresponds to the text_col in label_folder"}
    )

    lbl_folder: str = field(
        default=None, metadata={"help": "Path to a label folder of encoding text parquet files"}
    )

    text_max_len: int = field(
        default=32,
        metadata={
            "help": "The maximum total input sequence length after tokenization for query. Sequences longer "
                    "than this will be truncated, sequences shorter will be padded."
        },
    )

    lamb: float = field(
        default=0.5,
        metadata={
            "help": "The value of the wieghts of Q2X. For Q2Z, its weights are (1-lambda)."
        },
    )

    dataset_proc_num: int = field(
        default=96,
        metadata={"help": "number of proc used in dataset preprocess"}
    )
    
    num_partitions: int = field(
        default=128,
        metadata={"help": "number of partitions to save the embeddings"}
    )

    inference_topk: int = field(
        default=100,
        metadata={"help": "number of softmax topk for inference."}
    )

    inference_method: str = field(
        default="q2xz",
        metadata={"help": "Method for inference. Options: (1)q2z (2)q2x (3)q2xz"}
    )


@dataclass
class MyTrainingArguments(TrainingArguments):
    warmup_ratio: float = field(default=0.1)
    temperature: float = field(default=0.05)
    negatives_x_device: bool = field(default=False, metadata={"help": "share negatives across devices"})
    hnm_topk: int =  field(default=50, metadata={"help": "number of topk predictions from HNM step"})
    hnm_type: str =  field(default="q2xz", metadata={"help": "inference method for HNM at training stage"})
    hnm_steps: int =  field(default=10000, metadata={"help": "conduct hard negative mining every TK steps"})
    use_q_neg: bool = field(default=True, metadata={"help": "whether to use q as negative"})
