import os
import einops
import lightning
from lightning.pytorch.utilities.types import TRAIN_DATALOADERS
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
import mcubes
import trimesh
import torch.nn.functional as F
from tqdm import trange

from hypernet_core import Hypernet, SIREN_tar, TokenCompressor, DownsampleConv, MLP

from pdb import set_trace as bb
from time import time

from utils.examples.radiance_fields.mlp import MLPDensityField
from utils.model_tools import hook_fn_decorator, summary

import logging
from lpips import LPIPS

# Import existing Conv3DEncoder from LDMI
from third_party.LDMI.ldm.modules.diffusionmodules.model import Encoder
from third_party.LDMI.utils.geometry import make_coord_grid

# Import GridDataConverter from conversion.py
from third_party.LDMI.ldm.data.data_converters.conversion import ERA5Converter

# Import DiagonalGaussianDistribution for IPVAE-style encoding
from third_party.LDMI.ldm.modules.distributions.distributions import (
    DiagonalGaussianDistribution,
)

# Import BernoulliLoss for voxel occupancy prediction
from third_party.LDMI.ldm.modules.losses.vqperceptual import VQLPIPSWithDiscriminator
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from third_party.LDMI.ldm.modules.decoders.tokenizers.latent_tokenizer import LatentTokenizer

eps = 1e-8
img2mse = lambda x, y: torch.mean((x - y) ** 2)
img2l1 = lambda x, y: torch.mean(torch.abs(x - y))
mse2psnr = lambda x: -10.0 * torch.log(x) / torch.log(torch.Tensor([10.0]).to(x.device))
psnr = lambda x, y, eps=1e-8: -10.0 * torch.log10(torch.mean(torch.square(x - y)) + eps)


def create_siren_mlp(args):
    """
    Init SIREN MLP and HyperOpt.
    """
    siren = SIREN_tar(
        D=args.netdepth,
        W=args.netwidth,
        input_ch=args.input_ch,
        output_ch=args.output_ch,
        out_bias=args.out_bias,
        omega=args.omega,
    )
    model = Hypernet(
        target_net=siren,
        ftask_dim=args.ftask_dim,
        weight_dim=args.weight_dim,
        batch_size=args.batch_size,
        deriv_hidden_dim=args.deriv_hidden_dim,
        driv_num_layers=args.driv_num_layers,
        codec_hidden_dim=args.codec_hidden_dim,
        codec_num_layers=args.codec_num_layers,
        num_layers=args.num_ho,
        ftask_adapter_num_layers=args.ftask_adapter_num_layers,
        enable_each_layer_lr=args.enable_each_layer_lr,
        enable_weight_init=args.enable_weight_init,
        use_hyper_crossattn=args.use_hyper_crossattn,  
        slice_dl_din=args.slice_dl_din,
    )

    return siren, model

