import json
import math
import os
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import torch
from lightning import LightningDataModule
from omegaconf import DictConfig, ListConfig
from PIL import Image
from torch.utils.data import DataLoader, Dataset

os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
import io
import json
import random

import cv2
import requests
from einops import rearrange, repeat
from torch.utils.data.dataloader import default_collate
from tqdm import tqdm

from src.models.gs.cameras import Cameras
from src.utils.point import fov2focal


def collect_fn(batch):
    cameras = [item.pop("cameras") for item in batch]
    batch_processed = default_collate(batch)
    batch_processed["cameras"] = cameras
    return batch_processed


class MultiViewDataset(Dataset):
    def __init__(
        self,
        root_dir: str,
        bg_color: str,
        num_frames: int,
        meta_file: str = None,
        img_wh: Tuple[int, int] = (512, 512),
        relative_pose: bool = True,
        num_samples: int = -1,
        repeat: int = 1,
    ):
        self.root_dir = root_dir

        self.files = [
            os.path.join(self.root_dir, file)
            for file in sorted(os.listdir(self.root_dir))
        ]
        self.meta_file = meta_file
        self.img_wh = img_wh
        self.relative_pose = relative_pose
        self.num_frames = num_frames
        self.bg_color = self.get_bg_color(bg_color)
        self.repeat = repeat

    def get_bg_color(self, bg_color):
        if bg_color == "white":
            bg_color = np.array([1.0, 1.0, 1.0], dtype=np.float32)
        elif bg_color == "black":
            bg_color = np.array([0.0, 0.0, 0.0], dtype=np.float32)
        elif bg_color == "gray":
            bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32)
        elif bg_color == "random":
            bg_color = np.random.rand(3)
        elif isinstance(bg_color, float):
            bg_color = np.array([bg_color] * 3, dtype=np.float32)
        else:
            raise NotImplementedError
        return bg_color

    def load_image(self, img_path, rescale=True, return_type="np"):
        # not using cv2 as may load in uint16 format
        # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255]
        # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC)
        # pil always returns uint8
        # split root_dir and path
        im = Image.open(img_path)
        bg_color = self.bg_color
        img = np.array(im.resize(self.img_wh))
        img = img.astype(np.float32) / 255.0  # [0, 1]

        if img.shape[-1] == 4:
            alpha = img[..., 3:4]
            img = img[..., :3] * alpha + bg_color * (1 - alpha)
        if rescale:
            img = img * 2.0 - 1.0  # to -1 ~ 1
        if return_type == "np":
            pass
        elif return_type == "pt":
            img = torch.from_numpy(img)
        else:
            raise NotImplementedError

        return img

    def load_normal(self, normal_path, return_type="np"):
        img = np.array(Image.open(normal_path).resize(self.img_wh))
        img = img.astype(np.float32) / 255.0  # [0, 1]

        if img.shape[-1] == 4:
            alpha = img[..., 3:4]
            img = img[..., :3] * alpha

        img = img * 2.0 - 1.0  # to -1 ~ 1

        if return_type == "np":
            pass
        elif return_type == "pt":
            img = torch.from_numpy(img)
        else:
            raise NotImplementedError
        return img

    def load_depth(self, img_path):
        img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
        img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_NEAREST)
        depth = img[None, :, :, 3]
        mask = depth == 0  # depth = 65535 is invalid
        mask = ~mask
        return depth, mask.astype(np.float32)

    def __len__(self):
        return len(self.files) * self.repeat

    def load_instance(self, path, tunck=25):
        # 00000 - 00039
        im_list, depth_list, mask_list = [], [], []
        c2w_list, fov_x_list, fov_y_list = [], [], []
        # uniform sample 8 views

        selected_indices = np.linspace(0, 23, tunck, dtype=np.int)

        for i in selected_indices:
            sub_path = "{:05d}".format(i)
            full_path = os.path.join(path, sub_path)
            im_path = os.path.join(full_path, f"{sub_path}.png")
            im = self.load_image(im_path)
            depth_path = os.path.join(full_path, f"{sub_path}_nd.exr")
            depth, mask = self.load_depth(depth_path)
            json_path = os.path.join(full_path, f"{sub_path}.json")
            json_content = self.load_json(json_path)
            c2w = np.eye(4)
            c2w[:3, 0] = np.array(json_content["x"])
            c2w[:3, 1] = np.array(json_content["y"])
            c2w[:3, 2] = np.array(json_content["z"])
            c2w[:3, 3] = np.array(json_content["origin"])
            x_fov = json_content["x_fov"]
            y_fov = json_content["y_fov"]
            im_list.append(torch.from_numpy(im))
            depth_list.append(torch.from_numpy(depth))
            mask_list.append(torch.from_numpy(mask))
            c2w_list.append(torch.from_numpy(c2w))
            fov_x_list.append(x_fov)
            fov_y_list.append(y_fov)
        return im_list, depth_list, mask_list, c2w_list, fov_x_list, fov_y_list

    def get_cameras(self, c2w, fov_x_list, fov_y_list):
        w2c = torch.linalg.inv(c2w).to(torch.float)
        width = self.img_wh[0]
        fx = torch.tensor(
            [fov2focal(fov=fov_x, pixels=width) for fov_x in fov_x_list],
            dtype=torch.float32,
        )
        fy = torch.tensor(
            [fov2focal(fov=fov_y, pixels=width) for fov_y in fov_y_list],
            dtype=torch.float32,
        )
        width = torch.tensor([width], dtype=torch.float32).expand(c2w.shape[0])
        height = torch.clone(width)
        cameras = Cameras(
            R=w2c[:, :3, :3],
            T=w2c[:, :3, 3],
            fx=fx,
            fy=fy,
            cx=width / 2,
            cy=height / 2,
            width=width,
            height=height,
            appearance_id=torch.zeros_like(width),
            normalized_appearance_id=torch.zeros_like(width),
            distortion_params=None,
            camera_type=torch.zeros_like(width),
        )
        return cameras

    def __getitem__(self, index):
        index = index % len(self.files)
        file = self.files[index]
        im = self.load_image(file)
        im = rearrange(im, "h w c -> c h w")
        meta_file = json.load(open(self.meta_file))
        fov = meta_file["camera_angle_x"]
        c2w_list = np.array(
            [item["transform_matrix"] for item in meta_file["locations"]]
        )
        select_indices = np.linspace(
            0, len(c2w_list) - 1, self.num_frames, dtype=np.int
        )

        c2w_list = c2w_list[select_indices]
        focal_length = 0.5 * 1 / np.tan(0.5 * fov)
        intrinsics = np.array(
            [[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]],
            dtype=np.float32,
        )
        intrinsics = torch.from_numpy(intrinsics)
        c2w_matrixs = torch.from_numpy(c2w_list)  # N x 4 x 4
        # to opencv
        c2w_matrixs[:, :3, 1:3] *= -1

        num_views = c2w_matrixs.shape[0]
        intrinsics = repeat(intrinsics, "h w -> n h w", n=num_views)

        if self.relative_pose:
            src_w2c = torch.inverse(c2w_matrixs[:1])  # (1, 4, 4)
            src_distance = c2w_matrixs[:1, :3, 3].norm(dim=-1)  # (1)
            canonical_c2w = torch.matmul(src_w2c, c2w_matrixs)  # (Nv, 4, 4) z as x axis
            # shift to origin depth
            shift = torch.tensor(
                [
                    [1, 0, 0, 0],
                    [0, -1, 0, 0],
                    [0, 0, -1, src_distance],
                    [0, 0, 0, 1],
                ]
            )
            canonical_c2w = torch.matmul(shift, canonical_c2w)
            c2w_matrixs = canonical_c2w
        fov_x_list = [fov] * num_views
        fov_y_list = [fov] * num_views
        cameras = self.get_cameras(c2w_matrixs, fov_x_list, fov_y_list)
        fov = fov * 180 / np.pi
        fov = torch.tensor([fov] * c2w_matrixs.shape[0], dtype=torch.float32)
        return {
            "id": file.split("/")[-1].rsplit(".", 1)[0],
            "condition_image": torch.from_numpy(im),
            "intrinsics": intrinsics,
            "cameras": cameras,
            "c2w": c2w_matrixs,
            "diffusion_images": torch.rand(
                self.num_frames, 3, self.img_wh[0], self.img_wh[1]
            ),
            "fov": fov,
        }


