"""
Objaverse dataset loader following the same interface as nerf_synthetic.py
"""

import json
import os
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import imageio

from .utils import Rays
from pdb import set_trace as bb

def _load_renderings(root_fp: str, subject_id: str, split: str):
    """Load images from disk for Objaverse format."""
    if split == "val":
        split = "test"
    if not root_fp.startswith("/"):
        # allow relative path
        root_fp = os.path.join(
            os.path.dirname(os.path.abspath(__file__)),
            "..",
            "..",
            root_fp,
        )

    data_dir = os.path.join(root_fp, subject_id)
    metadir = os.path.join(data_dir)

    with open(os.path.join(metadir, f"transforms_{split}.json"), "r") as fp:
        meta = json.load(fp)

    images = []
    camtoworlds = []

    split_dir = os.path.join(data_dir, split)

    for i in range(len(meta["frames"])):
        frame = meta["frames"][i]
        fname = os.path.join(split_dir, frame["file_path"].split('/')[-1]+'.png')
        rgba = np.array(imageio.imread(fname))
        camtoworlds.append(frame["transform_matrix"])
        images.append(rgba)
        # bb()

    images = np.stack(images, axis=0)
    camtoworlds = np.stack(camtoworlds, axis=0)

    h, w = images.shape[1:3]

    if 'focal' in meta:
        focal = float(meta["focal"])
    else:
        focal = 0.5 * w / np.tan(0.5 * float(meta["camera_angle_x"]))

    return images, camtoworlds, focal


