import os
import einops
import lightning
from lightning.pytorch.utilities.types import TRAIN_DATALOADERS
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
import mcubes
import trimesh
import torch.nn.functional as F
from tqdm import trange

from hypernet_core import Hypernet, SIREN_tar
from dataloader.shapenet_voxel_datamodule import ShapeNetVoxelDataModule
from third_party.LDMI.utils.geometry import make_coord_grid
from pdb import set_trace as bb
from time import time

from utils.examples.radiance_fields.mlp import MLPDensityField
from utils.model_tools import hook_fn_decorator, summary

import logging
from lpips import LPIPS

# Import existing Conv3DEncoder from LDMI
from third_party.LDMI.ldm.modules.encoders.conv3d_encoder import Conv3DEncoder


# Import DiagonalGaussianDistribution for IPVAE-style encoding
from third_party.LDMI.ldm.modules.distributions.distributions import (
    DiagonalGaussianDistribution,
)

# Import BernoulliLoss for voxel occupancy prediction
from third_party.LDMI.ldm.modules.losses.occupancy_loss import BernoulliLoss
from third_party.LDMI.ldm.modules.decoders.tokenizers.latent_tokenizer import LatentTokenizer
eps = 1e-8
img2mse = lambda x, y: torch.mean((x - y) ** 2)
img2l1 = lambda x, y: torch.mean(torch.abs(x - y))
mse2psnr = lambda x: -10.0 * torch.log(x) / torch.log(torch.Tensor([10.0]).to(x.device))
psnr = lambda x, y, eps=1e-8: -10.0 * torch.log10(torch.mean(torch.square(x - y)) + eps)


def create_siren_mlp(args):
    """
    Init SIREN MLP and HyperOpt.
    """
    siren = SIREN_tar(
        D=args.netdepth,
        W=args.netwidth,
        input_ch=args.input_ch,
        output_ch=args.output_ch,
        out_bias=args.out_bias,
        omega=args.omega,
    )
    model = Hypernet(
        target_net=siren,
        ftask_dim=args.ftask_dim,
        weight_dim=args.weight_dim,
        batch_size=args.batch_size,
        deriv_hidden_dim=args.deriv_hidden_dim,
        driv_num_layers=args.driv_num_layers,
        codec_hidden_dim=args.codec_hidden_dim,
        codec_num_layers=args.codec_num_layers,
        num_layers=args.num_ho,
        weight_split_dim=args.weight_split_dim,
        ftask_adapter_num_layers=args.ftask_adapter_num_layers,
        enable_each_layer_lr=args.enable_each_layer_lr,
        enable_weight_init=args.enable_weight_init,
        use_hyper_crossattn=args.use_hyper_crossattn,  
        slice_dl_din=args.slice_dl_din,
    )

    return siren, model

