import os
import json
import math
import numpy as np
from PIL import Image
from omegaconf import DictConfig, ListConfig
from typing import Any, Tuple, Optional, List, Dict
import torch
from lightning import LightningDataModule
from torch.utils.data import Dataset, DataLoader
os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"
import cv2
from src.utils.point import fov2focal
from src.models.gs.cameras import Cameras
from torch.utils.data.dataloader import default_collate
from einops import rearrange, repeat
import json
from torch.utils.data.dataloader import default_collate
from petrel_client.client import Client
from tqdm import tqdm
import io
import requests
import random
client = Client('~/petreloss.conf', enable_mc=True)

def collect_fn(batch):
    cameras = [item.pop('cameras') for item in batch]
    batch_processed = default_collate(batch)
    batch_processed['cameras'] = cameras
    return batch_processed

class MultiViewDataset(Dataset):
    def __init__(
        self,
        root_dir: str,
        instance_file: str,
        # num_views: int,
        bg_color: str,
        num_frames: int,
        img_wh: Tuple[int, int],
        num_samples: Optional[int] = -1,
        relative_pose: bool = True,
        repeat: int = 1,
    ):
        self.root_dir = root_dir
        assert root_dir.endswith("/")
        self.bg_color = self.get_bg_color(bg_color)
        self.img_wh = img_wh
        self.num_frames = num_frames

        self.relative_pose = relative_pose
        self.repeat = repeat
        
        instances = json.load(open(instance_file, "r")) # [dictionary_id/instance_id]
        # sort by dictionary_id if dictionary_id is the same, sort by instance_id
        sorted_instances = sorted(instances, key=lambda x: (int(x.split("/")[0]), int(x.split("/")[1])))
        sorted_instances = [os.path.join(root_dir, x) for x in sorted_instances]
        
        self.sorted_instances = sorted_instances[:num_samples]
        
        # tmp_path = ['data/gobjaverse/data/0/10510/']
        # self.sorted_instances = tmp_path

    def get_bg_color(self, bg_color):
        if bg_color == "white":
            bg_color = np.array([1.0, 1.0, 1.0], dtype=np.float32)
        elif bg_color == "black":
            bg_color = np.array([0.0, 0.0, 0.0], dtype=np.float32)
        elif bg_color == "gray":
            bg_color = np.array([0.5, 0.5, 0.5], dtype=np.float32)
        elif bg_color == "random":
            bg_color = np.random.rand(3)
        elif isinstance(bg_color, float):
            bg_color = np.array([bg_color] * 3, dtype=np.float32)
        else:
            raise NotImplementedError
        return bg_color

    def load_image(self, img_path, rescale=True, return_type="np"):
        # not using cv2 as may load in uint16 format
        # img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) # [0, 255]
        # img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_CUBIC)
        # pil always returns uint8
        # split root_dir and path
        if img_path.startswith("http"):
            im_base = img_path.replace(self.root_dir, "")
            save_path = "gcc:s3://wenhao/gobjaverse_280k"
            put_path = os.path.join(save_path, im_base)
            if client.contains(put_path):
                im_bytes = client.get(put_path)
            else:
                base_name = img_path.split("/")[-2:]
                base_name = "/".join(base_name)
                img_path = img_path.replace(base_name, "campos_512_v4/" + base_name)
                im_bytes = requests.get(img_path).content
                client.put(put_path, io.BytesIO(im_bytes))
            im = Image.open(io.BytesIO(im_bytes))
        else:
            im = Image.open(img_path)

        bg_color = self.bg_color
        img = np.array(im.resize(self.img_wh))
        img = img.astype(np.float32) / 255.0  # [0, 1]
        if img.shape[-1] == 4:
            alpha = img[..., 3:4]
            img = img[..., :3] * alpha + bg_color * (1 - alpha)
        if rescale:
            img = img * 2.0 - 1.0  # to -1 ~ 1
        if return_type == "np":
            pass
        elif return_type == "pt":
            img = torch.from_numpy(img)
        else:
            raise NotImplementedError

        return img

    def load_normal(self, normal_path, return_type="np"):
        img = np.array(Image.open(normal_path).resize(self.img_wh))
        img = img.astype(np.float32) / 255.0  # [0, 1]

        if img.shape[-1] == 4:
            alpha = img[..., 3:4]
            img = img[..., :3] * alpha

        img = img * 2.0 - 1.0  # to -1 ~ 1

        if return_type == "np":
            pass
        elif return_type == "pt":
            img = torch.from_numpy(img)
        else:
            raise NotImplementedError
        return img
    
    def load_depth(self, img_path):
        if img_path.startswith("http"):
            im_base = img_path.replace(self.root_dir, "")
            save_path = "gcc:s3://wenhao/gobjaverse_280k"
            put_path = os.path.join(save_path, im_base)
            if client.contains(put_path):
                im_bytes = client.get(put_path)
            else:
                base_name = img_path.split("/")[-2:]
                base_name = "/".join(base_name)
                img_path = img_path.replace(base_name, "campos_512_v4/" + base_name)
                im_bytes = requests.get(img_path).content
                client.put(put_path, io.BytesIO(im_bytes))

            # Step 1: Create a BytesIO object from the byte data
            im_bytes = io.BytesIO(im_bytes)
            # Step 2: Convert the BytesIO object to a numpy array
            image_array = np.frombuffer(im_bytes.read(), dtype=np.uint8)
            # Step 3: Decode the numpy array into an image
            img = cv2.imdecode(image_array, cv2.IMREAD_UNCHANGED)
        else:
            img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)

        # img = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)
        img = cv2.resize(img, self.img_wh, interpolation=cv2.INTER_NEAREST)
        depth = img[None, :, :, 3]
        mask = depth == 0 # depth = 65535 is invalid
        mask = ~mask
        return depth, mask.astype(np.float32)
    
    def load_json(self, json_path):
        if json_path.startswith("http"):
            json_base = json_path.replace(self.root_dir, "")
            save_path = "gcc:s3://wenhao/gobjaverse_280k"
            put_path = os.path.join(save_path, json_base)
            if client.contains(put_path):
                json_bytes = client.get(put_path)
            else:
                base_name = json_path.split("/")[-2:]
                base_name = "/".join(base_name)
                json_path = json_path.replace(base_name, "campos_512_v4/" + base_name)
                json_bytes = requests.get(json_path).content
                save_path = "gcc:s3://wenhao/gobjaverse_280k"
                put_path = os.path.join(save_path, json_base)
                client.put(put_path, io.BytesIO(json_bytes))
            json_content = json.load(io.BytesIO(json_bytes))
        else:
            with open(json_path, "r") as f:
                json_content = json.load(f)
        return json_content

    def __len__(self):
        return len(self.sorted_instances) * self.repeat

    def load_instance(self, path, tunck=25):
        # 00000 - 00039
        im_list, depth_list, mask_list = [], [], []
        c2w_list, fov_x_list, fov_y_list = [], [], []
        # uniform sample 8 views

        selected_indices = np.linspace(0, 23, tunck, dtype=np.int)
        
        for i in selected_indices:
            sub_path = "{:05d}".format(i)
            full_path = os.path.join(path, sub_path)
            im_path = os.path.join(full_path, f"{sub_path}.png")
            im = self.load_image(im_path)
            depth_path = os.path.join(full_path, f"{sub_path}_nd.exr")
            depth, mask = self.load_depth(depth_path)
            json_path = os.path.join(full_path, f"{sub_path}.json")
            json_content = self.load_json(json_path)
            c2w = np.eye(4)
            c2w[:3, 0] = np.array(json_content['x'])
            c2w[:3, 1] = np.array(json_content['y'])
            c2w[:3, 2] = np.array(json_content['z'])
            c2w[:3, 3] = np.array(json_content['origin'])
            x_fov = json_content['x_fov']
            y_fov = json_content['y_fov']
            im_list.append(torch.from_numpy(im))
            depth_list.append(torch.from_numpy(depth))
            mask_list.append(torch.from_numpy(mask))
            c2w_list.append(torch.from_numpy(c2w))
            fov_x_list.append(x_fov)
            fov_y_list.append(y_fov)
        return im_list, depth_list, mask_list, c2w_list, fov_x_list, fov_y_list             

    def get_cameras(self, c2w, fov_x_list, fov_y_list):
        w2c = torch.linalg.inv(c2w).to(torch.float)
        width = self.img_wh[0]
        fx = torch.tensor(
            [fov2focal(fov=fov_x, pixels=width) for fov_x in fov_x_list],
            dtype=torch.float32)
        fy = torch.tensor(
            [fov2focal(fov=fov_y, pixels=width) for fov_y in fov_y_list],
            dtype=torch.float32)
        width = torch.tensor([width], dtype=torch.float32).expand(c2w.shape[0])
        height = torch.clone(width)
        cameras = Cameras(
            R = w2c[:, :3, :3],
            T = w2c[:, :3, 3],
            fx = fx,
            fy = fy,
            cx = width / 2,
            cy = height / 2,
            width=width,
            height=height,
            appearance_id=torch.zeros_like(width),
            normalized_appearance_id=torch.zeros_like(width),
            distortion_params=None,
            camera_type=torch.zeros_like(width),
        )
        return cameras

    def __getitem__(self, index):
        index = index % len(self.sorted_instances)
        while True:
            try:
                instance_path = self.sorted_instances[index]
                im_list, depth_list, mask_list, c2w_list, fov_x_list, fov_y_list = self.load_instance(instance_path, self.num_frames)
                break
            except Exception as e:
                print(e)
                print("Error loading instance, retrying...")
                index = np.random.randint(len(self.sorted_instances))
                continue
                    
        fov = fov_x_list[0]
        focal_length = 0.5 * 1 / np.tan(0.5 * fov)
        intrinsics = np.array(
            [
                [focal_length, 0, 0.5],
                [0, focal_length, 0.5],
                [0, 0, 1]
            ],
            dtype=np.float32,
        )
        intrinsics = torch.from_numpy(intrinsics)
    
        # concat and stack
        im_tensors = torch.stack(im_list, dim=0) # (Nv, 3, H, W)
        depth_tensors = torch.stack(depth_list, dim=0)
        c2w_matrixs = torch.stack(c2w_list, dim=0).to(torch.float32)
        masks = torch.stack(mask_list, dim=0)

        start_index = random.randint(0, self.num_frames)
        im_tensors = torch.cat([im_tensors[start_index:], im_tensors[:start_index]], dim=0)
        depth_tensors = torch.cat([depth_tensors[start_index:], depth_tensors[:start_index]], dim=0)
        c2w_matrixs = torch.cat([c2w_matrixs[start_index:], c2w_matrixs[:start_index]], dim=0)
        masks = torch.cat([masks[start_index:], masks[:start_index]], dim=0)
        fov_x_list = fov_x_list[start_index:] + fov_x_list[:start_index]
        fov_y_list = fov_y_list[start_index:] + fov_y_list[:start_index]

        intrinsics = repeat(intrinsics, "i j -> n i j", n = im_tensors.shape[0])

        if self.relative_pose:
            src_w2c = torch.inverse(c2w_matrixs[:1]) # (1, 4, 4)
            src_distance = c2w_matrixs[:1, :3, 3].norm(dim=-1) # (1)
            canonical_c2w = torch.matmul(src_w2c, c2w_matrixs) # (Nv, 4, 4) z as x axis 
            # shift to origin depth
            shift = torch.tensor(
                [
                    [1, 0, 0, 0],
                    [0, -1, 0, 0],
                    [0, 0, -1, src_distance],
                    [0, 0, 0, 1],
                ]
            )
            canonical_c2w = torch.matmul(shift, canonical_c2w)
            c2w_matrixs = canonical_c2w

        cameras = self.get_cameras(c2w_matrixs, fov_x_list, fov_y_list)
        im_tensors = rearrange(im_tensors, "m h w c -> m c h w")
        condition_image = im_tensors[0]
        diffusion_images = im_tensors
        # from PIL import Image
        # Image.fromarray(((condition_image * 0.5 + 0.5).permute(1, 2, 0).numpy() * 255).astype(np.uint8)).save('condition.png')
        # for i in range(8):
        #     Image.fromarray(((diffusion_images[i] * 0.5 + 0.5).permute(1, 2, 0).numpy() * 255).astype(np.uint8)).save(f'diffusion_{i}.png')
        # exit()
        fov = fov * 180 / np.pi
        fov = torch.tensor([fov] * c2w_matrixs.shape[0], dtype=torch.float32)
        return {
            "condition_image": condition_image,
            "diffusion_images": diffusion_images,
            "depths": depth_tensors,
            "cameras": cameras,
            "masks": masks,
            "prompt": "",
            "c2w": c2w_matrixs,
            "intrinsics": intrinsics,
            "fov": fov,
            "instance_path": instance_path,
        }
            # except Exception as e:
            #     print(e)
            #     print("Error loading instance, retrying...")
            #     index = np.random.randint(len(self.sorted_instances))
            #     continue


