from dataclasses import dataclass, field
import json
from typing import List, Optional, Tuple, Union, Dict
import numpy as np
from transformers import HfArgumentParser

@dataclass 
class DataArguments:
    op: Optional[str] = None
    frac: Optional[float] = 1.0
    kwargs: Optional[dict] = field(default_factory=dict)
    eval_keys: Optional[List[str]] = None
    tied_keys: Optional[List[List[str]]] = None
    dataset_path: Optional[str] = None

@dataclass
class ScriptArguments:
    # General arguments
    train_algo: Optional[str] = 'SFT'
    run_name_prefix: str = ''

    # Model arguments
    use_unsloth: bool = False
    generate_max_length: Optional[int] = None
    model_id: Optional[str] = None
    from_pretrained: bool = False
    revision: Optional[str] = None
    use_lora: bool = False
    lora_config: Optional[dict] = None
    architecture: Optional[str] = 'mamba'
    rope_theta: Optional[float] = np.inf
    partial_rotary_factor: Optional[float] = 1.0
    hidden_size: Optional[int] = 768
    intermediate_size: Optional[int] = 3072
    num_attention_heads: Optional[int] = 12
    num_layers: Optional[int] = 32
    max_position_embeddings: Optional[int] = 1024
    dropout: Optional[float] = 0.0
    use_character_tokenizer: bool = True
    freeze_layers: Optional[List[int]] = None

    # Data arguments
    task_length: Optional[int] = None
    use_iterable_dataset: bool = True
    num_train: int = 1000
    num_eval: int = 100
    train_cache_loc: str = 'data/train_data'
    eval_cache_loc: str = 'data/eval_data'
    train_data: Dict[str, DataArguments] = field(default_factory=dict)
    eval_data: Dict[str, DataArguments] = field(default_factory=dict)
    num_workers: Optional[int] = 8
    add_special_tokens: int = 0
    padding_side: str = 'right'
    mask_prompt: bool = True

    additional_train_data: List[str] = field(default_factory=list)

    def __post_init__(self):
        if self.intermediate_size is None:
            self.intermediate_size = 4 * self.hidden_size
        if self.lora_config is not None: 
            self.lora_config = json.loads(self.lora_config) if isinstance(self.lora_config, str) else self.lora_config
        for i, (task_id, data_args) in enumerate(self.train_data.items()):
            if not isinstance(data_args, DataArguments):
                self.train_data[task_id] = DataArguments(**data_args)
        for i, (task_id, data_args) in enumerate(self.eval_data.items()):
            if not isinstance(data_args, DataArguments):
                self.eval_data[task_id] = DataArguments(**data_args)
