import os
import json
import math
import numpy as np
from PIL import Image
from omegaconf import DictConfig, ListConfig
from typing import Any, Tuple, Optional, List, Dict
import torch
from lightning import LightningDataModule
from torch.utils.data import Dataset, DataLoader
os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"
import cv2
from src.utils.point import fov2focal
from src.models.gs.cameras import Cameras
from torch.utils.data.dataloader import default_collate
from einops import rearrange, repeat
import json
from torch.utils.data.dataloader import default_collate
from petrel_client.client import Client
from tqdm import tqdm
import io
import requests
import random
client = Client('~/petreloss.conf', enable_mc=True)

def collect_fn(batch):
    cameras = [item.pop('cameras') for item in batch]
    batch_processed = default_collate(batch)
    batch_processed['cameras'] = cameras
    return batch_processed

class MultiViewDataset(Dataset):
    def __init__(
        self,
        root_dir: str,
        bg_color: str,
        num_frames: int,
        meta_file: str = None,
        img_wh: Tuple[int, int] = (512, 512),
        relative_pose: bool = True,
        num_samples: int = -1,
        repeat: int = 1
    ):
        self.root_dir = root_dir

        self.files = [os.path.join(self.root_dir, file) for file in os.listdir(self.root_dir)]
        self.meta_file = meta_file
        self.img_wh = img_wh
        self.relative_pose = relative_pose
        self.num_frames = num_frames
        self.bg_color = self.get_bg_color(bg_color)
        self.repeat = repeat
        
    def get_bg_color(self, bg_color):
        if bg_color == "white":
            bg_color = np.array([1.0, 1.0, 1.0], dtype=np.float32)
        elif bg_color == "black":
            bg_color = np.array([0.0, 0.0, 0.0], dtype=np.float32)
        elif bg_color == "gray":
            bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32)
        elif bg_color == "random":
            bg_color = np.random.rand(3)
        elif isinstance(bg_color, float):
            bg_color = np.array([bg_color] * 3, dtype=np.float32)
        else:
            raise NotImplementedError
        return bg_color

    def load_image(self, img_path, rescale=True, return_type="np"):
        # not using cv2 as may load in uint16 format
        # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255]
        # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC)
        # pil always returns uint8
        # split root_dir and path
        if img_path.startswith("http"):
            im_base = img_path.replace(self.root_dir, "")
            save_path = "gcc:s3://wenhao/gobjaverse_280k"
            put_path = os.path.join(save_path, im_base)
            if client.contains(put_path):
                im_bytes = client.get(put_path)
            else:
                base_name = img_path.split("/")[-2:]
                base_name = "/".join(base_name)
                img_path = img_path.replace(base_name, "campos_512_v4/" + base_name)
                im_bytes = requests.get(img_path).content
                client.put(put_path, io.BytesIO(im_bytes))
            im = Image.open(io.BytesIO(im_bytes))
        else:
            im = Image.open(img_path)

        bg_color = self.bg_color
        img = np.array(im.resize(self.img_wh))
        img = img.astype(np.float32) / 255.0  # [0, 1]

        if img.shape[-1] == 4:
            alpha = img[..., 3:4]
            img = img[..., :3] * alpha + bg_color * (1 - alpha)
        if rescale:
            img = img * 2.0 - 1.0  # to -1 ~ 1
        if return_type == "np":
            pass
        elif return_type == "pt":
            img = torch.from_numpy(img)
        else:
            raise NotImplementedError

        return img

    def load_normal(self, normal_path, return_type="np"):
        img = np.array(Image.open(normal_path).resize(self.img_wh))
        img = img.astype(np.float32) / 255.0  # [0, 1]

        if img.shape[-1] == 4:
            alpha = img[..., 3:4]
            img = img[..., :3] * alpha

        img = img * 2.0 - 1.0  # to -1 ~ 1

        if return_type == "np":
            pass
        elif return_type == "pt":
            img = torch.from_numpy(img)
        else:
            raise NotImplementedError
        return img
    
    def load_depth(self, img_path):
        if img_path.startswith("http"):
            im_base = img_path.replace(self.root_dir, "")
            save_path = "gcc:s3://wenhao/gobjaverse_280k"
            put_path = os.path.join(save_path, im_base)
            if client.contains(put_path):
                im_bytes = client.get(put_path)
            else:
                base_name = img_path.split("/")[-2:]
                base_name = "/".join(base_name)
                img_path = img_path.replace(base_name, "campos_512_v4/" + base_name)
                im_bytes = requests.get(img_path).content
                client.put(put_path, io.BytesIO(im_bytes))

            # Step 1: Create a BytesIO object from the byte data
            im_bytes = io.BytesIO(im_bytes)
            # Step 2: Convert the BytesIO object to a numpy array
            image_array = np.frombuffer(im_bytes.read(), dtype=np.uint8)
            # Step 3: Decode the numpy array into an image
            img = cv2.imdecode(image_array, cv2.IMREAD_UNCHANGED)
        else:
            img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)

        # img = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)
        img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_NEAREST)
        depth = img[None, :, :, 3]
        mask = depth == 0 # depth = 65535 is invalid
        mask = ~mask
        return depth, mask.astype(np.float32)

    def __len__(self):
        return len(self.files) * self.repeat

    def load_instance(self, path, tunck=25):
        # 00000 - 00039
        im_list, depth_list, mask_list = [], [], []
        c2w_list, fov_x_list, fov_y_list = [], [], []
        # uniform sample 8 views

        selected_indices = np.linspace(0, 23, tunck, dtype=np.int)
        
        for i in selected_indices:
            sub_path = "{:05d}".format(i)
            full_path = os.path.join(path, sub_path)
            im_path = os.path.join(full_path, f"{sub_path}.png")
            im = self.load_image(im_path)
            depth_path = os.path.join(full_path, f"{sub_path}_nd.exr")
            depth, mask = self.load_depth(depth_path)
            json_path = os.path.join(full_path, f"{sub_path}.json")
            json_content = self.load_json(json_path)
            c2w = np.eye(4)
            c2w[:3, 0] = np.array(json_content['x'])
            c2w[:3, 1] = np.array(json_content['y'])
            c2w[:3, 2] = np.array(json_content['z'])
            c2w[:3, 3] = np.array(json_content['origin'])
            x_fov = json_content['x_fov']
            y_fov = json_content['y_fov']
            im_list.append(torch.from_numpy(im))
            depth_list.append(torch.from_numpy(depth))
            mask_list.append(torch.from_numpy(mask))
            c2w_list.append(torch.from_numpy(c2w))
            fov_x_list.append(x_fov)
            fov_y_list.append(y_fov)
        return im_list, depth_list, mask_list, c2w_list, fov_x_list, fov_y_list             

    def get_cameras(self, c2w, fov_x_list, fov_y_list):
        w2c = torch.linalg.inv(c2w).to(torch.float)
        width = self.img_wh[0]
        fx = torch.tensor(
            [fov2focal(fov=fov_x, pixels=width) for fov_x in fov_x_list],
            dtype=torch.float32)
        fy = torch.tensor(
            [fov2focal(fov=fov_y, pixels=width) for fov_y in fov_y_list],
            dtype=torch.float32)
        width = torch.tensor([width], dtype=torch.float32).expand(c2w.shape[0])
        height = torch.clone(width)
        cameras = Cameras(
            R = w2c[:, :3, :3],
            T = w2c[:, :3, 3],
            fx = fx,
            fy = fy,
            cx = width / 2,
            cy = height / 2,
            width=width,
            height=height,
            appearance_id=torch.zeros_like(width),
            normalized_appearance_id=torch.zeros_like(width),
            distortion_params=None,
            camera_type=torch.zeros_like(width),
        )
        return cameras

    def __getitem__(self, index):
        index = index % len(self.files)
        file = self.files[index]
        im = self.load_image(file)
        im = rearrange(im, "h w c -> c h w")
        meta_file = json.load(open(self.meta_file))
        fov = meta_file['camera_angle_x']
        c2w_list = np.array([item['transform_matrix'] for item in meta_file['locations']])     
        select_indices = np.linspace(0, len(c2w_list)-1, self.num_frames, dtype=np.int)   
        
        c2w_list = c2w_list[select_indices]
        focal_length = 0.5 * 1 / np.tan(0.5 * fov)
        intrinsics = np.array(
            [
                [focal_length, 0, 0.5],
                [0, focal_length, 0.5],
                [0, 0, 1]
            ],
            dtype=np.float32,
        )
        intrinsics = torch.from_numpy(intrinsics)
        c2w_matrixs = torch.from_numpy(c2w_list) # N x 4 x 4
        # to opencv
        c2w_matrixs[:, :3, 1:3] *= -1
        
        num_views = c2w_matrixs.shape[0]
        intrinsics = repeat(intrinsics, "h w -> n h w", n=num_views)
        
        if self.relative_pose:
            src_w2c = torch.inverse(c2w_matrixs[:1]) # (1, 4, 4)
            src_distance = c2w_matrixs[:1, :3, 3].norm(dim=-1) # (1)
            canonical_c2w = torch.matmul(src_w2c, c2w_matrixs) # (Nv, 4, 4) z as x axis 
            # shift to origin depth
            shift = torch.tensor(
                [
                    [1, 0, 0, 0],
                    [0, -1, 0, 0],
                    [0, 0, -1, src_distance],
                    [0, 0, 0, 1],
                ]
            )
            canonical_c2w = torch.matmul(shift, canonical_c2w)
            c2w_matrixs = canonical_c2w
        fov_x_list = [fov] * num_views
        fov_y_list = [fov] * num_views
        cameras = self.get_cameras(c2w_matrixs, fov_x_list, fov_y_list)
        fov = fov * 180 / np.pi
        fov = torch.tensor([fov] * c2w_matrixs.shape[0], dtype=torch.float32)
        return {
            "condition_image": torch.from_numpy(im),
            "intrinsics": intrinsics,
            "cameras": cameras,
            "c2w": c2w_matrixs,
            "diffusion_images": torch.rand(
                self.num_frames, 3, self.img_wh[0], self.img_wh[1]
            ),
            "fov": fov
        }