class MultiViewDataModule(LightningDataModule):
    def __init__(
        self,
        train_dataset: Dataset[Any],
        val_dataset: Dataset[Any],
        test_dataset: Dataset[Any],
        train_batch_size: int = 1,
        val_batch_size: int = 1,
        test_batch_size: int = 1,
        num_workers: Optional[int] = None,
        pin_memory: bool = True,
        real_dataset = None,
        real_batch_size = -1,
    ):
        super().__init__()

        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False)

        self.data_train = train_dataset
        self.data_val = val_dataset
        self.data_test = test_dataset
        self.data_real = real_dataset

        self.num_workers = num_workers if num_workers else train_batch_size * 2

    def prepare_data(self) -> None:
        # TODO: check if data is available
        pass
    
    def _dataloader(self, dataset: Dataset, batch_size: int, shuffle: bool) -> DataLoader[Any]:
        return DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=self.num_workers,
            pin_memory=self.hparams.pin_memory,
            collate_fn=collect_fn
        )

    def real_dataloader(self):
        return DataLoader(
            self.data_real,
            batch_size=self.hparams.real_batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=self.hparams.pin_memory,
            collate_fn=collect_fn
        )


    def train_dataloader(self) -> DataLoader[Any]:
        return DataLoader(
            self.data_train,
            batch_size=self.hparams.train_batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=self.hparams.pin_memory,
            collate_fn=collect_fn
        )

    def val_dataloader(self) -> DataLoader[Any]:
        if isinstance(self.data_val, ListConfig):
            return [self._dataloader(dataset, self.hparams.val_batch_size, False) for dataset in self.data_val]
        elif isinstance(self.data_val, DictConfig):
            return [self._dataloader(dataset, self.hparams.val_batch_size, False) for _, dataset in self.data_val.items()]
        else:
            return self._dataloader(self.data_val, self.hparams.val_batch_size, False)

    def test_dataloader(self) -> DataLoader[Any]:
        if isinstance(self.data_test, ListConfig):
            return [self._dataloader(dataset, self.hparams.test_batch_size, False) for dataset in self.data_test]
        elif isinstance(self.data_test, DictConfig):
            return [self._dataloader(dataset, self.hparams.test_batch_size, False) for _, dataset in self.data_test.items()]
        else:
            return self._dataloader(self.data_test, self.hparams.test_batch_size, False)
