import re
import json
from pathlib import Path
from datetime import datetime
from dataclasses import dataclass
from typing import Any

@dataclass(frozen=True)
class Experiment:
    experiment_name: str
    run_name: str
    timestamp: str

    experiment_dir: Path
    train_data_path: str
    val_data_path: str
    test_data_path: str

    problem_name: str
    dataset_name: str

    seed: int

    wandb_project: str
    wandb_group: str
    wandb_run_name: str
    train_config: dict[str, Any]

def get_problem_name(train_data_path: str):
    match_problem_name = re.search(r"(?<=data\/)([^\/]+)", train_data_path)
    if match_problem_name is None:
        raise ValueError(f"Could not extract problem name from path: {train_data_path}")
    problem_name = match_problem_name.group(1)
    return problem_name

def get_timestamp() -> str:
    return datetime.now().strftime("%Y%m%d_%H%M%S")

def get_exp_name(problem_name: str, ot_type: str, clean_or_noisy: str, sinkhorn_or_wgan: str, full_rollout_or_anchored_on_ot: str) -> str:
    exp_name = f"{problem_name}_{ot_type}_{clean_or_noisy}_{sinkhorn_or_wgan}_{full_rollout_or_anchored_on_ot}"
    return exp_name

def get_run_name(epochs: int, lambda_ot: float, noise_level: float, seed: int) -> str:
    run_name = f"epochs{epochs}_lambda{lambda_ot}_noise{noise_level}_seed{seed}"
    return run_name

def get_exp_path(exp_name: str, run_name: str, timestamp) -> Path:
    exp_path = Path("outputs") / exp_name / (timestamp + '_' +run_name)
    return exp_path

def get_dataset_config(train_data_path: str) -> dict[str, Any]:
    with open(train_data_path, 'rb') as f:
        dataset_config = json.load(f)
    return dataset_config