import argparse
import os
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any

import argparse
import os
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any

import argparse
import os
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any

@dataclass
class ExperimentConfig:
    experiment_name: str
    model_name: str
    sample_percentage: float 
    max_test_examples: int 
    epochs: int
    batch_size: int
    learning_rate: float 
    gradient_accumulation_steps: int
    save_steps: int 
    save_total_limit: int
    logging_steps: int 
    unfreeze_pct: float
    num_train_epochs: int
    test_only: bool = False
    output_dir: str = None
    load_model_path: Optional[str] = None
    dataset_cache_dir: str = "./cached_datasets"
    freeze_layers: bool = False

    def __post_init__(self):
        if self.output_dir is None:
            self.output_dir = f"./results/{self.experiment_name}"
    
        os.makedirs(self.output_dir, exist_ok=True)
        os.makedirs(self.dataset_cache_dir, exist_ok=True)


def parse_args() -> ExperimentConfig:

    parser = argparse.ArgumentParser()
    
    parser.add_argument('--model_name', type=str, required=True)
    parser.add_argument('--experiment_name', type=str, required=True)
    parser.add_argument('--load_model', type=str, default=None) # for pretrained model only
    parser.add_argument('--dataset_cache_dir', type=str, default='./cached_datasets')
    parser.add_argument('--output_dir', type=str, default=None)
    parser.add_argument('--test_only', type=bool, default=False)
    parser.add_argument('--gradient_accumulation_steps', type=int, default=2)
    parser.add_argument("--epochs", type=int, default=4)
    parser.add_argument("--learning_rate", type=float, default=3e-5)
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--save_steps", type=int, default=1000)
    parser.add_argument("--save_total_limit", type=int, default=2)
    parser.add_argument("--logging_steps", type=int, default=20)
    parser.add_argument("--max_test_examples", type=int, default=100)
    parser.add_argument("--sample_percentage", type=float, default=0.75)
    parser.add_argument("--freeze_layers", type=bool, default=False) 
    parser.add_argument("--unfreeze_pct", type=float, default=0.25)
    parser.add_argument("--num_train_epochs", type=int, default=8)
        
    args = parser.parse_args()

    config = ExperimentConfig(
        experiment_name=args.experiment_name,
        model_name=args.model_name,
        output_dir=args.output_dir, 
        model_ckpt_dir=args.model_ckpt_dir,
        load_model_path=args.load_model,
        cache_dir=args.cache_dir,
        sample_percentage=args.sample_percentage,
        max_test_examples=args.max_test_examples,
        test_only=args.test_only,
        epochs=args.epochs,
        batch_size=args.batch_size,
        learning_rate=args.learning_rate,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        save_steps=args.save_steps,
        save_total_limit=args.save_total_limit,
        logging_steps=args.logging_steps,
        unfreeze_pct=args.unfreeze_pct,
        freeze_layers=args.freeze_layers,
        num_train_epochs = args.epochs,
    )
    
    return config