from dataclasses import dataclass, field

import glob
import os
import re
import cv2
import torch

import threestudio
from threestudio.systems.base import BaseLift3DSystem
from threestudio.utils.ops import binary_cross_entropy, dot
from threestudio.utils.typing import *

from threestudio.utils.base import update_end_if_possible


@threestudio.register("ours-system")
class OursSystem(BaseLift3DSystem):
    @dataclass
    class Config(BaseLift3DSystem.Config):
        project_every: Any = 1
        run_test_every: Any = 1000
        pass

    cfg: Config

    def configure(self):
        # create geometry, material, background, renderer
        super().configure()
        self.projection_queue = []

    def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]:
        render_out = self.renderer(**batch)
        return {
            **render_out,
        }

    def on_fit_start(self) -> None:
        super().on_fit_start()
        # only used in training
        self.prompt_processor = threestudio.find(self.cfg.prompt_processor_type)(
            self.cfg.prompt_processor
        )
        self.guidance = threestudio.find(self.cfg.guidance_type)(self.cfg.guidance)

    def training_step(self, batch, batch_idx):
        # NOTE: projection avergaing system
        # if len(self.projection_queue) >= self.cfg.project_every:
        #     for i in self.projection_queue:
        #         # Render current latent
        #         self.quidance.update_dual_trajectory(i, t, latent, noise_pred)
        #     self.projection_queue = []
        # else:
        #     for i in range(len(batch["index"])):
        #         self.projection_queue.append(batch["index"][i])
            
        out = self(batch)
        prompt_utils = self.prompt_processor()
        guidance_out = self.guidance(
            out["comp_rgb"], prompt_utils, **batch, rgb_as_latents=False
        )

        loss = 0.0

        for name, value in guidance_out.items():
            if not (type(value) is torch.Tensor and value.numel() > 1):
                self.log(f"train/{name}", value)
            if name.startswith("loss_"):
                loss += value * self.C(self.cfg.loss[name.replace("loss_", "lambda_")])

        if self.C(self.cfg.loss.lambda_orient) > 0:
            if "normal" not in out:
                raise ValueError(
                    "Normal is required for orientation loss, no normal is found in the output."
                )
            loss_orient = (
                out["weights"].detach()
                * dot(out["normal"], out["t_dirs"]).clamp_min(0.0) ** 2
            ).sum() / (out["opacity"] > 0).sum()
            self.log("train/loss_orient", loss_orient)
            loss += loss_orient * self.C(self.cfg.loss.lambda_orient)

        loss_sparsity = (out["opacity"] ** 2 + 0.01).sqrt().mean()
        self.log("train/loss_sparsity", loss_sparsity)
        loss += loss_sparsity * self.C(self.cfg.loss.lambda_sparsity)

        opacity_clamped = out["opacity"].clamp(1.0e-3, 1.0 - 1.0e-3)
        loss_opaque = binary_cross_entropy(opacity_clamped, opacity_clamped)
        self.log("train/loss_opaque", loss_opaque)
        loss += loss_opaque * self.C(self.cfg.loss.lambda_opaque)

        # z-variance loss proposed in HiFA: https://hifa-team.github.io/HiFA-site/
        if "z_variance" in out and "lambda_z_variance" in self.cfg.loss:
            loss_z_variance = out["z_variance"][out["opacity"] > 0.5].mean()
            self.log("train/loss_z_variance", loss_z_variance)
            loss += loss_z_variance * self.C(self.cfg.loss.lambda_z_variance)

        for name, value in self.cfg.loss.items():
            self.log(f"train_params/{name}", self.C(value))

        return {"loss": loss}
    
    # def on_train_batch_end(self, outputs, batch, batch_idx):
    #     self.dataset = self.trainer.train_dataloader.dataset
    #     update_end_if_possible(
    #         self.dataset, self.true_current_epoch, self.true_global_step
    #     )
    #     self.do_update_step_end(self.true_current_epoch, self.true_global_step)
    #     self.guidance.update_batch()

    def validation_step(self, batch, batch_idx):
        
        if self.true_global_step % self.cfg.run_test_every == 0:
            self.test_step(batch, batch_idx)
        
        if batch['index'][0] != 0:
            return # We sample the whole orbit on validation, but if its not test - run only the first view
        
        out = self(batch)
        
        # random_view = torch.randint(0, len(self.guidance.hist_xt), (1,)).item()
        
        idx_1 = 0
        idx_2 = len(self.guidance.hist_xt) // 3
        idx_3 = len(self.guidance.hist_xt) * 2 // 3 
        
        self.save_image_grid(
            f"it{self.true_global_step}-{batch['index'][0]}.png",
            [
                {
                    "type": "rgb",
                    "img": out["comp_rgb"][0],
                    "kwargs": {"data_format": "HWC"},
                },
            ]
            + (
                [
                    {
                        "type": "rgb",
                        "img": out["comp_normal"][0],
                        "kwargs": {"data_format": "HWC", "data_range": (0, 1)},
                    }
                ]
                if "comp_normal" in out
                else []
            )
            + [
                {
                    "type": "grayscale",
                    "img": out["opacity"][0, :, :, 0],
                    "kwargs": {"cmap": None, "data_range": (0, 1)},
                },
            ]
            + [
                {
                    "type": "rgb",
                    "img": self.guidance.decode_latents(self.guidance.hist_xt[idx_1])[0].permute(1, 2, 0),
                    "kwargs": {"data_format": "HWC"},
                },
            ]
            + [
                {
                    "type": "rgb",
                    "img": self.guidance.decode_latents(self.guidance.hist_last_target[idx_1])[0].permute(1, 2, 0),
                    "kwargs": {"data_format": "HWC"},
                },
            ]
            + [
                {
                    "type": "rgb",
                    "img": self.guidance.decode_latents(self.guidance.hist_xt[idx_2])[0].permute(1, 2, 0),
                    "kwargs": {"data_format": "HWC"},
                },
            ]
            + [
                {
                    "type": "rgb",
                    "img": self.guidance.decode_latents(self.guidance.hist_last_target[idx_2])[0].permute(1, 2, 0),
                    "kwargs": {"data_format": "HWC"},
                },
            ]
            + [
                {
                    "type": "rgb",
                    "img": self.guidance.decode_latents(self.guidance.hist_xt[idx_3])[0].permute(1, 2, 0),
                    "kwargs": {"data_format": "HWC"},
                },
            ]
            + [
                {
                    "type": "rgb",
                    "img": self.guidance.decode_latents(self.guidance.hist_last_target[idx_3])[0].permute(1, 2, 0),
                    "kwargs": {"data_format": "HWC"},
                },
            ]
            ,
            name="validation_step", 
            step=self.true_global_step,
        )

    def on_validation_epoch_end(self):
        if self.true_global_step % self.cfg.run_test_every == 0:
            self.on_test_epoch_end()

    def test_step(self, batch, batch_idx):
        out = self(batch)
        self.save_image_grid(
            f"it{self.true_global_step}-test/{batch['index'][0]}.png",
            [
                {
                    "type": "rgb",
                    "img": out["comp_rgb"][0],
                    "kwargs": {"data_format": "HWC"},
                },
            ]
            + (
                [
                    {
                        "type": "rgb",
                        "img": out["comp_normal"][0],
                        "kwargs": {"data_format": "HWC", "data_range": (0, 1)},
                    }
                ]
                if "comp_normal" in out
                else []
            )
            + [
                {
                    "type": "grayscale",
                    "img": out["opacity"][0, :, :, 0],
                    "kwargs": {"cmap": None, "data_range": (0, 1)},
                },
            ],
            name="test_step",
            step=self.true_global_step,
        )
        
        batch_idx = batch['index'][0].item()
        # print("\n\n", batch_idx, type(batch_idx))
        # print(self.guidance.hist_xt.keys())
        self.save_image_grid(
            f"it{self.true_global_step}-test-views/{batch_idx}.png",
            [
                {
                    "type": "rgb",
                    "img": self.guidance.decode_latents(self.guidance.hist_xt[batch_idx])[0].permute(1, 2, 0),
                    "kwargs": {"data_format": "HWC"},
                },
            ]
            + [
                {
                    "type": "rgb",
                    "img": self.guidance.decode_latents(self.guidance.hist_last_target[batch_idx])[0].permute(1, 2, 0),
                    "kwargs": {"data_format": "HWC"},
                },
            ],
            name="test_step",
            step=self.true_global_step,
        )

    def on_test_epoch_end(self):
        self.save_img_sequence(
            f"it{self.true_global_step}-test",
            f"it{self.true_global_step}-test",
            "(\d+)\.png",
            save_format="mp4",
            fps=10,
            name="test",
            step=self.true_global_step,
        )
        
        self.save_img_sequence(
            f"it{self.true_global_step}-test-views",
            f"it{self.true_global_step}-test-views",
            "(\d+)\.png",
            save_format="mp4",
            fps=10,
            name="test",
            step=self.true_global_step,
        )
        
        self.create_video_from_images(self.get_save_dir(), os.path.join(self.get_save_dir(), "progression_video.mp4"), fps=10)
        
    def sorted_alphanumeric(self, data):
        """
        Sort function to sort the file names alphanumerically based on the number in the filename.
        """
        convert = lambda text: int(text) if text.isdigit() else text.lower()
        alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
        return sorted(data, key=alphanum_key)

    def create_video_from_images(self, image_folder, output_video_file, fps=1):
        images = self.sorted_alphanumeric(glob.glob(os.path.join(image_folder, '*.png')))
        if not images:
            print("No images found in the folder.")
            return

        frame = cv2.imread(images[0])
        height, width, layers = frame.shape

        fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Using H.264 codec
        out = cv2.VideoWriter(output_video_file, fourcc, fps, (width, height))

        for image in images:
            frame = cv2.imread(image)
            filename = os.path.basename(image)
            cv2.putText(frame, filename, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
            out.write(frame)

        out.release()
        # cv2.destroyAllWindows()