import json
from typing import List, Union
from pydantic import BaseModel
from .single_layer import NeighbourhoodCosineSimilarity, LaplacianPyramid, Ring, Ring1D


class NesimConfig(BaseModel, extra="allow"):
    """
    layer_wise_configs: list of SingleLayerConfig instances.
    They specify which layers to watch and apply loss upon.
    """

    layer_wise_configs: List[
        Union[NeighbourhoodCosineSimilarity, LaplacianPyramid, Ring, Ring1D]
    ]

    def save_json(self, filename: str):
        with open(filename, "w") as file:
            json.dump(self.model_dump(), file, indent=4)

    @classmethod
    def from_json(cls, filename: str):
        with open(filename, "r") as file:
            json_data = json.load(file)
        # return cls.model_validate(json_data)
        return cls(**json_data)
