from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import numpy as np
import torch


@dataclass(frozen=True)
class Theta:
    weights: np.ndarray  # (F,)
    name: str


def uniform_theta(num_features: int) -> Theta:
    w = np.full((int(num_features),), 1.0 / float(num_features), dtype=np.float32)
    return Theta(weights=w, name="uniform")


def load_theta(
    path: str | Path, *, expected_num_features: Optional[int] = None
) -> Theta:
    p = Path(path)
    ckpt = torch.load(str(p), map_location="cpu")

    # Preferred format (matches the original repo):
    #   ckpt["state_dict"]["linear.weight"] has shape (1, F)
    if isinstance(ckpt, dict) and "state_dict" in ckpt:
        sd = ckpt["state_dict"]
        if "linear.weight" not in sd:
            raise ValueError(f"Missing state_dict['linear.weight'] in {p}")
        w = sd["linear.weight"]
        if isinstance(w, torch.Tensor):
            w = w.detach().cpu()
        w = np.asarray(w)
        if w.ndim != 2 or int(w.shape[0]) != 1:
            raise ValueError(
                f"Expected linear.weight shape (1,F), got {w.shape} in {p}"
            )
        w = w[0].astype(np.float32)
    else:
        # Fallback: user provided a raw tensor/array
        w = np.asarray(ckpt).astype(np.float32)

    if expected_num_features is not None and int(w.shape[0]) != int(
        expected_num_features
    ):
        raise ValueError(
            f"theta dimension mismatch: expected {expected_num_features}, got {int(w.shape[0])}"
        )

    return Theta(weights=w, name=str(p))
