from dataclasses import dataclass, field
from typing import Optional

from transformers import HfArgumentParser

import dp_transformers

@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.

    Using `HfArgumentParser` we can turn this class
    into argparse arguments to be able to specify them on
    the command line.training_args
    """

    task_name: str = field(default=None)
    max_seq_length: int = field(
        default=256,
        metadata={
            "help": "The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        },
    )    
    shadow_id : int = field(
        default=0, metadata={"help": "The id of the shadow model"}
    )
    ratio_change : float = field(
        default=0.01, metadata={"help": "Ratio of samples to change"}
    )
    prefix_length : int = field(
        default=10, metadata={"help": "Length of the prefix"}
    )
    topk : int = field(
        default=10, metadata={"help": "Top k tokens to change"}
    )
    prefix_type : str = field(
        default='none', metadata={"help": "Type of prefix to use"}
    )
    data_cache_dir: str = field(metadata={"help": "The directory where the datasets are stored."}
    )

@dataclass
class ModelArguments:
    """
    Arguments pretaining to which model/config/tokenizer we are going to fine-tune from.
    """
    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained Prefix model or model identifier from huggingface.co/models"}
    )
    lm_model_name_or_path: str = field(
        default = "gpt2",
        metadata={"help": "Path to pretrained LMHead model or LMHead model identifier from huggingface.co/models"}
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
    )
    use_fast_tokenizer: bool = field(
        default=True,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    
    use_auth_token: bool = field(
        default=False,
        metadata={
            "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
            "with private models)."
        },
    )
    prefix: bool = field(
        default=False,
        metadata={
            "help": "Will use P-tuning v2 during training"
        }
    )
    prompt: bool = field(
        default=False,
        metadata={
            "help": "Will use prompt tuning during training"
        }
    )
    lora: bool = field(
        default=False,
        metadata={
            "help": "Will use PEFT LoRA during training"
        }
    )
    mixed: bool = field(
        default=False,
        metadata={
            "help": "mix between lora and prefix"
        }
    )
    soft_prompt: bool = field(
        default=False,
        metadata={
            "help": "Will use soft prompt tuning during training"
        }
    )
    last_layer: bool = field(
        default=False,
        metadata={
            "help": "Will use soft prompt tuning during training"
        }
    )
    pre_seq_len: int = field(
        default=5,
        metadata={
            "help": "The length of prompt"
        }
    )
    prefix_projection: bool = field(
        default=True,
        metadata={
            "help": "Apply a two-layer MLP head over the prefix embeddings"
        }
    ) 
    prefix_hidden_size: int = field(
        default=512,
        metadata={
            "help": "The hidden size of the MLP projection head in Prefix Encoder if prefix projection is used"
        }
    )
    hidden_dropout_prob: float = field(
        default=0,
        metadata={
            "help": "The dropout probability used in the models"
        }
    )
    loading_4_bit: bool = field(
        default=False,
        metadata={
            "help": "if the model should be loaded in 4 bit"
        }
    ) 
    constant_scheduler: bool = field(
        default=False,
        metadata={
            "help": "Use the constant scheduler"
        }
    ) 
    lr_lora: float = field(
        default=0,
        metadata={
            "help": "Lora lr for mixed models"
        }
    ) 
    lora_rank: int = field(
        default=8,
        metadata={
            "help": "Rank used for lora"
        }
    ),
    train_from_scratch: bool = field(
        default=False,
        metadata={
            "help": "Train from scratch"
        }
    )


def get_args():
    """Parse all the args."""
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, dp_transformers.TrainingArguments, dp_transformers.PrivacyArguments))
    args = parser.parse_args_into_dataclasses()
    
    return args


def print_args(model_args, data_args, training_args, privacy_args, path_to_file: str):
        
    with open (path_to_file, mode='w+') as f:
        f.write('=' * 10)
        f.write('---MODEL ARGS---\n')
        for k, v in model_args.__dict__.items():
            f.write('        - {} : {}'.format(k, v))
        f.write('\n---DATA ARGS---\n')
        for k, v in data_args.__dict__.items():
            f.write('        - {} : {}'.format(k, v))
        f.write('\n---TRAINING ARGS---\n')
        for k, v in training_args.__dict__.items():
            f.write('        - {} : {}'.format(k, v))
        f.write('\n---PRIVACY ARGS---\n')
        for k, v in privacy_args.__dict__.items():
            f.write('        - {} : {}'.format(k, v))
