from dataclasses import dataclass, field
from typing import Callable, Dict, Optional, List
from transformers import GlueDataTrainingArguments as DataTrainingArguments
from transformers.hf_argparser import HfArgumentParser


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """
    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
        
        
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
    )
    # Few-shot type
    #   - finetune: standard fine-tuning
    #   - prompt: prompt-based fine-tuning
    #   - prompt-demo: prompt-based fine-tuning with demonstrations
    few_shot_type: str = field(
        default='prompt-demo',
        metadata={"help": "Few-shot learning model type. Choice: finetune, prompt, autoregressive"}
    )

    # Only for BERT-type model
    random_segment: bool = field(
        default=False,
        metadata={"help": "Whether to reinitialize the token type embeddings (only for BERT)."}
    )

    use_lm_head: int = field(
        default=1,
        metadata={"help": "0/1: Whether to use lm head or use a simple linear classifier."}
    )    
    
    log_file_store: str = field(
        default='prompt-demo',
        metadata={"help": "File to log results"}
    )
        
    use_CLS_linearhead: int = field(
        default=0,
        metadata={"help": "0/1: Whether to use [CLS] or the mask representation."}
    ) 
        
    l1_reg: float = field(
        default=0.,
        metadata={"help": "Apply l1 regularization on the model parameters!"}   
    )       
        
       
        
@dataclass
class DynamicDataTrainingArguments(DataTrainingArguments):
    """
    Arguments for dynamic training.
    """
    num_k: Optional[int] = field(
        default=16,
        metadata={"help": "Number of training instances per class"}
    )

    num_sample: Optional[int] = field(
        default=16,
        metadata={"help": "Number of samples (for inference) in fine-tuning with demonstrations"}
    )

    num_demo: Optional[int] = field(
        default=1,
        metadata={"help": "Number of demonstrations from each class"}
    )

    auto_demo: bool = field(
        default=True,
        metadata={"help": "Automatically generate template for using demonstrations"}
    )

    # For prompting
    template: str = field(
        default=None,
        metadata={"help": "Template"}
    )

    mapping: str = field(
        default=None,
        metadata={"help": "Label word mapping"}
    )

    template_path: str = field(
        default=None,
        metadata={"help": "Path to a txt file that stores all the templates, one per line. Do not set this when prompt_path is used"}
    )

    mapping_path: str = field(
        default=None,
        metadata={"help": "Path to a txt file that stores all the label word mappings, one per line. Do not set this when prompt_path is used"}
    )

    prompt_path: str = field(
        default=None,
        metadata={"help": "Path to a txt file that stores all the prompts (templates and mappings), one per line"}
    )
 
    template_id: int = field(
        default=None,
        metadata={"help": "Template id if using template_path"}
    )

    mapping_id: int = field(
        default=None,
        metadata={"help": "Mapping id if using template_path"}
    )

    prompt_id: int = field(
        default=None,
        metadata={"help": "Prompt id if using prompt_path"}
    )

    top_n_template: int = field(
        default=None,
        metadata={"help": "Use top-n template in the template path"}
    )

    # For logging
    tag: str = field(
        default='',
        metadata={"help": "Set the tag and find the result easier in the log."}
    )

    # For filtering when using demonstrations
    demo_filter: bool = field(
        default=False,
        metadata={"help": "Only use similar instances in demonstrations"}
    )

    demo_filter_rate: float = field(
        default=0.5,
        metadata={"help": "Only use top-x\% similar instances in demonstrations"}
    )

    demo_filter_model: str = field(
        default=None,
        metadata={"help": "Model name for demonstration filter embeddings. Will load embeddings based on the model name."}
    )

    debug_mode: bool = field(
        default=False,
        metadata={"help": "Debug mode"}
    )

    # For max length
    double_demo: bool = field(
        default=False,
        metadata={"help": "Use double length for using demonstrations"}
    )

    first_sent_limit: int = field(
        default=None,
        metadata={"help": "Limit the length of the first sentence (i.e., sent_0)"}
    )

    other_sent_limit: int = field(
        default=None,
        metadata={"help": "Limit the length of sentences other than the first sentence"}
    )

    use_full_length: bool = field(
        default=None,
        metadata={"help": "Use the full length (512)"}
    )
        
    max_length_per_example: int = field(
        default=None,
        metadata={"help": "Max length per example in gpt experiments on gpt!"}
    )

    max_seq_length: int = field(
        default=128,
        metadata={"help": "Maximum sequence length"}
    )

    #Arguments for gpt3 in-context experiments: not necessary for our experiments!    
    gpt3_in_context_head: bool = field(
        default=False,
        metadata={"help": "GPT-3's in-context learning (context at the beginning)"}
    )

    gpt3_in_context_tail: bool = field(
        default=False,
        metadata={"help": "GPT-3's in-context learning (context at the end)"}
    )

    gpt3_in_context_num: int = field(
        default=32,
        metadata={"help": "Number of context examples"}
    )    
        
    truncate_head: bool = field(
        default=False,
        metadata={"help": "When exceeding the maximum length, truncate the head instead of the tail."}
    )

   
    # Do not set up the following fields. They are set up automatically.
    prompt: bool = field(
        default=True,
        metadata={"help": "Whether to use prompt-based fine-tuning"}
    )
    template_list: List[str] = field(
        default=None,
        metadata={"help": "(DO NOT List of templates (only initialized after the program starts."}
    )
    
    autoregressive: bool = field(
        default=False,
        metadata={"help": "Whether to use GPT2 fine-tuning"}
    )

    
    

parser = HfArgumentParser(DynamicDataTrainingArguments)



data_args = tasks_config = {
    "cola": DynamicDataTrainingArguments(
        task_name="cola",
        data_dir="/data/common/lm-bff/k-shot/CoLA/64-0",
        num_k=64,
        template="*cls**sent_0*_This_is*mask*.*sep+*",
        mapping="{'0':'incorrect','1':'correct'}"
    ),
    "SST-2": DynamicDataTrainingArguments(
        task_name="sst-2",
        data_dir="/data/common/lm-bff/k-shot/SST-2/64-0",
        num_k=64,
        template="*cls**sent_0*_It_was*mask*.*sep+*",
        mapping="{'0':'terrible','1':'great'}"
    ),
    "imdb": DynamicDataTrainingArguments(
        task_name="imdb",
        data_dir="/data/common/lm-bff/k-shot/imdb/64-0",
        num_k=64,
        template="*cls**sent_0*_It_was*mask*.*sep+*",
        mapping="{'0':'terrible','1':'great'}",
        first_sent_limit=110,
        double_demo=True
    ),
    "MRPC": DynamicDataTrainingArguments(
        task_name="mrpc",
        data_dir="/data/common/lm-bff/k-shot/MRPC/64-0",
        num_k=64,
        template="*cls**sent_0**mask*,*+sent_1**sep+*",
        mapping="{'0':'No','1':'Yes'}"
    ),
    "QQP": DynamicDataTrainingArguments(
        task_name="qqp",
        data_dir="/data/common/lm-bff/k-shot/QQP/64-0",
        num_k=64,
        template="*cls**sent_0**mask*,*+sent_1**sep+*",
        mapping="{'0':'No','1':'Yes'}",
        num_sample=4
    ),
    "sts-b": DynamicDataTrainingArguments(
        task_name="sts-b",
        data_dir="/data/common/lm-bff/k-shot/STS-B/64-0",
        num_k=64,
        template="*cls**sent_0**mask*,*+sent_1**sep+*",
        mapping="{'0':'No','1':'Yes'}"
    ),
    "MNLI": DynamicDataTrainingArguments(
        task_name="mnli",
        data_dir="/data/common/lm-bff/k-shot/MNLI/64-0",
        num_k=64,
        template="*cls**sent_0*?*mask*,*+sent_1**sep+*",
        mapping="{'contradiction':'No','entailment':'Yes','neutral':'Maybe'}",
        first_sent_limit=110,
        max_seq_length=256,
        num_sample=4
    ),
    "anli": DynamicDataTrainingArguments(
        task_name="anli",
        data_dir="/data/common/lm-bff/k-shot/ANLI/64-0",
        num_k=64,
        template="*cls**sent_0*?*mask*,*+sent_1**sep+*",
        mapping="{'contradiction':'No','entailment':'Yes','neutral':'Maybe'}",
        max_seq_length=256,
        num_sample=4
    ),
    "SNLI": DynamicDataTrainingArguments(
        task_name="snli",
        data_dir="/data/common/lm-bff/k-shot/SNLI/64-0",
        num_k=64,
        template="*cls**sent_0*?*mask*,*+sent_1**sep+*",
        mapping="{'contradiction':'No','entailment':'Yes','neutral':'Maybe'}",
        max_seq_length=256,
        num_sample=4
    ),
    "QNLI": DynamicDataTrainingArguments(
        task_name="qnli",
        data_dir="/data/common/lm-bff/k-shot/QNLI/64-0",
        num_k=64,
        template="*cls**sent_0*?*mask*,*+sent_1**sep+*",
        mapping="{'not_entailment':'No','entailment':'Yes'}"
    ),
    "RTE": DynamicDataTrainingArguments(
        task_name="rte",
        data_dir="/data/common/lm-bff/k-shot/RTE/64-0",
        num_k=64,
        template="*cls**sent_0*?*mask*,*+sent_1**sep+*",
        mapping="{'not_entailment':'No','entailment':'Yes'}",
        max_seq_length=256,
        first_sent_limit=240
    ),
    "mr": DynamicDataTrainingArguments(
        task_name="mr",
        data_dir="/data/common/lm-bff/k-shot/mr/64-0",
        num_k=64,
        template="*cls**sent_0*_It_was*mask*.*sep+*",
        mapping="{0:'terrible',1:'great'}",
        first_sent_limit=110,
        other_sent_limit=50,
        double_demo=True
    ),
    "sst-5": DynamicDataTrainingArguments(
        task_name="sst-5",
        data_dir="/data/common/lm-bff/k-shot/SST-5/64-0",
        num_k=64,
        template="*cls**sent_0*_It_was*mask*.*sep+*",
        mapping="{0:'terrible',1:'bad',2:'okay',3:'good',4:'great'}",
        first_sent_limit=110,
        other_sent_limit=20,
        double_demo=True
    ),
    "subj": DynamicDataTrainingArguments(
        task_name="subj",
        data_dir="/data/common/lm-bff/k-shot/subj/64-0",
        num_k=64,
        template="*cls**sent_0*_This_is*mask*.*sep+*",
        mapping="{0:'subjective',1:'objective'}",
        first_sent_limit=110,
        other_sent_limit=50,
        double_demo=True
    ),
    "trec": DynamicDataTrainingArguments(
        task_name="trec",
        data_dir="/data/common/lm-bff/k-shot/trec/64-0",
        num_k=64,
        template="*cls**mask*:*+sent_0**sep+*",
        mapping="{0:'Description',1:'Entity',2:'Expression',3:'Human',4:'Location',5:'Number'}",
        first_sent_limit=110,
        double_demo=True
    ),
    "ag_news": DynamicDataTrainingArguments(
        task_name="ag_news",
        data_dir="/data/common/lm-bff/k-shot/AG_NEWS/64-0",
        num_k=64,
        template="*cls**mask*:*+sent_0**sep+*",
        mapping="{'0':'World','1':'Sports','2':'Business','3':'Tech'}",
        first_sent_limit=110,
        double_demo=True
    ),
    "cr": DynamicDataTrainingArguments(
        task_name="cr",
        data_dir="/data/common/lm-bff/k-shot/cr/64-0",
        num_k=64,
        template="*cls**sent_0*_It_was*mask*.*sep+*",
        mapping="{0:'terrible',1:'great'}",
        first_sent_limit=110,
        other_sent_limit=50,
        double_demo=True
    ),
    "mpqa": DynamicDataTrainingArguments(
        task_name="mpqa",
        data_dir="/data/common/lm-bff/k-shot/mpqa/64-0",
        num_k=64,
        template="*cls**sent_0*_It_was*mask*.*sep+*",
        mapping="{0:'terrible',1:'great'}",
        first_sent_limit=110,
        double_demo=True
    ),
}