import json
from typing import Optional, List
from dataclasses import dataclass, field
from peft import LoraConfig
import transformers


@dataclass
class ModelArguments:
    gt_enocder_path: str = field(default="vicuna-7b-v1.5")
    llm_path: str = field(default="vicuna-7b-v1.5")
    node_dim: int = field(default=768)
    llm_dim: int = field(default=4096)
    projector_hidden_act: str = field(
        default="relu",
        metadata={"help": "activation function of projector"}
    )
    memory_token_nums: int = field(
        default=128,
        metadata={"help": "number of memory tokens"}
    )
    spatial_pos_max: int = field(
        default=10,
        metadata={"help": "max distance between nodes"}
    )
    model_arch: str = field(
        default='vicuna',
        metadata={"help": "model version"}
    )


@dataclass
class DataArguments:
    datasets: str = field(
        default='arxiv',
        metadata={"help": "training datasets"}
    )
    data_weights: str = field(
        default='1',
        metadata={"help": "training datasets weights"}
    )


@dataclass
class TrainingArguments(transformers.TrainingArguments):

    model_max_length: int = field(
        default=2048,
        metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
    )
    use_peft: bool = False
    """whether to use peft"""
    peft_config: Optional[LoraConfig] = field(
        default_factory=lambda: LoraConfig(
            r=128,
            lora_alpha=32,
            lora_dropout=0.05,
            target_modules=["q_proj", "v_proj"],
            bias='none',
            task_type='CAUSAL_LM',
        ),
    )
    output_dir: str = field(
        default='',
        metadata={"help": "output dir"}
    )
    result_dir: str = field(
        default='',
        metadata={"help": "results dir"}
    )
    inference: bool = False
    bf16: bool = True
    fix_encoder: bool = False
    fix_mem: bool = False
    pure_icae: bool = False
    proj_learning_rate: float = field(
        default=1e-4, 
        metadata={"help": "The initial learning rate for projector."}
    )
    graph_learning_rate: float = field(
        default=2e-3, 
        metadata={"help": "The initial learning rate for graph."}
    )
    project_name: str = field(
        default='UniGTE',
        metadata={"help": "The name of wandb group."}
    )
    log_with: str = field(
        default='wandb',
    )