import cv2
import numpy as np

def save_video(numpy_array, output_path, fps=30):
    """
    将 numpy 数组保存为视频文件。
    
    参数:
    - numpy_array: 形状为 (t, h, w, c) 的 NumPy 数组，t 是帧数，h 和 w 是帧的高度和宽度，c 是颜色通道数。
    - output_path: 输出视频的文件路径。
    - fps: 视频的帧率。
    """
    numpy_array = numpy_array.astype(np.uint8).transpose(0, 2, 3, 1)
    t, h, w, c = numpy_array.shape
    
    fourcc = cv2.VideoWriter_fourcc(*'mp4v') 
    out = cv2.VideoWriter(output_path, fourcc, fps, (w, h))
    
    for i in range(t):
        frame = numpy_array[i]
        out.write(frame)
    out.release()

from src.utils.visual import visual_camera
from tqdm import tqdm
if __name__ == "__main__":
    from torchvision.utils import save_image
    from pytorch_lightning import seed_everything   
    seed_everything(42)
    dataset = MultiViewDataset(
        root_dir="https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/aigc3d/objaverse/",
        instance_file="data/gobjaverse/gobjaverse_280k.json",
        bg_color="white",
        num_frames=8,
        img_wh=(512, 512))
    # split train test
    from src.utils.visual import visual_camera
    item = dataset[0]
    visual_camera(item["c2w"], item["intrinsics"])
    # visual all off the image
    images = item["diffusion_images"]
    images = rearrange(images, "n c h w -> c h (n w)")
    save_image(images, "visual/diffusion_images.png")
    # for i in tqdm(range(0, 200000, 10000)):
    #     item = dataset[i]
        