import argparse
from dataclasses import dataclass, field
from typing import List, Optional
import torch

@dataclass
class ExistenceConfig:

    model_name: str = "distilroberta-base"

    datasets: List[str] = field(default_factory=lambda: ["wikitext", "openwebtext"])
    wikitext_config: str = "wikitext-2-raw-v1"
    wikitext_split: str = "test"
    openwebtext_split: str = "train"

    n_slots: int = 800
    n_slots_cpu: int = 200

    l_list: List[int] = field(default_factory=lambda: [0, 1, 2, 4, 8])
    r_list: List[int] = field(default_factory=lambda: [0, 1, 2, 4, 8])

    k_cand: int = 50

    eps: float = 1e-12

    seed: int = 42

    cache_dir: str = "./results"
    figure_dir: str = "./figures"

    force_recompute: bool = False
    batch_size: int = 16

    device: str = field(default_factory=lambda: "cuda" if torch.cuda.is_available() else "cpu")

    def __post_init__(self):

        if self.device == "cpu":
            if self.n_slots > self.n_slots_cpu:
                print(f"[Config] No GPU detected. Reducing n_slots from {self.n_slots} to {self.n_slots_cpu}")
                self.n_slots = self.n_slots_cpu

    @property
    def l_max(self) -> int:
        return max(self.l_list)

    @property
    def r_max(self) -> int:
        return max(self.r_list)

    @property
    def min_seq_length(self) -> int:

        return self.l_max + self.r_max + 5

def parse_args() -> ExistenceConfig:

    parser = argparse.ArgumentParser(
        description="Existence evaluation: Curvature fingerprints in two-sided inference geometry",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    parser.add_argument(
        "--model_name",
        type=str,
        default="distilroberta-base",
        choices=["distilroberta-base", "roberta-base"],
        help="Masked language model to use"
    )

    parser.add_argument(
        "--datasets",
        type=str,
        nargs="+",
        default=["wikitext", "openwebtext"],
        help="Datasets to evaluate on"
    )
    parser.add_argument(
        "--wikitext_config",
        type=str,
        default="wikitext-2-raw-v1",
        choices=["wikitext-2-raw-v1", "wikitext-103-raw-v1"],
        help="Wikitext configuration (use 103 for larger runs)"
    )
    parser.add_argument(
        "--wikitext_split",
        type=str,
        default="test",
        choices=["train", "validation", "test"],
        help="Wikitext split to use"
    )

    parser.add_argument(
        "--n_slots",
        type=int,
        default=800,
        help="Number of slots to sample per dataset"
    )

    parser.add_argument(
        "--l_list",
        type=int,
        nargs="+",
        default=[0, 1, 2, 4, 8],
        help="Left context radii"
    )
    parser.add_argument(
        "--r_list",
        type=int,
        nargs="+",
        default=[0, 1, 2, 4, 8],
        help="Right context radii"
    )

    parser.add_argument(
        "--k_cand",
        type=int,
        default=50,
        help="Top-k candidates per anchor for S_i construction"
    )

    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed"
    )

    parser.add_argument(
        "--cache_dir",
        type=str,
        default="./results",
        help="Directory for caching results"
    )
    parser.add_argument(
        "--figure_dir",
        type=str,
        default="./figures",
        help="Directory for saving figures"
    )

    parser.add_argument(
        "--force_recompute",
        action="store_true",
        help="Force recomputation even if cached results exist"
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=16,
        help="Batch size for forward passes"
    )

    args = parser.parse_args()

    config = ExistenceConfig(
        model_name=args.model_name,
        datasets=args.datasets,
        wikitext_config=args.wikitext_config,
        wikitext_split=args.wikitext_split,
        n_slots=args.n_slots,
        l_list=args.l_list,
        r_list=args.r_list,
        k_cand=args.k_cand,
        seed=args.seed,
        cache_dir=args.cache_dir,
        figure_dir=args.figure_dir,
        force_recompute=args.force_recompute,
        batch_size=args.batch_size,
    )

    return config

def get_default_config() -> ExistenceConfig:

    return ExistenceConfig()