class SubjectLoader(torch.utils.data.Dataset):
    """Single subject data loader for training and evaluation - Objaverse format."""

    SPLITS = ["train", "val", "test"]

    # Default values - can be overridden
    SCALE_FACTOR = 1.0
    NEAR, FAR = 2, 6
    OPENGL_CAMERA = True

    def __init__(
        self,
        subject_id: str,
        root_fp: str,
        split: str,
        color_bkgd_aug: str = "black",
        num_rays: int = None,
        near: float = None,
        far: float = None,
        batch_over_images: bool = True,
        device: torch.device = torch.device("cpu"),
        half_res: bool = False,
        white_bkgd: bool = False,
    ):
        super().__init__()
        assert split in self.SPLITS, f"Split {split} not in {self.SPLITS}"
        assert color_bkgd_aug in ["white", "black", "random"]

        self.split = split
        self.num_rays = num_rays
        self.near = self.NEAR if near is None else near
        self.far = self.FAR if far is None else far
        self.training = (num_rays is not None) and (split in ["train", "trainval"])
        self.color_bkgd_aug = color_bkgd_aug
        self.batch_over_images = batch_over_images
        self.half_res = half_res
        self.white_bkgd = white_bkgd

        # Load data
        self.images, self.camtoworlds, self.focal = _load_renderings(
            root_fp, subject_id, split
        )

        # Convert to tensors
        self.images = torch.from_numpy(self.images).to(torch.uint8)

        # Get original dimensions
        self.HEIGHT, self.WIDTH = self.images.shape[1:3]

        # Apply half resolution if requested
        if self.half_res:
            self.SCALE_FACTOR = 0.5
            self.focal = self.focal / 2.0
        else:
            self.SCALE_FACTOR = 1.0

        # Resize images if needed
        if self.SCALE_FACTOR != 1.0:
            self.images = torch.nn.functional.interpolate(
                self.images.permute(0, 3, 1, 2),
                scale_factor=self.SCALE_FACTOR,
                mode="bilinear",
            ).permute(0, 2, 3, 1)

        self.camtoworlds = torch.from_numpy(self.camtoworlds).to(torch.float32)

        # Create camera intrinsics matrix
        self.K = torch.tensor(
            [
                [
                    self.focal * self.SCALE_FACTOR,
                    0,
                    self.WIDTH / 2.0 * self.SCALE_FACTOR,
                ],
                [
                    0,
                    self.focal * self.SCALE_FACTOR,
                    self.HEIGHT / 2.0 * self.SCALE_FACTOR,
                ],
                [0, 0, 1],
            ],
            dtype=torch.float32,
        )

        # Move to device
        self.images = self.images.to(device)
        self.camtoworlds = self.camtoworlds.to(device)
        self.K = self.K.to(device)

        print(f"Loaded {len(self.images)} images with shape {self.images.shape}")
        assert self.images.shape[1:3] == (
            int(self.HEIGHT * self.SCALE_FACTOR),
            int(self.WIDTH * self.SCALE_FACTOR),
        )

        self.g = torch.Generator(device=device)
        self.g.manual_seed(42)

    def __len__(self):
        return len(self.images)

    @torch.no_grad()
    def __getitem__(self, index):
        data = self.fetch_data(index)
        data = self.preprocess(data)
        return data

    def preprocess(self, data):
        """Process the fetched / cached data with randomness."""
        rgba, rays = data["rgba"], data["rays"]
        pixels, alpha = torch.split(rgba, [3, 1], dim=-1)

        # if self.training:
        if self.color_bkgd_aug == "random":
            color_bkgd = torch.rand(3, device=self.images.device, generator=self.g)
        elif self.color_bkgd_aug == "white":
            color_bkgd = torch.ones(3, device=self.images.device)
        elif self.color_bkgd_aug == "black":
            color_bkgd = torch.zeros(3, device=self.images.device)
        # else:
        #     # just use white during inference
        #     color_bkgd = torch.ones(3, device=self.images.device)
        # bb()
        # Apply white background if requested
        if self.white_bkgd:
            pixels = pixels * alpha + (1.0 - alpha)
        else:
            pixels = pixels * alpha + color_bkgd * (1.0 - alpha)
        # bb()
        return {
            "pixels": pixels,  # [n_rays, 3] or [h, w, 3]
            "rays": rays,  # [n_rays,] or [h, w]
            "color_bkgd": color_bkgd,  # [3,]
            **{k: v for k, v in data.items() if k not in ["rgba", "rays"]},
        }

    def update_num_rays(self, num_rays):
        self.num_rays = num_rays

    def fetch_data(self, index):
        """Fetch the data (it maybe cached for multiple batches)."""
        num_rays = self.num_rays

        if self.training:
            if self.batch_over_images:
                image_id = torch.randint(
                    0,
                    len(self.images),
                    size=(num_rays,),
                    device=self.images.device,
                    generator=self.g,
                )
            else:
                image_id = [index] * num_rays
            x = torch.randint(
                0,
                int(self.WIDTH * self.SCALE_FACTOR),
                size=(num_rays,),
                device=self.images.device,
                generator=self.g,
            )
            y = torch.randint(
                0,
                int(self.HEIGHT * self.SCALE_FACTOR),
                size=(num_rays,),
                device=self.images.device,
                generator=self.g,
            )
        else:
            image_id = [index]
            x, y = torch.meshgrid(
                torch.arange(
                    int(self.WIDTH * self.SCALE_FACTOR),
                    device=self.images.device,
                ),
                torch.arange(
                    int(self.HEIGHT * self.SCALE_FACTOR),
                    device=self.images.device,
                ),
                indexing="xy",
            )
            x = x.flatten()
            y = y.flatten()

        # generate rays
        rgba = self.images[image_id, y, x] / 255.0  # (num_rays, 4)
        c2w = self.camtoworlds[image_id]  # (num_rays, 3, 4)
        camera_dirs = F.pad(
            torch.stack(
                [
                    (x - self.K[0, 2] + 0.5) / self.K[0, 0],
                    (y - self.K[1, 2] + 0.5)
                    / self.K[1, 1]
                    * (-1.0 if self.OPENGL_CAMERA else 1.0),
                ],
                dim=-1,
            ),
            (0, 1),
            value=(-1.0 if self.OPENGL_CAMERA else 1.0),
        )  # [num_rays, 3]

        # [n_cams, height, width, 3]
        directions = (camera_dirs[:, None, :] * c2w[:, :3, :3]).sum(dim=-1)
        origins = torch.broadcast_to(c2w[:, :3, -1], directions.shape)
        viewdirs = directions / torch.linalg.norm(directions, dim=-1, keepdims=True)

        if self.training:
            origins = torch.reshape(origins, (num_rays, 3))
            viewdirs = torch.reshape(viewdirs, (num_rays, 3))
            rgba = torch.reshape(rgba, (num_rays, 4))
        else:
            origins = torch.reshape(
                origins,
                (
                    int(self.HEIGHT * self.SCALE_FACTOR),
                    int(self.WIDTH * self.SCALE_FACTOR),
                    3,
                ),
            )
            viewdirs = torch.reshape(
                viewdirs,
                (
                    int(self.HEIGHT * self.SCALE_FACTOR),
                    int(self.WIDTH * self.SCALE_FACTOR),
                    3,
                ),
            )
            rgba = torch.reshape(
                rgba,
                (
                    int(self.HEIGHT * self.SCALE_FACTOR),
                    int(self.WIDTH * self.SCALE_FACTOR),
                    4,
                ),
            )
        rays = Rays(origins=origins, viewdirs=viewdirs)

        return {
            "rgba": rgba,  # [h, w, 4] or [num_rays, 4]
            "rays": rays,  # [h, w, 3] or [num_rays, 3]
        }


