import os

import tifffile
import torch
from torch import Tensor, nn

import smlm


class Save2Tiff(nn.Module):
    def __init__(self, dirpath: str):
        super().__init__()
        self.dirpath = dirpath
        os.makedirs(self.dirpath, exist_ok=False)
        self.p = nn.Parameter(torch.zeros((1, 64, 1), dtype=torch.float32))
        self.x = nn.Parameter(torch.zeros((1, 64, 4), dtype=torch.float32))

    def forward(self, y: Tensor) -> Tensor:
        if y.ndim != 4:
            raise ValueError("Expect y to have 4 dimensions: (bs, n_frame, h, w)")
        bs = y.size(0)

        for y_ in y:
            hash = smlm.utils.torch.hash_tensor(y_)
            filename = str(abs(hash)) + ".tiff"
            path = os.path.join(self.dirpath, filename)
            y_ = y_.detach()
            y_ = y_.to(dtype=torch.int16, device="cpu")
            y_ = y_.numpy()
            tifffile.imwrite(path, y_)

        # nonsens, just so that x depends on y
        p = torch.sigmoid(self.p)
        x = 1000.0 * torch.sigmoid(self.x)
        x = torch.cat([p, x], dim=-1)
        x = x.expand((bs, 64, 5))

        if self.training:
            bg = y.min(dim=1).values
            return x, bg
        return x
