import lightning
import numpy as np
import torch
import logging

from load_nf2vec import NeRFsDataset


class HoNeRF2VecDataModule(lightning.LightningDataModule):
    def __init__(self, args) -> None:
        super().__init__()
        self.args = args
        assert args.dataset_type == "nerf2vec", f"dataset type is {args.dataset_type}"
        self.near = args.near
        self.far = args.far
        # default seed if not provided
        self.seed = getattr(args, 'seed', 42)

        # prepare a fixed generator and worker init for reproducibility
        self._generator = torch.Generator()
        self._generator.manual_seed(self.seed)

        def _worker_init_fn(worker_id):
            worker_seed = (torch.initial_seed() + worker_id) % 2**32
            import numpy as _np, random as _random
            _np.random.seed(worker_seed)
            _random.seed(worker_seed)
        self._worker_init_fn = _worker_init_fn

    def set_cam_param(self, dataset):
        H, W, focal = dataset.H, dataset.W, dataset.focal
        H, W = int(H), int(W)
        hwf = [H, W, focal]

        K = np.array([[focal, 0, 0.5 * W], [0, focal, 0.5 * H], [0, 0, 1]])

        self.H = H
        self.W = W
        self.K = K
        self.hwf = hwf

    def setup(self, stage="fit"):
        if hasattr(self, "train_dataset") or hasattr(self, "test_dataset"):
            return
        self.train_dataset = NeRFsDataset(
            self.args.datadir,
            self.args.half_res,
            "train",
            self.args.white_bkgd,
            self.args.num_cond,
            cond_split="train",
            base_seed=self.seed,
        )
        try:
            objvse_datadir = self.args.novel_objvse_datadir
        except:
            objvse_datadir = self.args.novel_cond_datadir
        self.novel_objvse_dataset = NeRFsDataset(
            objvse_datadir,
            self.args.half_res,
            "test",
            self.args.white_bkgd,
            self.args.num_cond,
            cond_split="test",
            base_seed=self.seed,
        )

        self.set_cam_param(self.train_dataset)

    def train_dataloader(self):
        dataloader = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.args.batch_size,
            shuffle=True,
            num_workers=6,
            generator=self._generator,
            worker_init_fn=self._worker_init_fn,
        )
        return dataloader

    def val_dataloader(self):     
        dataloader_novel_objvse = torch.utils.data.DataLoader(
            self.novel_objvse_dataset,
            batch_size=self.args.batch_size,
            shuffle=False,
            num_workers=4,
            generator=self._generator,
            worker_init_fn=self._worker_init_fn,
        )
        return [
            dataloader_novel_objvse,
        ]

    def test_dataloader(self):
        dataloader_novel_objvse = torch.utils.data.DataLoader(
            self.novel_objvse_dataset,
            batch_size=self.args.batch_size,
            shuffle=False,
            num_workers=4,
            generator=self._generator,
            worker_init_fn=self._worker_init_fn,
        )
        return [
            dataloader_novel_objvse,
        ]
