from pydantic import BaseModel, Extra, Field
from typing import Union, List
from .exponential_scale_scheduler import ExponentialDecayScale


class NeighbourhoodCosineSimilarity(BaseModel, extra=Extra.forbid):
    """
    - `layer_name`: name of layer in model, something like "model.fc1"
    - `scale`: scale by which the loss for this layer is to be multiplied. If None, then will just watch the layer's loss.
    """

    layer_name: str
    scale: Union[None, float]
    ## loss_type holds the name, makes the json easier to read
    loss_type: str = Field("neighbourhood_cossim", Literal=True, type=str)


class LaplacianPyramid(BaseModel, extra=Extra.forbid):
    """
    - `layer_name`: name of layer in model, something like "model.fc1"
    - `scale`: scale by which the loss for this layer is to be multiplied. If None, then will just watch the layer's loss.
    - `shrink_factor`: factor by which the grid is shrinked before it gets resized back to it's original size
    """

    layer_name: str
    scale: Union[None, float, ExponentialDecayScale]
    shrink_factor: List[float]
    ## loss_type holds the name, makes the json easier to read
    loss_type: str = Field("laplacian_pyramid", Literal=True, type=str)


class Ring(BaseModel, extra=Extra.forbid):
    layer_name: str
    scale: Union[None, float, ExponentialDecayScale]
    radius_inner: float
    radius_outer: float
    ## loss_type holds the name, makes the json easier to read
    loss_type: str = Field("ring", Literal=True, type=str)

class Ring1D(BaseModel, extra=Extra.forbid):
    layer_name: str
    scale: Union[None, float, ExponentialDecayScale]
    freq_inner: int
    freq_outer: int
    ## loss_type holds the name, makes the json easier to read
    loss_type: str = Field("ring_1d", Literal=True, type=str)