import os
from dataclasses import dataclass, field
from typing import Optional, List
from transformers import TrainingArguments


@dataclass
class ModelArguments:
    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
    )
    normalize: bool = field(default=False)
    pooling: str = field(default='cls')
    attn_type: str = field(default=None)
    use_reweighting_loss: bool = field(default=False)
    bitfit: bool = field(default=False)
    rdrop_weight: float = field(default=None)
    dtype: Optional[str] = field(
        default="float32",
        metadata={
            "help": "Floating-point format in which the model weights should be initialized and trained. Choose one "
                    "of `[float32, float16, bfloat16]`. "
        },
    )


@dataclass
class DataArguments:
    train_dir: str = field(
        default=None, metadata={"help": "Path to train directory"}
    )

    data_config: str = field(default="config/data_config.json")

    mix_coefficient: float = field(default=0.0)
    buffer_size: int = field(default=10000)
    length_config: str = field(default=None)
    default_length: int = field(default=None)
    query_column: Optional[str] = field(
        default="question",
        metadata={"help": "The name of the column in the datasets containing the questions."},
    )
    doc_column: Optional[str] = field(
        default="passage",
        metadata={"help": "The name of the column in the datasets containing the passages."},
    )

    max_len: int = field(
        default=512,
        metadata={
            "help": "The maximum total input sequence length after tokenization for document. Sequences longer "
                    "than this will be truncated, sequences shorter will be padded."
        },
    )
    data_cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the data downloaded from huggingface"}
    )

    add_prompt: bool = field(
        default=False, metadata={"help": "Prepend simple prompt to the text. e.g, 'query: this is a query', 'doc: this is a docc'."}
    )
    
    instruction: str = field(default=None, metadata={"help": "instrucion"}) 
    doc_instruction: str = field(default=None, metadata={"help": "instrucion"}) 

    add_instruction: bool = field(
        default=False, metadata={"help": "Prepend detailed instructions for the data."}
    )

    mask_instruction_pooling: bool = field(
        default=True, metadata={"help": "Whether or not mask instruction tokens during pooling."}
    )

    neg_per_ins: int = field(default=8, metadata={"help": "negs per instance"})
    finetune_data_path: str = field(default=None, metadata={"help": "Path to the json file for finetuning."})

    finetune_data_config: str = field(default='config/ft_data_config.yaml', metadata={"help": "Finetuning data config file."})
    boq_token: str = field(default=None, metadata={"help": "boq token"})
    bod_token: str = field(default=None, metadata={"help": "boq token"})
    eod_token: str = field(default=None, metadata={"help": "boq token"})
    random_neg: int = field(default=0)


@dataclass
class EmbeddingTrainingArguments(TrainingArguments):
    warmup_ratio: float = field(default=0.1)
    negatives_x_device: bool = field(default=False, metadata={"help": "share negatives across devices"})
    contrastive_warmup: bool = field(default=False, metadata={"help": "disable negative sharing within warmup steps"})

    grad_cache: bool = field(default=False, metadata={"help": "Use gradient cache update"})
    gc_q_chunk_size: int = field(default=4)
    gc_d_chunk_size: int = field(default=32)

    temperature: float = field(default=None)
    t_warmup: bool = field(default=False, metadata={"help": "Linear temperature warmup."})
    full_contrastive_loss: bool = field(default=True)
    use_norm_loss: bool = field(default=False)
    loss_scale: float = field(default=-1., metadata={"help": "loss scale, -1 will use world_size"})
    use_lora: bool = field(default=True, metadata={"help": "Linear temperature warmup."})
    train_emb: bool = field(default=True, metadata={"help": "Linear temperature warmup."})
    continue_train: bool = field(default=False, metadata={"help": "Continue train"})
    loss_method: str = field(default=None)
    kl_weight: float = field(default=None)
    reference_model: str = field(default=None)

    train_type: str = None

    diy_cl_weight_ratio: float = field(default=1.0)

    token_list_path: str = None
    token_num: int = None

    freeze_classifier: bool = field(default=True)
    random_init: bool = field(default=False)

    limit_ratio: float = field(default=1e-4)
    sft_temperature: float = field(default=5e-2)



@dataclass
class LoraArguments:
    lora_r: int = 16
    lora_alpha: int = 32
    lora_dropout: float = 0.1
    lora_target_modules: List[str] = field(
        default_factory=lambda: ["c_attn", "c_proj", "w1", "w2"]
    )
    lora_weight_path: str = ""
    lora_bias: str = "none"
    q_lora: bool = False
    exclude_modules: str = None

