import math
import os

import h5py
import numpy as np
import yaml
from torch import Tensor

from smlm.activations.io import WriterInterface


class PicassoWriter(WriterInterface):
    LOCS_DTYPE = np.dtype(
        [
            ("frame", "<u4"),
            ("x", "<f4"),
            ("y", "<f4"),
            ("z", "<f4"),
            ("photons", "<f4"),
            ("lpx", "<f4"),
            ("lpy", "<f4"),
        ]
    )

    def __init__(self, filepath: str, pixel_size: Tensor):
        if os.path.splitext(filepath)[-1].lower() != ".hdf5":
            raise ValueError(f"Picasso only supports .hdf5 files: found{filepath}")
        self.file_path = filepath
        self.pixel_size = (pixel_size[0].item(), pixel_size[1].item())

    def open(self):
        self.total_rows = 0
        self.max_frame_idx = 0
        self.max_x = 0
        self.max_y = 0

        self.file = h5py.File(self.file_path, "w")
        self.ds = self.file.create_dataset(
            "locs",
            shape=(self.total_rows,),
            maxshape=(None,),
            dtype=self.LOCS_DTYPE,
            chunks=True,
            compression="gzip",
            compression_opts=4,
        )

    def close(self):
        yaml_path = os.path.splitext(self.file_path)[0] + ".yaml"
        info = {
            "Frames": self.max_frame_idx,
            "Generated by": "Shot",
            "Height": math.ceil(self.max_y),
            "Width": math.ceil(self.max_x),
        }
        with open(yaml_path, "w") as yf:
            yaml.dump(info, yf, default_flow_style=False)

        if self.file:
            self.file.close()

    def _write(self, data: Tensor):
        n = len(data)
        if n == 0:
            return

        data = data[:, :5].round().cpu().numpy()
        recs = np.zeros(n, dtype=self.LOCS_DTYPE)
        recs["frame"] = data[:, 0]
        recs["x"] = data[:, 1] / self.pixel_size[0]
        recs["lpx"] = 1.0
        recs["y"] = data[:, 2] / self.pixel_size[1]
        recs["lpy"] = 1.0
        recs["z"] = data[:, 3]
        recs["photons"] = data[:, 4]

        self.ds.resize(self.total_rows + n, axis=0)
        self.ds[self.total_rows : self.total_rows + n] = recs
        self.total_rows += n

        self.max_frame_idx = max(self.max_frame_idx, recs["frame"].max().item())
        self.max_x = max(self.max_x, recs["x"].max().item())
        self.max_y = max(self.max_y, recs["y"].max().item())