class HoShapeNet(lightning.LightningModule):

    def __init__(self, args, data_module=None):
        super().__init__()
        self.args = args
        self.voxel_resolution = getattr(args, "voxel_resolution", 32)
        self.target_net, self.hypernet = create_siren_mlp(args)
        self.register_buffer('shared_coord', make_coord_grid((self.voxel_resolution, self.voxel_resolution, self.voxel_resolution), (-1, 1)), persistent=False)
        # Initialize Conv3D encoder from LDMI

        self.voxel_encoder = Conv3DEncoder(
            dim_z=32,  # Since LDMI encoder outputs 2*dim_z
            base_channels=64,
            dropout=0.1,
        )

        # Initialize GridDataConverter for coordinate handling (replacing custom position function)
        voxel_data_shape = (
            1,
            self.voxel_resolution,
            self.voxel_resolution,
            self.voxel_resolution,
        )


        # Add quantization layers for IPVAE-style encoding
        self.quant_conv = torch.nn.Conv2d(
            2 * 32, 2 * 32, 1
        )
        self.post_quant_conv = torch.nn.Conv2d(32, 32, 1)
        self.atten = getattr(args, "atten", False)
        self.tokenizer = LatentTokenizer(latent_dim=32, latent_size=[8,8], patch_size=1, dim=64, n_head=4, head_dim=48,  atten=self.atten)
        # Initialize BernoulliLoss for voxel occupancy prediction
        kl_weight = getattr(args, "kl_weight", 0.001)
        self.loss_fn = BernoulliLoss(kl_weight=kl_weight)

    def on_fit_start(self):
        self.get_model_summary()

    def get_model_summary(self):
        if not self.global_rank == 0:
            return
        if self.args.enable_weight_init:
            model_summary_dict = {
                "main-hypernet": summary(self.hypernet),
                "main-target_net": summary(self.target_net),
                "voxel-encoder": summary(self.voxel_encoder),
            }
        else:
            model_summary_dict = {
                "main-hypernet": summary(self.hypernet),
                "main-target_net": summary(self.target_net),
                "voxel-encoder": summary(self.voxel_encoder),
            }

        def format_param_count(count):
            if count >= 1_000_000:
                return f"{count/1_000_000:.2f}M"
            elif count >= 1_000:
                return f"{count/1_000:.2f}K"
            else:
                return f"{count}"

        model_summary_str = "Model Component Summary:\n"
        model_summary_str += "-" * 50 + "\n"
        model_summary_str += f"{'Component':<25} | {'Parameter Count':>20}\n"
        model_summary_str += "-" * 50 + "\n"

        main_components = ["main-hypernet", "voxel-encoder"]
        total_params = sum(
            model_summary_dict[comp]
            for comp in main_components
            if comp in model_summary_dict
        )

        for component, param_count in model_summary_dict.items():
            param_str = format_param_count(param_count)
            model_summary_str += f"{component:<25} | {param_str:>20}\n"

        model_summary_str += "-" * 50 + "\n"
        total_str = format_param_count(total_params)
        model_summary_str += f"{'Total (main components)':<25} | {total_str:>20}\n"
        model_summary_str += "-" * 50 + "\n"

        print(model_summary_str)
        open(os.path.join(self.logger.log_dir, "model_summary.txt"), "w").write(
            model_summary_str
        )

    def configure_optimizers(self):
        output_modulation = list(self.target_net.modulation_weight_scale) + list(self.target_net.modulation_bias_scale) + list(self.target_net.modulation_bias_shift)
        id_set = set(id(p) for p in output_modulation)
        param_list = list(p for p in self.parameters() if id(p) not in id_set)

        # return Adam(self.parameters(), lr=self.args.lrate, betas=(0.9, 0.999))
        return Adam(
            [
                {
                    "params": param_list,
                    "lr": self.args.lrate,
                },
                {
                    "params": output_modulation,
                    "lr": self.args.lrate_modulation,
                },
            ]
        )

    def encode(self, features, sample_posterior):

        # Encode using Conv3D encoder
        h = self.voxel_encoder(features)  # Output: (B, 2*dim_z, H', W') where the depth channel is meaned out.

        # Apply quantization conv (similar to IPVAE)
        moments = self.quant_conv(h)
        posterior = DiagonalGaussianDistribution(moments)

        if sample_posterior:
            z = posterior.sample()
        else:
            z = posterior.mode()
        # bb()
        return z, posterior



    def predict_voxels(self, latent_embedding, coordinates=None):
        """Predict voxel values using the target network and hypernet"""
        # Use data converter coordinates if not provided
        if coordinates is None:
            coordinates = self.shared_coord
        # Expand coordinates for batch size
        if coordinates.dim() == 4:
            coordinates = einops.repeat(coordinates, 'h w l d -> b h w l d', b=latent_embedding.shape[0])
        # bb()
        # latent_embedding = einops.rearrange(latent_embedding, "b c h w -> b (h w) c")

        # Generate weights using hypernet
        self.hypernet.generate_weights(latent_embedding)

        # Predict densities using target network
        densities = self.hypernet(coordinates)  # (N, 1)
        # densities = self.target_net(coordinates)

        # Reshape back to voxel grid
        B = latent_embedding.shape[0]
        res = self.voxel_resolution
        predicted_voxels = densities.view(B, 1, res, res, res)

        return predicted_voxels

    def voxel_loss(self, gt_voxels, pred_voxels, posterior, split="train"):
        """Use BernoulliLoss for voxel occupancy prediction"""
        return self.loss_fn(gt_voxels, pred_voxels, posterior, split=split)

    def get_warmup_lr(self):
        return 1.0

    def forward(self, voxels, sample_posterior=True):
        self.hypernet.cleanup()
        z, posterior = self.encode(voxels, sample_posterior=sample_posterior)
        # bb()
        z = self.post_quant_conv(z)
        z = self.tokenizer(z)
        pred_voxels = self.predict_voxels(z)
        return pred_voxels, posterior
    
    def get_input(self, batch):
        x = batch[0]
        return x
    
    def training_step(self, batch, batch_idx):
        voxels = self.get_input(batch)
        # bb()
        pred_voxels, posterior  = self(voxels)
        # if self.global_step < self.args.weight_init_steps:
        #     total_loss = self.weight_init_loss()
        #     log_dict = {
        #         "train/weight-init-loss": total_loss,
        #     }
        # else:
        # Use BernoulliLoss
        total_loss, log_dict = self.voxel_loss(
            voxels, pred_voxels, posterior, split="train"
        )
        log_dict["train/lr"] = self.optimizers().param_groups[0]["lr"]
        self.log("train/total_loss", total_loss, prog_bar=True, logger=True, on_step=True, on_epoch=False, sync_dist=True)
        self.log_dict(log_dict, prog_bar=False, logger=True, on_step=True, on_epoch=False, sync_dist=True)
        # Update log dict with learning rate

        return total_loss

    def val_test_step(self, batch, batch_idx, mode, dataloader_idx=None):
        voxels = self.get_input(batch)

        pred_voxels, posterior= self(voxels)

        # Use BernoulliLoss
        total_loss, log_dict = self.voxel_loss(
            voxels, pred_voxels, posterior, split=mode
        )
        self.log(f"{mode}/total_loss", total_loss, prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
        self.log_dict(log_dict, prog_bar=False, logger=True, on_step=False, on_epoch=True, sync_dist=True)


    def validation_step(self, batch, batch_idx, dataloader_idx=None):
        return self.val_test_step(
            batch, batch_idx, "val", dataloader_idx=dataloader_idx
        )

    def test_step(self, batch, batch_idx, dataloader_idx=None):
        return self.val_test_step(
            batch, batch_idx, "test", dataloader_idx=dataloader_idx
        )
    
    @torch.no_grad()
    def log_images(self, batch, only_inputs=False, super_reconstruction=False, **kwargs):
        voxels = self.get_input(batch)
        
        pred_voxels, posterior = self(voxels, sample_posterior=False)
        pred_voxels = torch.sigmoid(pred_voxels)
        log = dict()
        log["inputs"] = voxels.cpu()
        log["reconstructions"] = pred_voxels.cpu()
        # self.hypernet.generate_weights(posterior.sample())
        # samples = torch.sigmoid(self.hypernet(coords))
        # samples = samples.reshape(*voxels.shape).cpu()
        # bb()
        super_reconstruction = False
        if super_reconstruction:
            resolution = [64,64,64]
            
            # Repeat for batch size
            coords = coords.unsqueeze(0).repeat([pred_voxels.shape[0], 1, 1])
            
            # Generate weights using hypernet
            self.hypernet.generate_weights(posterior.mode())
            
            # Predict super-resolution voxels
            super_voxels = torch.sigmoid(self.hypernet(coords)).reshape(pred_voxels.shape[0], 1, *resolution)
            
            log["super_reconstructions"] = super_voxels

        return log