import json
import math
import os
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import torch
from lightning import LightningDataModule
from omegaconf import DictConfig, ListConfig
from PIL import Image
from torch.utils.data import DataLoader, Dataset

os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
import json
from io import BytesIO

import cv2
import requests
from einops import rearrange, repeat
from torch.utils.data.dataloader import default_collate

from src.models.gs.cameras import Cameras
from src.utils.point import fov2focal


def collect_fn(batch):
    cameras = [item.pop("cameras") for item in batch]
    batch_processed = default_collate(batch)
    batch_processed["cameras"] = cameras
    return batch_processed


class MultiViewDataset(Dataset):
    def __init__(
        self,
        root_dir: str,
        instance_file: str,
        # num_views: int,
        bg_color: str,
        num_frames: int,
        img_wh: Tuple[int, int],
        num_samples: Optional[int] = -1,
        relative_pose: bool = True,
        repeat: int = 1,
        remote_root: Optional[str] = None,
    ):
        self.root_dir = root_dir
        self.bg_color = self.get_bg_color(bg_color)
        self.img_wh = img_wh
        self.num_frames = num_frames
        self.relative_pose = relative_pose
        self.repeat = repeat

        if remote_root is not None:
            if remote_root.endswith("/"):
                remote_root = remote_root[:-1]
        self.remote_root = remote_root

        instances = json.load(open(instance_file, "r"))  # [dictionary_id/instance_id]
        # sort by dictionary_id if dictionary_id is the same, sort by instance_id
        sorted_instances = sorted(
            instances, key=lambda x: (int(x.split("/")[0]), int(x.split("/")[1]))
        )
        sorted_instances = [os.path.join(root_dir, x) for x in sorted_instances]

        self.sorted_instances = sorted_instances[:num_samples]

        # tmp_path = ['data/gobjaverse/data/0/10510/']
        # self.sorted_instances = tmp_path

    def get_bg_color(self, bg_color):
        if bg_color == "white":
            bg_color = np.array([1.0, 1.0, 1.0], dtype=np.float32)
        elif bg_color == "black":
            bg_color = np.array([0.0, 0.0, 0.0], dtype=np.float32)
        elif bg_color == "gray":
            bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32)
        elif bg_color == "random":
            bg_color = np.random.rand(3)
        elif isinstance(bg_color, float):
            bg_color = np.array([bg_color] * 3, dtype=np.float32)
        else:
            raise NotImplementedError
        return bg_color

    def check_and_download_item(self, item_path):
        if os.path.exists(item_path):
            return
        else:
            if self.remote_root is not None:
                os.makedirs(os.path.dirname(item_path), exist_ok=True)
                splits_ = item_path.split("/")
                item_id = "/".join(splits_[-4:-2])
                base_name = "/".join(splits_[-2:])
                img_url = f"{self.remote_root}/{item_id}/campos_512_v4/{base_name}"
                # download image or json
                if item_path.endswith(".png"):
                    response = requests.get(img_url)
                    img = Image.open(BytesIO(response.content))
                    img.save(item_path)
                elif item_path.endswith(".json") or item_path.endswith(".exr"):
                    response = requests.get(img_url)
                    with open(item_path, "wb") as f:
                        f.write(response.content)
            else:
                raise FileNotFoundError(f"{item_path} not found")

    def load_image(self, img_path, rescale=True, return_type="np"):
        self.check_and_download_item(img_path)

        # not using cv2 as may load in uint16 format
        # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255]
        # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC)
        # pil always returns uint8
        bg_color = self.bg_color
        img = np.array(Image.open(img_path).resize(self.img_wh))
        img = img.astype(np.float32) / 255.0  # [0, 1]
        if img.shape[-1] == 4:
            alpha = img[..., 3:4]
            img = img[..., :3] * alpha + bg_color * (1 - alpha)
        if rescale:
            img = img * 2.0 - 1.0  # to -1 ~ 1
        if return_type == "np":
            pass
        elif return_type == "pt":
            img = torch.from_numpy(img)
        else:
            raise NotImplementedError

        return img

    def load_normal(self, normal_path, return_type="np"):
        img = np.array(Image.open(normal_path).resize(self.img_wh))
        img = img.astype(np.float32) / 255.0  # [0, 1]

        if img.shape[-1] == 4:
            alpha = img[..., 3:4]
            img = img[..., :3] * alpha

        img = img * 2.0 - 1.0  # to -1 ~ 1

        if return_type == "np":
            pass
        elif return_type == "pt":
            img = torch.from_numpy(img)
        else:
            raise NotImplementedError
        return img

    def load_depth(self, img_path):
        self.check_and_download_item(img_path)
        img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
        img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_NEAREST)
        depth = img[None, :, :, 3]
        mask = depth == 0  # depth = 65535 is invalid
        mask = ~mask
        return depth, mask.astype(np.float32)

    def __len__(self):
        return len(self.sorted_instances) * self.repeat

    def load_instance(self, path, tunck=25):
        # 00000 - 00039
        im_list, depth_list, mask_list = [], [], []
        c2w_list, fov_x_list, fov_y_list = [], [], []
        # uniform sample 8 views
        selected_indices = np.linspace(0, 23, tunck, dtype=np.int32)
        for i in selected_indices:
            sub_path = "{:05d}".format(i)
            full_path = os.path.join(path, sub_path)

            im_path = os.path.join(full_path, f"{sub_path}.png")
            im = self.load_image(im_path)

            depth_path = os.path.join(full_path, f"{sub_path}_nd.exr")
            depth, mask = self.load_depth(depth_path)

            json_path = os.path.join(full_path, f"{sub_path}.json")
            self.check_and_download_item(json_path)

            with open(json_path, "r", encoding="utf8") as reader:
                json_content = json.load(reader)
                c2w = np.eye(4)
                c2w[:3, 0] = np.array(json_content["x"])
                c2w[:3, 1] = np.array(json_content["y"])
                c2w[:3, 2] = np.array(json_content["z"])
                c2w[:3, 3] = np.array(json_content["origin"])
                x_fov = json_content["x_fov"]
                y_fov = json_content["y_fov"]
                im_list.append(torch.from_numpy(im))
                depth_list.append(torch.from_numpy(depth))
                mask_list.append(torch.from_numpy(mask))
                c2w_list.append(torch.from_numpy(c2w))
                fov_x_list.append(x_fov)
                fov_y_list.append(y_fov)
        return im_list, depth_list, mask_list, c2w_list, fov_x_list, fov_y_list

    def get_cameras(self, c2w, fov_x_list, fov_y_list):
        w2c = torch.linalg.inv(c2w).to(torch.float)
        width = self.img_wh[0]
        fx = torch.tensor(
            [fov2focal(fov=fov_x, pixels=width) for fov_x in fov_x_list],
            dtype=torch.float32,
        )
        fy = torch.tensor(
            [fov2focal(fov=fov_y, pixels=width) for fov_y in fov_y_list],
            dtype=torch.float32,
        )
        width = torch.tensor([width], dtype=torch.float32).expand(c2w.shape[0])
        height = torch.clone(width)
        cameras = Cameras(
            R=w2c[:, :3, :3],
            T=w2c[:, :3, 3],
            fx=fx,
            fy=fy,
            cx=width / 2,
            cy=height / 2,
            width=width,
            height=height,
            appearance_id=torch.zeros_like(width),
            normalized_appearance_id=torch.zeros_like(width),
            distortion_params=None,
            camera_type=torch.zeros_like(width),
        )
        return cameras

    def __getitem__(self, index):
        index = index % len(self.sorted_instances)
        while True:
            try:
                instance_path = self.sorted_instances[index]
                im_list, depth_list, mask_list, c2w_list, fov_x_list, fov_y_list = (
                    self.load_instance(instance_path, self.num_frames)
                )
                break
            except Exception as e:
                print(e)
                print("Error loading instance, retrying...")
                index = np.random.randint(len(self.sorted_instances))
                continue

        fov = fov_x_list[0]
        focal_length = 0.5 * 1 / np.tan(0.5 * fov)
        intrinsics = np.array(
            [[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]],
            dtype=np.float32,
        )
        intrinsics = torch.from_numpy(intrinsics)

        # concat and stack
        im_tensors = torch.stack(im_list, dim=0)
        depth_tensors = torch.stack(depth_list, dim=0)
        c2w_matrixs = torch.stack(c2w_list, dim=0).to(torch.float32)
        masks = torch.stack(mask_list, dim=0)
        intrinsics = repeat(intrinsics, "i j -> n i j", n=im_tensors.shape[0])

        if self.relative_pose:
            src_w2c = torch.inverse(c2w_matrixs[:1])  # (1, 4, 4)
            src_distance = c2w_matrixs[:1, :3, 3].norm(dim=-1)  # (1)
            canonical_c2w = torch.matmul(src_w2c, c2w_matrixs)  # (Nv, 4, 4) z as x axis
            # shift to origin depth
            shift = torch.tensor(
                [
                    [1, 0, 0, 0],
                    [0, -1, 0, 0],
                    [0, 0, -1, src_distance],
                    [0, 0, 0, 1],
                ]
            )
            canonical_c2w = torch.matmul(shift, canonical_c2w)
            c2w_matrixs = canonical_c2w

        cameras = self.get_cameras(c2w_matrixs, fov_x_list, fov_y_list)
        im_tensors = rearrange(im_tensors, "m h w c -> m c h w")
        condition_image = im_tensors[0]
        diffusion_images = im_tensors
        # from PIL import Image
        # Image.fromarray(((condition_image * 0.5 + 0.5).permute(1, 2, 0).numpy() * 255).astype(np.uint8)).save('condition.png')
        # for i in range(8):
        #     Image.fromarray(((diffusion_images[i] * 0.5 + 0.5).permute(1, 2, 0).numpy() * 255).astype(np.uint8)).save(f'diffusion_{i}.png')
        # exit()
        fov = fov * 180 / np.pi
        fov = torch.tensor([fov] * c2w_matrixs.shape[0], dtype=torch.float32)
        return {
            "condition_image": condition_image,
            "diffusion_images": diffusion_images,
            "depths": depth_tensors,
            "cameras": cameras,
            "masks": masks,
            "prompt": "",
            "c2w": c2w_matrixs,
            "intrinsics": intrinsics,
            "fov": fov,
            "instance_path": instance_path,
        }


