"""Represents information about parameters."""
import dataclasses
from typing import Optional, Tuple

import torch


###############################################################################


@dataclasses.dataclass(frozen=True, eq=True)
class ParameterInfo:
    """Information about a particular parameter within a model.
    
    Intended to be saved along information tied to parameters such as e.g. Fishers.
    """

    shape: Tuple[int, ...]

    name: Optional[str] = None

    def to_json(self):
        return {'shape': self.shape, 'name': self.name}

    def n_elements(self) -> int:
        ret = 1
        for d in self.shape:
            ret *= d
        return ret

    @classmethod
    def from_parameter(cls, p: torch.nn.Parameter, *, name: Optional[str] = None):
        return cls(shape=tuple(p.shape), name=name)

    @classmethod
    def from_json(cls, json_obj):
        return cls(shape=tuple(json_obj['shape']), name=json_obj['name'])
