import os
import einops
import lightning
from lightning.pytorch.utilities.types import TRAIN_DATALOADERS
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
import mcubes
import trimesh
import torch.nn.functional as F
from tqdm import trange

from hypernet_core import Hypernet, SIREN_tar, DownsampleConv

from pdb import set_trace as bb
from time import time

from utils.examples.radiance_fields.mlp import MLPDensityField
from utils.model_tools import hook_fn_decorator, summary

import logging
from lpips import LPIPS

# Import existing Conv3DEncoder from LDMI
from third_party.LDMI.ldm.modules.diffusionmodules.model import Encoder
from third_party.LDMI.utils.geometry import make_coord_grid

# Import GridDataConverter from conversion.py
from third_party.LDMI.ldm.data.data_converters.conversion import ERA5Converter

# Import DiagonalGaussianDistribution for IPVAE-style encoding
from third_party.LDMI.ldm.modules.distributions.distributions import (
    DiagonalGaussianDistribution,
)

# Import BernoulliLoss for voxel occupancy prediction
from third_party.LDMI.ldm.modules.losses import GaussianLoss
from collections import defaultdict
from third_party.LDMI.ldm.modules.decoders.tokenizers.latent_tokenizer import LatentTokenizer
eps = 1e-8
img2mse = lambda x, y: torch.mean((x - y) ** 2)
img2l1 = lambda x, y: torch.mean(torch.abs(x - y))
mse2psnr = lambda x: -10.0 * torch.log(x) / torch.log(torch.Tensor([10.0]).to(x.device))
psnr = lambda x, y, eps=1e-8: -10.0 * torch.log10(torch.mean(torch.square(x - y)) + eps)


def create_siren_mlp(args):
    """
    Init SIREN MLP and HyperOpt.
    """
    siren = SIREN_tar(
        D=args.netdepth,
        W=args.netwidth,
        input_ch=args.input_ch,
        output_ch=args.output_ch,
        out_bias=args.out_bias,
        omega=args.omega,
    )
    model = Hypernet(
        target_net=siren,
        ftask_dim=args.ftask_dim,
        weight_dim=args.weight_dim,
        batch_size=args.batch_size,
        deriv_hidden_dim=args.deriv_hidden_dim,
        driv_num_layers=args.driv_num_layers,
        codec_hidden_dim=args.codec_hidden_dim,
        codec_num_layers=args.codec_num_layers,
        num_layers=args.num_ho,
        weight_split_dim=args.weight_split_dim,
        ftask_adapter_num_layers=args.ftask_adapter_num_layers,
        enable_each_layer_lr=args.enable_each_layer_lr,
        enable_weight_init=args.enable_weight_init,
        use_hyper_crossattn=args.use_hyper_crossattn,  
        slice_dl_din=args.slice_dl_din,
    )

    return siren, model