class MultiViewDataModule(LightningDataModule):
    def __init__(
        self,
        train_dataset: Dataset[Any],
        val_dataset: Dataset[Any],
        test_dataset: Dataset[Any],
        train_batch_size: int = 1,
        val_batch_size: int = 1,
        test_batch_size: int = 1,
        num_workers: Optional[int] = None,
        pin_memory: bool = True,
        real_dataset=None,
        real_batch_size=-1,
    ):
        super().__init__()

        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False)

        self.data_train = train_dataset
        self.data_val = val_dataset
        self.data_test = test_dataset
        self.data_real = real_dataset

        self.num_workers = num_workers if num_workers else train_batch_size * 2

    def prepare_data(self) -> None:
        # TODO: check if data is available
        pass

    def _dataloader(
        self, dataset: Dataset, batch_size: int, shuffle: bool
    ) -> DataLoader[Any]:
        return DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=self.num_workers,
            pin_memory=self.hparams.pin_memory,
            collate_fn=collect_fn,
        )

    def real_dataloader(self):
        return DataLoader(
            self.data_real,
            batch_size=self.hparams.real_batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=self.hparams.pin_memory,
            collate_fn=collect_fn,
        )

    def train_dataloader(self) -> DataLoader[Any]:
        return DataLoader(
            self.data_train,
            batch_size=self.hparams.train_batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=self.hparams.pin_memory,
            collate_fn=collect_fn,
        )

    def val_dataloader(self) -> DataLoader[Any]:
        if isinstance(self.data_val, ListConfig):
            return [
                self._dataloader(dataset, self.hparams.val_batch_size, False)
                for dataset in self.data_val
            ]
        elif isinstance(self.data_val, DictConfig):
            return [
                self._dataloader(dataset, self.hparams.val_batch_size, False)
                for _, dataset in self.data_val.items()
            ]
        else:
            return self._dataloader(self.data_val, self.hparams.val_batch_size, False)

    def test_dataloader(self) -> DataLoader[Any]:
        if isinstance(self.data_test, ListConfig):
            return [
                self._dataloader(dataset, self.hparams.test_batch_size, False)
                for dataset in self.data_test
            ]
        elif isinstance(self.data_test, DictConfig):
            return [
                self._dataloader(dataset, self.hparams.test_batch_size, False)
                for _, dataset in self.data_test.items()
            ]
        else:
            return self._dataloader(self.data_test, self.hparams.test_batch_size, False)


import cv2
import numpy as np


def save_video(numpy_array, output_path, fps=30):
    """
    将 numpy 数组保存为视频文件。

    参数:
    - numpy_array: 形状为 (t, h, w, c) 的 NumPy 数组，t 是帧数，h 和 w 是帧的高度和宽度，c 是颜色通道数。
    - output_path: 输出视频的文件路径。
    - fps: 视频的帧率。
    """
    numpy_array = numpy_array.astype(np.uint8).transpose(0, 2, 3, 1)
    t, h, w, c = numpy_array.shape

    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))

    for i in range(t):
        frame = numpy_array[i]
        out.write(frame)
    out.release()


from src.utils.visual import visual_camera

if __name__ == "__main__":
    from pytorch_lightning import seed_everything
    from torchvision.utils import save_image

    seed_everything(42)
    dataset = MultiViewDataset(
        root_dir="data/gobjaverse",
        instance_file="data/gobjaverse/gobjaverse_280k.json",
        bg_color="white",
        img_wh=(512, 512),
        num_frames=8,
        remote_root="https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/aigc3d/objaverse",
    )
    # split train test
    print(dataset[0])
