from dataclasses import dataclass, field

import glob
import os
import re
import cv2
import numpy as np
import torch
import torch.nn.functional as F

import torchvision.transforms.functional as TF
from transformers import pipeline
from transformers import AutoImageProcessor, AutoModelForDepthEstimation
from PIL import Image
import requests

import threestudio
from threestudio.systems.base import BaseLift3DSystem
from threestudio.utils.ops import binary_cross_entropy, dot
from threestudio.utils.typing import *

from threestudio.utils.base import update_end_if_possible


import torch

def add_channel_to_image(image):
    """
    Add a third channel filled with zeros to a batch of 2-channel images using PyTorch.

    Args:
    - image (torch.Tensor): A tensor of shape (B, H, W, 2)

    Returns:
    - torch.Tensor: A tensor of shape (B, H, W, 3) with the added zero channel.
    """
    B, H, W, _ = image.shape
    # Create a new zero-filled channel
    zero_channel = torch.zeros((B, H, W, 1), dtype=image.dtype, device=image.device)
    # Concatenate the original tensor with the zero channel along the last dimension
    rgb_image = torch.cat((image, zero_channel), dim=3)
    return rgb_image

def compute_hessian(image):
    # Assume image is of shape [B, H, W, C] with B=batch size, C=channels
    B, H, W, C = image.shape
    
    # Pad the image to handle borders (padding left, right, top, bottom for each spatial dimension)
    padded_image = torch.nn.functional.pad(image, (0, 0, 1, 1, 1, 1), mode='replicate')
    
    # Compute second derivatives, taking into account the batch and channel dimensions
    dxx = padded_image[:, 1:-1, 2:, :] - 2 * padded_image[:, 1:-1, 1:-1, :] + padded_image[:, 1:-1, :-2, :]
    dyy = padded_image[:, 2:, 1:-1, :] - 2 * padded_image[:, 1:-1, 1:-1, :] + padded_image[:, :-2, 1:-1, :]
    dxy = (padded_image[:, 2:, 2:, :] - padded_image[:, 2:, :-2, :] - padded_image[:, :-2, 2:, :] + padded_image[:, :-2, :-2, :]) / 4
    
    # Pack into a tensor of shape [B, H, W, C, 2, 2] for Hessian at each pixel
    hessian = torch.zeros((B, H, W, C, 2, 2), device=image.device)
    hessian[:, :, :, :, 0, 0] = dxx
    hessian[:, :, :, :, 1, 1] = dyy
    hessian[:, :, :, :, 0, 1] = dxy
    hessian[:, :, :, :, 1, 0] = dxy
    
    return hessian

def compute_eigenvalues_vectorized(hessian_matrix):
    # Assume hessian_matrix is of shape [B, H, W, C, 2, 2]
    trace = hessian_matrix[..., 0, 0] + hessian_matrix[..., 1, 1]  # Trace of each 2x2 matrix
    determinant = hessian_matrix[..., 0, 0] * hessian_matrix[..., 1, 1] - hessian_matrix[..., 0, 1] * hessian_matrix[..., 1, 0]  # Determinant of each 2x2 matrix

    # Calculate the discriminant of the characteristic equation
    discriminant = trace**2 - 4 * determinant

    # Eigenvalues using the quadratic formula
    eigenvalues = torch.zeros_like(hessian_matrix[..., 0])  # Creating a tensor to hold the eigenvalues with shape [B, H, W, C, 2]
    eigenvalues[..., 0] = (trace + torch.sqrt(discriminant)) / 2  # First eigenvalue
    eigenvalues[..., 1] = (trace - torch.sqrt(discriminant)) / 2  # Second eigenvalue

    return eigenvalues

def get_img_eigenvalues(depth_matrix):
    depth_hessian = compute_hessian(depth_matrix)
    eigenvalues = compute_eigenvalues_vectorized(depth_hessian)
    # [B, H, W, C, 2]
    return eigenvalues[:, :, :, 0, :]