class MultiViewDataModule(LightningDataModule):
    def __init__(
        self,
        train_dataset: Dataset[Any] = None,
        val_dataset: Dataset[Any] = None,
        test_dataset: Dataset[Any] = None,
        train_batch_size: int = 1,
        val_batch_size: int = 1,
        test_batch_size: int = 1,
        num_workers: Optional[int] = None,
        pin_memory: bool = True,
        real_dataset=None,
        real_batch_size=-1,
    ):
        super().__init__()

        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False)

        self.data_train = train_dataset
        self.data_val = val_dataset
        self.data_test = test_dataset
        self.data_real = real_dataset

        self.num_workers = num_workers if num_workers else train_batch_size * 2

    def prepare_data(self) -> None:
        # TODO: check if data is available
        pass

    def _dataloader(
        self, dataset: Dataset, batch_size: int, shuffle: bool
    ) -> DataLoader[Any]:
        return DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=self.num_workers,
            pin_memory=self.hparams.pin_memory,
            collate_fn=collect_fn,
        )

    def real_dataloader(self):
        return DataLoader(
            self.data_real,
            batch_size=self.hparams.real_batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=self.hparams.pin_memory,
            collate_fn=collect_fn,
        )

    def train_dataloader(self) -> DataLoader[Any]:
        if self.data_train is None:
            return None
        return DataLoader(
            self.data_train,
            batch_size=self.hparams.train_batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=self.hparams.pin_memory,
            collate_fn=collect_fn,
        )

    def val_dataloader(self) -> DataLoader[Any]:
        if isinstance(self.data_val, ListConfig):
            return [
                self._dataloader(dataset, self.hparams.val_batch_size, False)
                for dataset in self.data_val
            ]
        elif isinstance(self.data_val, DictConfig):
            return [
                self._dataloader(dataset, self.hparams.val_batch_size, False)
                for _, dataset in self.data_val.items()
            ]
        else:
            return self._dataloader(self.data_val, self.hparams.val_batch_size, False)

    def test_dataloader(self) -> DataLoader[Any]:
        if isinstance(self.data_test, ListConfig):
            return [
                self._dataloader(dataset, self.hparams.test_batch_size, False)
                for dataset in self.data_test
            ]
        elif isinstance(self.data_test, DictConfig):
            return [
                self._dataloader(dataset, self.hparams.test_batch_size, False)
                for _, dataset in self.data_test.items()
            ]
        else:
            return self._dataloader(self.data_test, self.hparams.test_batch_size, False)