class MultiViewDataModule(LightningDataModule):
    def __init__(
        self,
        train_dataset: Dataset[Any] = None,
        val_dataset: Dataset[Any] = None,
        test_dataset: Dataset[Any] = None,
        train_batch_size: int = 1,
        val_batch_size: int = 1,
        test_batch_size: int = 1,
        num_workers: Optional[int] = None,
        pin_memory: bool = True,
        real_dataset = None,
        real_batch_size = -1,
    ):
        super().__init__()

        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False)

        self.data_train = train_dataset
        self.data_val = val_dataset
        self.data_test = test_dataset
        self.data_real = real_dataset

        self.num_workers = num_workers if num_workers else train_batch_size * 2

    def prepare_data(self) -> None:
        # TODO: check if data is available
        pass
    
    def _dataloader(self, dataset: Dataset, batch_size: int, shuffle: bool) -> DataLoader[Any]:
        return DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=self.num_workers,
            pin_memory=self.hparams.pin_memory,
            collate_fn=collect_fn
        )

    def real_dataloader(self):
        return DataLoader(
            self.data_real,
            batch_size=self.hparams.real_batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=self.hparams.pin_memory,
            collate_fn=collect_fn
        )


    def train_dataloader(self) -> DataLoader[Any]:
        if self.data_train is None:
            return None
        return DataLoader(
            self.data_train,
            batch_size=self.hparams.train_batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=self.hparams.pin_memory,
            collate_fn=collect_fn
        )

    def val_dataloader(self) -> DataLoader[Any]:
        if isinstance(self.data_val, ListConfig):
            return [self._dataloader(dataset, self.hparams.val_batch_size, False) for dataset in self.data_val]
        elif isinstance(self.data_val, DictConfig):
            return [self._dataloader(dataset, self.hparams.val_batch_size, False) for _, dataset in self.data_val.items()]
        else:
            return self._dataloader(self.data_val, self.hparams.val_batch_size, False)

    def test_dataloader(self) -> DataLoader[Any]:
        if isinstance(self.data_test, ListConfig):
            return [self._dataloader(dataset, self.hparams.test_batch_size, False) for dataset in self.data_test]
        elif isinstance(self.data_test, DictConfig):
            return [self._dataloader(dataset, self.hparams.test_batch_size, False) for _, dataset in self.data_test.items()]
        else:
            return self._dataloader(self.data_test, self.hparams.test_batch_size, False)