@threestudio.register("ours-inversion-system")
class OursInversionSystem(BaseLift3DSystem):
    @dataclass
    class Config(BaseLift3DSystem.Config):
        project_every: Any = 1
        run_test_every: Any = 1000
        rescale_additional_losses: float = -1
        sqrt_sparcity_loss: bool = True
        threshold_sparcity_loss: float = 0.0
        convexity_res: int = 8
        pass

    cfg: Config

    def configure(self):
        # create geometry, material, background, renderer
        super().configure()
        self.projection_queue = []

        if "lambda_depth" in self.cfg.loss:
            self.depth_pipe = pipeline(task="depth-estimation", model="LiheYoung/depth-anything-small-hf")
            # self.depth_image_processor = AutoImageProcessor.from_pretrained("LiheYoung/depth-anything-small-hf")
            # self.depth_model = AutoModelForDepthEstimation.from_pretrained("LiheYoung/depth-anything-small-hf")

    def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
        render_out = self.renderer(**batch)
        return {
            **render_out,
        }
    
    @torch.no_grad()
    def predict_normalized_depth(self, target_image, ref_depth):
        target_image = target_image[0].permute(1, 2, 0).cpu().numpy()
        target_image = Image.fromarray(np.uint8(target_image*255))
        predicted_depth = self.depth_pipe(target_image)["depth"]
        predicted_depth = TF.to_tensor(predicted_depth).unsqueeze(0)
        predicted_depth = F.interpolate(predicted_depth, size=ref_depth.shape[-3:-1], mode='bilinear', align_corners=False).permute(0, 2, 3, 1).cuda()
        
        
        # D min-max scaling
        # eps = 1e-6
        # ref_depth_min = ref_depth[ref_depth > eps].min()
        # ref_depth_scale = ref_depth[ref_depth > eps].max() - ref_depth_min
        
        # nonz = predicted_depth > eps
        # dmin = predicted_depth[nonz].min()
        # dscale = predicted_depth[nonz].max() - dmin
        # predicted_depth[nonz] = ref_depth_min + ref_depth_scale * (predicted_depth[nonz] - dmin) / dscale # rescale around the new min
        predicted_depth = predicted_depth / predicted_depth.max()
        return predicted_depth.detach()
            

    def on_fit_start(self) -> None:
        super().on_fit_start()
        # only used in training
        self.prompt_processor = threestudio.find(self.cfg.prompt_processor_type)(
            self.cfg.prompt_processor
        )
        self.guidance = threestudio.find(self.cfg.guidance_type)(self.cfg.guidance)

    def training_step(self, batch, batch_idx):
        out = self(batch)
        prompt_utils = self.prompt_processor()
        guidance_out = self.guidance(
            out["comp_rgb"], prompt_utils, **batch, rgb_as_latents=False
        )

        loss = 0.0
        
        loss_scale = 1.
        if self.cfg.rescale_additional_losses > 0.:
            loss_scale = (guidance_out["grad_norm"] / self.cfg.rescale_additional_losses).detach()
        
        for name, value in guidance_out.items():
            if not (type(value) is torch.Tensor and value.numel() > 1):
                self.log(f"train/{name}", value)
            if name.startswith("loss_"):
                loss += value * self.C(self.cfg.loss[name.replace("loss_", "lambda_")])

        if self.C(self.cfg.loss.lambda_orient) > 0:
            if "normal" not in out:
                raise ValueError(
                    "Normal is required for orientation loss, no normal is found in the output."
                )
            loss_orient = (
                out["weights"].detach()
                * dot(out["normal"], out["t_dirs"]).clamp_min(0.0) ** 2
            ).sum() / (out["opacity"] > 0).sum()
            self.log("train/loss_orient", loss_orient)
            loss += loss_scale * loss_orient * self.C(self.cfg.loss.lambda_orient)
            
            
        if ("lambda_convex" in self.cfg.loss) and (self.C(self.cfg.loss.lambda_convex) > 1e-6):
            downscaled_norms = F.interpolate(out["comp_normal"].permute(0, 3, 1, 2), [self.cfg.convexity_res, self.cfg.convexity_res], mode='bilinear', align_corners=False).permute(0, 2, 3, 1)
            
            # Left-right
            right_normals = downscaled_norms[:, :, 1: , :]  # Pad and then remove the first column
            left_normals  = downscaled_norms[:, :, :-1, :]  # Remove the last column to align with right_normals
    
            h_cross_product = torch.cross(left_normals, right_normals, dim=-1)
            h_sine_of_angle = h_cross_product[..., 2]
            
            # Up-dowm
            up_normals    = downscaled_norms[:, :-1, :, :]
            down_normals  = downscaled_norms[:, 1: , :, :]
    
            v_cross_product = torch.cross(down_normals, up_normals, dim=-1)
            v_sine_of_angle = v_cross_product[..., 2]
            
            loss_convexity = - (h_sine_of_angle.mean() + v_sine_of_angle.mean())
            self.log("train/loss_convexity", loss_convexity)
            loss += loss_scale * loss_convexity * self.C(self.cfg.loss.lambda_convex)
            
        if ("lambda_convex_hess" in self.cfg.loss) and (self.C(self.cfg.loss.lambda_convex_hess) > 1e-6):
            downscaled_depth = F.interpolate(out["depth_d"].permute(0, 3, 1, 2), [self.cfg.convexity_res, self.cfg.convexity_res], mode='bilinear', align_corners=False).permute(0, 2, 3, 1)
            depth_eigenvalues = get_img_eigenvalues(downscaled_depth) #B, H, W, 2
            
            #B, H, W, 1
            downscaled_opacity = F.interpolate(out["opacity"].permute(0, 3, 1, 2), [self.cfg.convexity_res, self.cfg.convexity_res], mode='bilinear', align_corners=False).permute(0, 2, 3, 1)
            internal_eigenvalues = depth_eigenvalues * downscaled_opacity
            
            loss_convex_hess = -internal_eigenvalues.mean()
            self.log("train/loss_convexity_hess", loss_convex_hess)
            loss += loss_scale * loss_convex_hess * self.C(self.cfg.loss.lambda_convex_hess)

        loss_sparsity_initial = (out["opacity"] ** 2 + 0.01)
        if self.cfg.sqrt_sparcity_loss:
            loss_sparsity_sqrt = loss_sparsity_initial.sqrt()
        else:
            loss_sparsity_sqrt = loss_sparsity_initial
        loss_sparsity = F.relu(loss_sparsity_sqrt.mean() - self.cfg.threshold_sparcity_loss)
        self.log("train/loss_sparsity", loss_sparsity)
        loss += loss_scale * loss_sparsity * self.C(self.cfg.loss.lambda_sparsity)

        opacity_clamped = out["opacity"].clamp(1.0e-3, 1.0 - 1.0e-3)
        loss_opaque = binary_cross_entropy(opacity_clamped, opacity_clamped)
        self.log("train/loss_opaque", loss_opaque)
        loss += loss_scale * loss_opaque * self.C(self.cfg.loss.lambda_opaque)
        
        # z-variance loss proposed in HiFA: https://hifa-team.github.io/HiFA-site/
        if "z_variance" in out and "lambda_z_variance" in self.cfg.loss:
            loss_z_variance = out["z_variance"][out["opacity"] > 0.5].mean()
            self.log("train/loss_z_variance", loss_z_variance)
            loss += loss_scale * loss_z_variance * self.C(self.cfg.loss.lambda_z_variance)

        if "depth_z" in out and "lambda_depth" in self.cfg.loss:
            rescaled_nerf_depth = out["depth_z"] / out["depth_z"].max()
            predicted_depth = self.predict_normalized_depth(
                self.guidance.decode_latents(guidance_out["target_x0"]), rescaled_nerf_depth
            )

            loss_depth = F.mse_loss(rescaled_nerf_depth, predicted_depth)
            self.log("train/loss_depth", loss_depth)
            loss += loss_scale * loss_depth * self.C(self.cfg.loss.lambda_depth)

        for name, value in self.cfg.loss.items():
            self.log(f"train_params/{name}", self.C(value))
            
        self.log("train/total_loss", loss)

        return {"loss": loss}
    

    def validation_step(self, batch, batch_idx):
        
        if self.true_global_step % self.cfg.run_test_every == 0:
            self.test_step(batch, batch_idx)
        
        if batch['index'][0] != 0:
            return # We sample the whole orbit on validation, but if its not test - run only the first view
        
        out = self(batch)
        
        # random_view = torch.randint(0, len(self.guidance.hist_xt), (1,)).item()
        with torch.no_grad():
            pred_x0_latent, noisy_latent = self.guidance(
                out["comp_rgb"], self.prompt_processor(), **batch, rgb_as_latents=False, test_call=True
            )
    
        if "depth_z" in out and "lambda_depth" in self.cfg.loss:  
            rescaled_nerf_depth = out["depth_z"] / out["depth_z"].max()
            predicted_depth = self.predict_normalized_depth(
                self.guidance.decode_latents(pred_x0_latent), rescaled_nerf_depth
            )
        
        self.save_image_grid(
            f"it{self.true_global_step}-{batch['index'][0]}.png",
            [
                {
                    "type": "rgb",
                    "img": out["comp_rgb"][0],
                    "kwargs": {"data_format": "HWC"},
                },
            ]
            + (
                [
                    {
                        "type": "rgb",
                        "img": out["comp_normal"][0],
                        "kwargs": {"data_format": "HWC", "data_range": (0, 1)},
                    }
                ]
                if "comp_normal" in out
                else []
            )
            + (
                [
                    {
                        "type": "grayscale",
                        "img": out["depth_d"][0, :, :, 0],
                        "kwargs": {"cmap": None, "data_range": (0, 1)},
                    }
                ]
                if "depth_d" in out
                else []
            )
            + (
                [
                    {
                        "type": "rgb",
                        "img": add_channel_to_image( ( get_img_eigenvalues(out["depth_d"]) * out["opacity"] ) )[0],
                        "kwargs": {"data_format": "HWC", "data_range": (-.1, .1)},
                    }
                ]
                if "depth_d" in out
                else []
            )
            + [
                {
                    "type": "grayscale",
                    "img": out["opacity"][0, :, :, 0],
                    "kwargs": {"cmap": None, "data_range": (0, 1)},
                },
            ]
            # + [
            #     {
            #         "type": "grayscale",
            #         "img": rescaled_nerf_depth[0, :, :, 0],
            #         "kwargs": {"cmap": None, "data_range": (0, 1)},
            #     },
            # ]
            + [
                {
                    "type": "rgb",
                    "img": self.guidance.decode_latents(noisy_latent)[0].permute(1, 2, 0),
                    "kwargs": {"data_format": "HWC"},
                },
            ]
            + [
                {
                    "type": "rgb",
                    "img": self.guidance.decode_latents(pred_x0_latent)[0].permute(1, 2, 0),
                    "kwargs": {"data_format": "HWC"},
                },
            ]
            + [
                {
                    "type": "rgb",
                    "img": self.guidance.decode_latents(pred_x0_latent)[0].permute(1, 2, 0) - out["comp_rgb"][0],
                    "kwargs": {"data_format": "HWC"},
                },
            ]
            + (
                [
                    {
                        "type": "grayscale",
                        "img": predicted_depth[0, :, :, 0],
                        "kwargs": {"cmap": None, "data_range": (0, 1)},
                    }
                ]
                if "lambda_depth" in self.cfg.loss
                else []
            )
            ,
            name="validation_step", 
            step=self.true_global_step,
        )

    def on_validation_epoch_end(self):
        if self.true_global_step % self.cfg.run_test_every == 0:
            self.on_test_epoch_end()

    def test_step(self, batch, batch_idx):
        out = self(batch)
        self.save_image_grid(
            f"it{self.true_global_step}-test/{batch['index'][0]}.png",
            [
                {
                    "type": "rgb",
                    "img": out["comp_rgb"][0],
                    "kwargs": {"data_format": "HWC"},
                },
            ]
            + (
                [
                    {
                        "type": "rgb",
                        "img": out["comp_normal"][0],
                        "kwargs": {"data_format": "HWC", "data_range": (0, 1)},
                    }
                ]
                if "comp_normal" in out
                else []
            )
            + [
                {
                    "type": "grayscale",
                    "img": out["opacity"][0, :, :, 0],
                    "kwargs": {"cmap": None, "data_range": (0, 1)},
                },
            ],
            name="test_step",
            step=self.true_global_step,
        )
    
    def on_test_epoch_end(self):
        self.save_img_sequence(
            f"it{self.true_global_step}-test",
            f"it{self.true_global_step}-test",
            "(\d+)\.png",
            save_format="mp4",
            fps=10,
            name="test",
            step=self.true_global_step,
        )
        
        # self.create_video_from_images(self.get_save_dir(), os.path.join(self.get_save_dir(), "progression_video.mp4"), fps=10)
        
    def sorted_alphanumeric(self, data):
        """
        Sort function to sort the file names alphanumerically based on the number in the filename.
        """
        convert = lambda text: int(text) if text.isdigit() else text.lower()
        alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
        return sorted(data, key=alphanum_key)

    def create_video_from_images(self, image_folder, output_video_file, fps=1):
        images = self.sorted_alphanumeric(glob.glob(os.path.join(image_folder, '*.png')))
        if not images:
            print("No images found in the folder.")
            return

        frame = cv2.imread(images[0])
        height, width, layers = frame.shape

        fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Using H.264 codec
        out = cv2.VideoWriter(output_video_file, fourcc, fps, (width, height))

        for image in images:
            frame = cv2.imread(image)
            filename = os.path.basename(image)
            cv2.putText(frame, filename, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
            out.write(frame)

        out.release()
        # cv2.destroyAllWindows()