from dataclasses import dataclass

import math
import torch

from ..classes import MLP, Hyperparameters, ModelInterface
from .decoder import Decoder


@dataclass
class ModelHyperparameters(Hyperparameters):
    num_data: int = -1
    z_dim: int = 256


class Model(ModelInterface):
    def __init__(self, decoder: Decoder, initial_z: torch.Tensor):
        super().__init__()
        self.decoder = decoder
        if initial_z is not None:
            self.z_map = torch.nn.Parameter(initial_z)

    def decode(self, x: torch.Tensor, z: torch.Tensor):
        z = z.unsqueeze(1).expand(x.shape[:2] + (z.shape[1], ))
        f = self.decoder(x, z)
        return f

    def z(self, data_indices: torch.Tensor):
        return self.z_map[data_indices]

    def forward(self, x: torch.Tensor, data_indices: torch.Tensor):
        z = self.z(data_indices)
        return self.decode(x, z)
