import random
import torch
from typing import Literal
from dataclasses import dataclass, field

@dataclass
class Feature:
    name: str
    description: str
    description_embedding: dict[str, torch.Tensor]
    dtype: Literal["real", "categorical"]
    categories_embedding: dict[str, torch.Tensor] 
    categories: list[str] = field(default_factory=list)
    value_range: list[float] = field(default_factory=list)

@dataclass
class TabularData:
    description: str
    features: list[Feature]
    train_rows: list[list[float]]
    test_rows: list[list[float]]
    
@dataclass
class Example:
    description: str
    features: list[Feature]
    fewshot_rows: list[list[float]]
    target_row: list[float]
    target_column_id: int
    missing_column_ids: list[int] = field(default_factory=list)

def simple_data_collator(batch: list[Example]):
    return batch
    
def parse_kshot_setting(kshot_setting) -> int:
    if kshot_setting.startswith("fixed::"):
        return int(kshot_setting.removeprefix("fixed::"))
    elif kshot_setting.startswith("range::"):
        minv, maxv = kshot_setting.removeprefix("range::").split(":")
        return random.choice(range(int(minv), int(maxv)+1))
    # elif column_missing_setting.startswith("fraction::"):
    #     frac = float(column_missing_setting.removeprefix("fraction::"))
    #     return int(frac * num_columns)
    else:
        raise NotImplementedError()
import random

def parse_missing_setting(column_missing_setting: str, num_columns: int) -> int:
    """
    Turn a setting string like
      - "fixed::3"              → exactly 3 missing columns
      - "range::1:5"            → a random integer in [1,5]
      - "fraction::0.2"         → 20% of num_columns, rounded
      - "completely_at_random"  → we’ll handle separately in sample_missing_columns
    into an integer count.
    """
    if column_missing_setting.startswith("fixed::"):
        return int(column_missing_setting.removeprefix("fixed::"))
    elif column_missing_setting.startswith("range::"):
        minv, maxv = column_missing_setting.removeprefix("range::").split(":")
        return random.randint(int(minv), int(maxv))
    elif column_missing_setting.startswith("fraction::"):
        frac = float(column_missing_setting.removeprefix("fraction::"))
        return int(frac * num_columns)
    else:
        # if you want to treat other labels (like "none") here,
        # you can return 0, but sample_missing_columns already handles "none"
        raise NotImplementedError(f"Unknown missing‐columns setting: {column_missing_setting}")
    
def sample_missing_columns(num_columns, column_missing_setting) -> list[int]:
    if column_missing_setting == "none":
        return []
    elif column_missing_setting == "completely_at_random":
        return [i for i in range(num_columns) if random.random() > 0.5]
    elif column_missing_setting.startswith("range::"):
        minv, maxv = column_missing_setting.removeprefix("range::").split(":")
    # parse_missing_setting may return a negative or too‐large k
        num_missing = parse_missing_setting(column_missing_setting, num_columns)
    # clamp k between 0 and num_columns
        num_missing = max(0, min(num_missing, num_columns))
        return random.sample(range(num_columns), k=num_missing)
    else:
        raise NotImplementedError()