import cv2
import numpy as np


def save_video(numpy_array, output_path, fps=30):
    """
    将 numpy 数组保存为视频文件。

    参数:
    - numpy_array: 形状为 (t, h, w, c) 的 NumPy 数组，t 是帧数，h 和 w 是帧的高度和宽度，c 是颜色通道数。
    - output_path: 输出视频的文件路径。
    - fps: 视频的帧率。
    """
    numpy_array = numpy_array.astype(np.uint8).transpose(0, 2, 3, 1)
    t, h, w, c = numpy_array.shape

    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))

    for i in range(t):
        frame = numpy_array[i]
        out.write(frame)
    out.release()


from tqdm import tqdm

from src.utils.visual import visual_camera

if __name__ == "__main__":
    from pytorch_lightning import seed_everything
    from torchvision.utils import save_image

    seed_everything(42)
    dataset = MultiViewDataset(
        root_dir="data/test/images",
        meta_file="data/test/meta.json",
        bg_color="white",
        num_frames=8,
        img_wh=(512, 512),
    )
    # split train test
    from src.utils.visual import visual_camera

    item = dataset[1]
    visual_camera(item["c2w"], item["intrinsics"])
    # visual all off the image
    images = item["diffusion_images"]
    images = rearrange(images, "n c h w -> c h (n w)")
    save_image(images, "visual/diffusion_images.png")
    save_image(item["condition_image"], "visual/condition_image.png")
    # for i in tqdm(range(0, 200000, 10000)):
    #     item = dataset[i]