class HoImageNet(lightning.LightningModule):

    def __init__(self, args, data_module=None):
        super().__init__()
        self.args = args
        self.voxel_resolution = getattr(args, "voxel_resolution", [256,256])
        self.n_embed = 8192
        self.automatic_optimization = False
        self.target_net, self.hypernet = create_siren_mlp(args)
        self.image_key = "image"
        self.register_buffer('shared_coord', make_coord_grid((self.voxel_resolution[0], self.voxel_resolution[1]), (-1, 1)), persistent=False)
        self.voxel_encoder = Encoder(
            double_z=False,
            z_channels=3,
            resolution=256,
            in_channels=3,
            out_ch=3,
            ch=128,
            ch_mult=[ 1,2,4 ],  # num_down = len(ch_mult)-1
            num_res_blocks=2,
            attn_resolutions=[],
            dropout=0.0
        )
        
        # Add quantization layers for IPVAE-style encoding
        self.quant_conv = torch.nn.Conv2d(3, 3, 1)
        self.post_quant_conv = torch.nn.Conv2d(3, 3, 1)
        self.quantize = VectorQuantizer(self.n_embed, 3, beta=0.25, remap=None, sane_index_shape=False)
        # Initialize BernoulliLoss for voxel occupancy prediction
        self.loss_fn = VQLPIPSWithDiscriminator(
            disc_conditional=False,
            disc_in_channels=3,
            disc_num_layers=3,
            disc_start=1,
            disc_weight=0.75,
            disc_factor=0.5,
            codebook_weight=1.0,
            n_classes=self.n_embed,
            perceptual_weight=0.1
        )
        self.atten = getattr(args, "atten", False)

        self.tokenizer = LatentTokenizer(latent_dim=3, latent_size=64, patch_size=4, dim=768, n_head=4, head_dim=32,  atten=self.atten)

        self.freeze_encoder = getattr(args, "freeze_encoder", True)
        if self.freeze_encoder:
            print("#########################[HoImageNet] Freezing voxel_encoder parameters.#############################")
            for p in (
                list(self.voxel_encoder.parameters()) +
                list(self.quant_conv.parameters()) +
                list(self.post_quant_conv.parameters()) +
                list(self.quantize.parameters())
            ):
                p.requires_grad = False
    def on_fit_start(self):
        self.get_model_summary()

    def get_model_summary(self):
        if not self.global_rank == 0:
            return
        if self.args.enable_weight_init:
            model_summary_dict = {
                "main-hypernet": summary(self.hypernet),
                "main-target_net": summary(self.target_net),
                "voxel-encoder": summary(self.voxel_encoder),
            }
        else:
            model_summary_dict = {
                "main-hypernet": summary(self.hypernet),
                "main-target_net": summary(self.target_net),
                "voxel-encoder": summary(self.voxel_encoder),
            }

        def format_param_count(count):
            if count >= 1_000_000:
                return f"{count/1_000_000:.2f}M"
            elif count >= 1_000:
                return f"{count/1_000:.2f}K"
            else:
                return f"{count}"

        model_summary_str = "Model Component Summary:\n"
        model_summary_str += "-" * 50 + "\n"
        model_summary_str += f"{'Component':<25} | {'Parameter Count':>20}\n"
        model_summary_str += "-" * 50 + "\n"

        main_components = ["main-hypernet", "voxel-encoder"]
        total_params = sum(
            model_summary_dict[comp]
            for comp in main_components
            if comp in model_summary_dict
        )

        for component, param_count in model_summary_dict.items():
            param_str = format_param_count(param_count)
            model_summary_str += f"{component:<25} | {param_str:>20}\n"

        model_summary_str += "-" * 50 + "\n"
        total_str = format_param_count(total_params)
        model_summary_str += f"{'Total (main components)':<25} | {total_str:>20}\n"
        model_summary_str += "-" * 50 + "\n"

        print(model_summary_str)
        open(os.path.join(self.logger.log_dir, "model_summary.txt"), "w").write(
            model_summary_str
        )
        
    def configure_optimizers(self):
        output_modulation = list(self.target_net.modulation_weight_scale)  + list(self.target_net.modulation_bias_scale) + list(self.target_net.modulation_bias_shift)
        # output_modulation = list(self.target_net.modulation_weight_scale) + list(self.target_net.modulation_weight_shift) + list(self.target_net.modulation_bias)     
        id_set = set(id(p) for p in output_modulation)
        disc_params = list(self.loss_fn.discriminator.parameters())
        disc_set = set(id(p) for p in disc_params)
        param_list = list(p for p in self.parameters() if id(p) not in id_set and id(p) not in disc_set)
 
        # return Adam(self.parameters(), lr=self.args.lrate, betas=(0.9, 0.999))
        optimizer_main = Adam(
            [
                {"params": param_list, "lr": self.args.lrate},
                {"params": output_modulation, "lr": self.args.lrate_modulation},
            ]
        )
        optimizer_disc = Adam(
            [
                {"params": disc_params, "lr": self.args.disc_lr},
            ]
        )
        return [optimizer_main, optimizer_disc]

    def register_grad_hooks(self, module, name_prefix=""):
        for name, param in module.named_parameters():
            if param.requires_grad:
                def make_hook(n):
                    return lambda grad: print(
                        f"{n}: grad mean={grad.mean():.4e}, std={grad.std():.4e}"
                    )
                hook_name = f"{name_prefix}.{name}"
                param.register_hook(make_hook(hook_name))

    def encode(self, features, quantize=False):
        # Encode using Conv3D encoder
        h = self.voxel_encoder(features)  # Output: (B, 2*dim_z, H', W') where the depth channel is meaned out.
        # bb()
        # Apply quantization conv (similar to IPVAE)
        moments = self.quant_conv(h)
        if quantize:
            quant, emb_loss, info = self.quantize(moments)
            return quant, emb_loss, info
        return moments

    def apply_data_converter(self, x):
        if hasattr(self, "data_converter"):
            if hasattr(self.data_converter, "batch_to_coordinates_and_features"):
                _, features = self.data_converter.batch_to_coordinates_and_features(x)
            else:
                features = self.data_converter(x)
        return features

    def predict_voxels(self, latent_embedding, coordinates=None):
        """Predict voxel values using the target network and hypernet"""
        res = self.voxel_resolution
        B = latent_embedding.shape[0]
        if coordinates is None:
            coordinates = self.shared_coord
        if len(coordinates.shape) == 3:
            coordinates = einops.repeat(coordinates, 'h w d -> b h w d', b=B)        # Use data converter coordinates if not provided

        # latent_embedding = einops.rearrange(latent_embedding, "b c h w -> b (h w) c")


        # Generate weights using hypernet
        weights = self.hypernet.generate_weights(latent_embedding)
        self.last_layer, weights = self.get_last_layer(weights)
        # self.hypernet.final_weight_dicts = weights
        # Predict densities using target network
        rgb = self.hypernet(coordinates, global_step=self.global_step)  # (N, 1)
        # bb()
        predicted_voxels = rgb.permute(0, 3, 1, 2)
        # predicted_voxels.retain_grad()
        return predicted_voxels

    def get_last_layer(self, weights):
        last_idx = self.target_net.get_last_layer_idx()
        weight = weights[0][f'layers.{last_idx}.weight'].permute(0, 2, 1)  
        bias = weights[0][f'layers.{last_idx}.bias'].unsqueeze(-1).permute(0, 2, 1)  
        wb = torch.cat([weight, bias], dim=1) 
        # reverse the weight and bias for backprop
        weight_reverse, bias_reverse = wb.split([weight.shape[1], bias.shape[1]], dim=1)
        weight_reverse = weight_reverse.permute(0, 2, 1)
        bias_reverse = bias_reverse.permute(0, 2, 1).squeeze(-1)

        weights[-1][f'layers.{last_idx}.weight'] = weight_reverse
        weights[-1][f'layers.{last_idx}.bias'] = bias_reverse
        # wb.retain_grad()
        # weight.retain_grad()
        # bias.retain_grad()
        return wb, weights

    def voxel_loss(self, codebook_loss, gt_voxels, pred_voxels, optimizer_idx, global_step, last_layer, cond=None, split="train", predicted_indices=None):
        """Use BernoulliLoss for voxel occupancy prediction"""
        # bb()
        return self.loss_fn(
            codebook_loss,
            gt_voxels, 
            pred_voxels, 
            optimizer_idx=optimizer_idx, 
            global_step=self.global_step, 
            last_layer=self.last_layer, 
            split=split, 
            cond=None,
            predicted_indices=predicted_indices
        )

    def get_warmup_lr(self):
        return 1.0

    def forward(self, input, return_pred_indices=False):
        self.hypernet.cleanup()
        quant, diff, (_,_,ind) = self.encode(input, quantize=True)

        quant = self.post_quant_conv(quant)

        quant = self.tokenizer(quant)
        # bb()
        pred_voxels = self.predict_voxels(quant)
        if return_pred_indices:
            return pred_voxels, diff, ind
        return pred_voxels, diff

    
    def get_input(self, batch, k):
        try:
            x = batch[k]
            if len(x.shape) == 3:
                x = x[..., None]
            x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
        except:
            x = batch[0].to(memory_format=torch.contiguous_format).float()
        # Old from ldm repo:
        """x = batch[k]
        if len(x.shape) == 3:
            x = x[..., None]
        x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()"""
        return x
    
    def training_step(self, batch, batch_idx):
        optimizer_main, optimizer_disc = self.optimizers()
        x = self.get_input(batch, self.image_key)
        # bb()
        xrec, qloss, ind = self(x, return_pred_indices=True)
        if self.global_step % 2 == 0:
            optimizer_main.zero_grad()

            aeloss, log_dict_ae = self.voxel_loss(qloss, x, xrec, 
                                            optimizer_idx=0, 
                                            global_step=self.global_step,
                                            last_layer=self.last_layer, 
                                            split="train",
                                            predicted_indices=ind)
            self.manual_backward(aeloss)
            optimizer_main.step()
            self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False, sync_dist=True)

        # train discriminator
        else:
            optimizer_disc.zero_grad()
            # discloss, log_dict_disc = self.voxel_loss(None, x, xrec, 
            discloss, log_dict_disc = self.voxel_loss(qloss, x, xrec, 
                                            optimizer_idx=1, 
                                            global_step=self.global_step,
                                            last_layer=self.last_layer, 
                                            split="train")
            self.manual_backward(discloss)
            # bb()
            optimizer_disc.step()
            self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False, sync_dist=True)



    def val_test_step(self, batch, batch_idx, mode, suffix="", dataloader_idx=None):
        x = self.get_input(batch, self.image_key)
        # bb()
        xrec, qloss, ind = self(x, return_pred_indices=True)
        # xrec = self(x, return_pred_indices=True)
        # aeloss, log_dict_ae = self.voxel_loss(None, x, xrec, 
        aeloss, log_dict_ae = self.voxel_loss(qloss, x, xrec, 
                                        optimizer_idx=0, 
                                        global_step=self.global_step,
                                        last_layer=self.last_layer, 
                                        split=mode+suffix,
                                        predicted_indices=ind)
                                        # predicted_indices=None)
        # bb()
        # discloss, log_dict_disc = self.voxel_loss(None, x, xrec, 
        discloss, log_dict_disc = self.voxel_loss(qloss, x, xrec, 
                                        optimizer_idx=1, 
                                        global_step=self.global_step,
                                        last_layer=self.last_layer, 
                                        split=mode+suffix,
                                        # predicted_indices=None)
                                        predicted_indices=ind)

        self.log(f"{mode}{suffix}/aeloss", aeloss,
                   prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
        # if version.parse(pl.__version__) >= version.parse('1.4.0'):
        #     del log_dict_ae[f"{mode}{suffix}/rec_loss"]
        self.log_dict(log_dict_ae, sync_dist=True)
        self.log_dict(log_dict_disc, sync_dist=True)
        # bb()
        return self.log_dict


    def validation_step(self, batch, batch_idx, dataloader_idx=None):
        return self.val_test_step(
            batch, batch_idx, "val", dataloader_idx=dataloader_idx
        )

    def test_step(self, batch, batch_idx, dataloader_idx=None):
        return self.val_test_step(
            batch, batch_idx, "test", dataloader_idx=dataloader_idx
        )
    
    @torch.no_grad()
    def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
        log = dict()
        x = self.get_input(batch, self.image_key)
        x = x.to(self.device)
        if only_inputs:
            log["inputs"] = x
            return log
        xrec, _ = self(x)
        # xrec = self(x)
        xrec = torch.clamp(xrec, -1, 1)
        if x.shape[1] > 3:
            # colorize with random projection
            assert xrec.shape[1] > 3
            x = self.to_rgb(x)
            xrec = self.to_rgb(xrec)
        log["inputs"] = x
        log["reconstructions"] = xrec
        if plot_ema:
            with self.ema_scope():
                xrec_ema, _ = self(x)
                if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
                log["reconstructions_ema"] = xrec_ema
        return log

    def to_rgb(self, x):
        assert self.image_key == "segmentation"
        if not hasattr(self, "colorize"):
            self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
        x = F.conv2d(x, weight=self.colorize)
        x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
        return x