class HoERA5(lightning.LightningModule):

    def __init__(self, args, data_module=None):
        super().__init__()
        self.args = args
        self.voxel_resolution = getattr(args, "voxel_resolution", [46,90])
        self.target_net, self.hypernet = create_siren_mlp(args)
        self.image_key = "image"
        self.voxel_encoder = Encoder(
            double_z=True,
            z_channels=3,
            resolution=64,
            in_channels=1,
            out_ch=3,
            ch=32,
            ch_mult=[ 1,2,4 ],  # num_down = len(ch_mult)-1
            num_res_blocks=2,
            attn_resolutions=[ ],
        )
        self.data_converter = ERA5Converter(
            data_shape=[46,90], normalize_features=False
        )
        if hasattr(self.data_converter, 'coordinates'):
            coords = self.data_converter.coordinates
            self.register_buffer('shared_coord', coords, persistent=False)
        
        # Add quantization layers for IPVAE-style encoding
        self.quant_conv = torch.nn.Conv2d(
            2 * 3, 2 * 3, 1
        )
        self.post_quant_conv = torch.nn.Conv2d(
            3, 3, 1
        )
        self.kl_weight = getattr(args, "kl_weight", 1.0e-06)
        # self.automatic_optimization = False
        # Initialize BernoulliLoss for voxel occupancy prediction
        self.loss_fn = GaussianLoss(kl_weight=self.kl_weight)
        self.atten = getattr(args, "atten", False)
        self.tokenizer = LatentTokenizer(latent_dim=3, latent_size=[11,22], patch_size=1, dim=104, n_head=4, head_dim=32,  atten=self.atten)

    def on_fit_start(self):
        self.get_model_summary()

    def get_model_summary(self):
        if not self.global_rank == 0:
            return
        if self.args.enable_weight_init:
            model_summary_dict = {
                "main-hypernet": summary(self.hypernet),
                "main-target_net": summary(self.target_net),
                "voxel-encoder": summary(self.voxel_encoder),
            }
        else:
            model_summary_dict = {
                "main-hypernet": summary(self.hypernet),
                "main-target_net": summary(self.target_net),
                "voxel-encoder": summary(self.voxel_encoder),
            }

        def format_param_count(count):
            if count >= 1_000_000:
                return f"{count/1_000_000:.2f}M"
            elif count >= 1_000:
                return f"{count/1_000:.2f}K"
            else:
                return f"{count}"

        model_summary_str = "Model Component Summary:\n"
        model_summary_str += "-" * 50 + "\n"
        model_summary_str += f"{'Component':<25} | {'Parameter Count':>20}\n"
        model_summary_str += "-" * 50 + "\n"

        main_components = ["main-hypernet", "voxel-encoder"]
        total_params = sum(
            model_summary_dict[comp]
            for comp in main_components
            if comp in model_summary_dict
        )

        for component, param_count in model_summary_dict.items():
            param_str = format_param_count(param_count)
            model_summary_str += f"{component:<25} | {param_str:>20}\n"

        model_summary_str += "-" * 50 + "\n"
        total_str = format_param_count(total_params)
        model_summary_str += f"{'Total (main components)':<25} | {total_str:>20}\n"
        model_summary_str += "-" * 50 + "\n"

        print(model_summary_str)
        open(os.path.join(self.logger.log_dir, "model_summary.txt"), "w").write(
            model_summary_str
        )
        
    def register_grad_hooks(self, module, name_prefix=""):
        param_dict = defaultdict(list)


        for name, param in module.named_parameters():
            if param.requires_grad:
                prefix = name.split('.')[0]  # 比如 opt_layers, init_weight_hypernet
                param_dict[prefix].append((name, param))

        # 每个模块只注册前5个参数
        for prefix, params in param_dict.items():
            for name, param in params[:5]:
                def make_hook(n):
                    return lambda grad: print(
                        f"{n}: grad mean={grad.mean():.4e}, std={grad.std():.4e}"
                    )
                hook_name = f"{name_prefix}.{name}"
                param.register_hook(make_hook(hook_name))


    def configure_optimizers(self):
        output_modulation = list(self.target_net.modulation_weight_scale) + list(self.target_net.modulation_bias_scale) + list(self.target_net.modulation_bias_shift)
        # output_modulation = list(self.target_net.modulation_weight_scale) + list(self.target_net.modulation_weight_shift) + list(self.target_net.modulation_bias)        
        id_set = set(id(p) for p in output_modulation)
        if hasattr(self.loss_fn, 'logvar'):
            id_set.add(id(self.loss_fn.logvar))
        param_list = list(p for p in self.parameters() if id(p) not in id_set)
        return Adam(
            [
                {
                    "params": param_list,
                    "lr": self.args.lrate,
                },
                {
                    "params": output_modulation,
                    "lr": self.args.lrate_modulation,
                },
            ]
        )

    def encode(self, features, sample_posterior):
        # Encode using Conv3D encoder
        h = self.voxel_encoder(features)  # Output: (B, 2*dim_z, H', W') where the depth channel is meaned out.

        # Apply quantization conv (similar to IPVAE)
        moments = self.quant_conv(h)
        posterior = DiagonalGaussianDistribution(moments)

        if sample_posterior:
            z = posterior.sample()
        else:
            z = posterior.mode()

        return z, posterior

    def apply_data_converter(self, x):
        if hasattr(self, "data_converter"):
            if hasattr(self.data_converter, "batch_to_coordinates_and_features"):
                _, features = self.data_converter.batch_to_coordinates_and_features(x)
            else:
                features = self.data_converter(x)
        return features

    def predict_voxels(self, latent_embedding, coordinates=None):
        """Predict voxel values using the target network and hypernet"""
        # Use data converter coordinates if not provided
        if coordinates is None:
            coordinates = self.data_converter.coordinates

        # Expand coordinates for batch size
        if coordinates.dim() == 3:
            coordinates = coordinates.unsqueeze(0).expand(
                latent_embedding.shape[0], -1, -1, -1
            )

        # latent_embedding = einops.rearrange(latent_embedding, "b c h w -> b (h w) c")


        # Generate weights using hypernet
        weights = self.hypernet.generate_weights(latent_embedding)
        # self.hypernet.final_weight_dicts = weights
        # Predict densities using target network
        rgb = self.hypernet(coordinates)  # (N, 1)
        # rgb = self.target_net(coordinates)

        B = latent_embedding.shape[0]
        predicted_voxels = einops.rearrange(rgb, "b h w c -> b c h w")
        # predicted_voxels.retain_grad()
        return predicted_voxels

    def get_last_layer(self, weights):
        last_idx = self.target_net.get_last_layer_idx()
        weight = weights[0][f'layers.{last_idx}.weight'].permute(0, 2, 1)  
        bias = weights[0][f'layers.{last_idx}.bias'].unsqueeze(-1).permute(0, 2, 1)  
        wb = torch.cat([weight, bias], dim=1) 
        # reverse the weight and bias for backprop
        weight_reverse, bias_reverse = wb.split([weight.shape[1], bias.shape[1]], dim=1)
        weight_reverse = weight_reverse.permute(0, 2, 1)
        bias_reverse = bias_reverse.permute(0, 2, 1).squeeze(-1)

        weights[-1][f'layers.{last_idx}.weight'] = weight_reverse
        weights[-1][f'layers.{last_idx}.bias'] = bias_reverse
        # wb.retain_grad()
        # weight.retain_grad()
        # bias.retain_grad()
        return wb, weights

    def voxel_loss(self, gt_voxels, pred_voxels, posterior, split):
        """Use BernoulliLoss for voxel occupancy prediction"""
        # bb()
        return self.loss_fn(
            gt_voxels, 
            pred_voxels,
            posterior, 
            split
        )

    def get_warmup_lr(self):
        return 1.0

    def forward(self, input, sample_posterior=True):
        self.hypernet.cleanup()
        z, posterior = self.encode(input, sample_posterior)
        # bb()
        z = self.post_quant_conv(z)
        # z = self.downsample_conv(z)
        # bb()
        z = self.tokenizer(z)
        # bb()
        pred_voxels = self.predict_voxels(z, self.shared_coord)
        return pred_voxels, posterior
    
    def get_input(self, batch, k="image"):
        try:
            x = batch[k]
            if len(x.shape) == 3:
                x = x[..., None]
            x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
        except:
            x = batch[0]
        # Old from ldm repo:
        """x = batch[k]
        if len(x.shape) == 3:
            x = x[..., None]
        x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()"""
        return x
    
    def training_step(self, batch, batch_idx):
        data = self.get_input(batch, self.image_key)
        # bb()
        if hasattr(self, 'data_converter'):
            inputs = self.apply_data_converter(data)
            if hasattr(self.data_converter, "batch_to_coordinates_and_features"):
                data = inputs
        else:
            inputs = data
        pred_voxels, posterior = self(inputs)
        # bb()
        # train autoencoder278
        aeloss, log_dict_ae = self.voxel_loss(data, pred_voxels, posterior, split="train")


        self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
        self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False, sync_dist=True)

        return aeloss


    def val_test_step(self, batch, batch_idx, mode, dataloader_idx=None):
        data = self.get_input(batch, self.image_key)

        if hasattr(self, 'data_converter'):
            inputs = self.apply_data_converter(data)
            if hasattr(self.data_converter, "batch_to_coordinates_and_features"):
                data = inputs
        else:
            inputs = data

        pred_voxels, posterior = self(inputs)

        aeloss, log_dict_ae = self.voxel_loss(data, pred_voxels, posterior, split=mode)
        # bb()
        self.log("val/rec_loss", log_dict_ae["val/rec_loss"], sync_dist=True)
        self.log_dict(log_dict_ae, sync_dist=True)   

        return self.log_dict

    def validation_step(self, batch, batch_idx, dataloader_idx=None):
        return self.val_test_step(
            batch, batch_idx, "val", dataloader_idx=dataloader_idx
        )

    def test_step(self, batch, batch_idx, dataloader_idx=None):
        return self.val_test_step(
            batch, batch_idx, "test", dataloader_idx=dataloader_idx
        )
    
    @torch.no_grad()
    def log_images(self, batch, only_inputs=False, super_resolution=False, **kwargs):
        log = dict()
        x = self.get_input(batch, self.image_key)
        x = x.to(self.device)

        log["inputs"] = x

        if hasattr(self, 'data_converter'):
            _, x = self.data_converter.batch_to_coordinates_and_features(x)

        if not only_inputs:
            xrec, posterior= self(x)

            if x.shape[1] > 3:
                # colorize with random projection
                assert xrec.shape[1] > 3
                x = self.to_rgb(x)
                xrec = self.to_rgb(xrec)

            log["reconstructions"] = xrec

            if super_resolution:
                # Super reconstruction
                res = 2 * self.encoder.resolution
                coord = make_coord_grid((res,res), (-1, 1)).to(x.device)
                xrec_super = self.decode(z, coord=coord)
                if x.shape[1] > 3:
                    xrec_super = self.to_rgb(xrec_super)
                log["super_reconstructions"] = xrec_super
                
        return log


    def to_rgb(self, x):
        assert self.image_key == "segmentation"
        if not hasattr(self, "colorize"):
            self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
        x = F.conv2d(x, weight=self.colorize)
        x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
        return x