import cv2
import numpy as np

def save_video(numpy_array, output_path, fps=30):
    """
    将 numpy 数组保存为视频文件。
    
    参数:
    - numpy_array: 形状为 (t, h, w, c) 的 NumPy 数组，t 是帧数，h 和 w 是帧的高度和宽度，c 是颜色通道数。
    - output_path: 输出视频的文件路径。
    - fps: 视频的帧率。
    """
    numpy_array = numpy_array.astype(np.uint8).transpose(0, 2, 3, 1)
    t, h, w, c = numpy_array.shape
    
    fourcc = cv2.VideoWriter_fourcc(*'mp4v') 
    out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
    
    for i in range(t):
        frame = numpy_array[i]
        out.write(frame)
    out.release()

from src.utils.visual import visual_camera
from tqdm import tqdm
if __name__ == "__main__":
    from torchvision.utils import save_image
    from pytorch_lightning import seed_everything   
    seed_everything(42)
    dataset = MultiViewDataset(
        root_dir="data/test/images",
        meta_file="data/test/meta.json",
        bg_color="white",
        num_frames=8,
        img_wh=(512, 512))
    # split train test
    from src.utils.visual import visual_camera
    item = dataset[1]
    visual_camera(item["c2w"], item["intrinsics"])
    # visual all off the image
    images = item["diffusion_images"]
    images = rearrange(images, "n c h w -> c h (n w)")
    save_image(images, "visual/diffusion_images.png")
    save_image(item["condition_image"], "visual/condition_image.png")
    # for i in tqdm(range(0, 200000, 10000)):
    #     item = dataset[i]
        