class ObjaverseDataset(torch.utils.data.Dataset):
    """Multi-object Objaverse dataset loader."""

    def __init__(
        self,
        basedir: str,
        split: str = "train",
        color_bkgd_aug: str = "white",
        num_rays: int = None,
        near: float = None,
        far: float = None,
        batch_over_images: bool = True,
        device: torch.device = torch.device("cpu"),
        half_res: bool = False,
        white_bkgd: bool = False,
    ):
        super().__init__()

        self.basedir = basedir
        self.split = split
        self.device = device

        # Collect object directories
        print("Collecting object directories...")
        list_path = os.path.join(basedir, "list.txt")
        if os.path.exists(list_path):
            with open(list_path, "r") as f:
                self.obj_dirs = [line.strip() for line in f.readlines()]
        else:
            self.obj_dirs = [
                p
                for p in os.listdir(basedir)
                if os.path.isdir(os.path.join(basedir, p))
            ]

        print(f"Found {len(self.obj_dirs)} objects in {basedir}")

        # Create individual subject loaders
        self.subject_loaders = {}
        self.valid_subjects = []

        for obj_id in self.obj_dirs:
            try:
                loader = SubjectLoader(
                    subject_id=obj_id,
                    root_fp=basedir,
                    split=split,
                    color_bkgd_aug=color_bkgd_aug,
                    num_rays=num_rays,
                    near=near,
                    far=far,
                    batch_over_images=batch_over_images,
                    device=device,
                    half_res=half_res,
                    white_bkgd=white_bkgd,
                )
                self.subject_loaders[obj_id] = loader
                self.valid_subjects.append(obj_id)
            except Exception as e:
                print(f"Failed to load object {obj_id}: {e}")
                continue

        print(f"Successfully loaded {len(self.valid_subjects)} objects")

        # Calculate total length
        self.lengths = [
            len(self.subject_loaders[obj_id]) for obj_id in self.valid_subjects
        ]
        self.cumulative_lengths = np.cumsum([0] + self.lengths)

    def __len__(self):
        return sum(self.lengths)

    def __getitem__(self, index):
        # Find which subject this index belongs to
        subject_idx = np.searchsorted(self.cumulative_lengths[1:], index, side="right")
        local_idx = index - self.cumulative_lengths[subject_idx]

        subject_id = self.valid_subjects[subject_idx]
        return self.subject_loaders[subject_id][local_idx]

    def update_num_rays(self, num_rays):
        for loader in self.subject_loaders.values():
            loader.update_num_rays(num_rays)
