import os
import einops
import lightning
from lightning.pytorch.utilities.types import TRAIN_DATALOADERS
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
import mcubes
import trimesh
import torch.nn.functional as F
from tqdm import trange

from hypernet_core import Hypernet, SIREN_tar

from pdb import set_trace as bb
from time import time

from utils.examples.radiance_fields.mlp import MLPDensityField
from utils.model_tools import hook_fn_decorator, summary

import logging
from lpips import LPIPS

from third_party.LDMI.ldm.modules.diffusionmodules.model import Encoder
from third_party.LDMI.utils.geometry import make_coord_grid

from third_party.LDMI.ldm.data.data_converters.conversion import GridDataConverter

from third_party.LDMI.ldm.modules.distributions.distributions import (
    DiagonalGaussianDistribution,
)

from third_party.LDMI.ldm.modules.losses import LPIPSWithDiscriminator
from third_party.LDMI.ldm.modules.decoders.tokenizers.latent_tokenizer import LatentTokenizer

eps = 1e-8
img2mse = lambda x, y: torch.mean((x - y) ** 2)
img2l1 = lambda x, y: torch.mean(torch.abs(x - y))
mse2psnr = lambda x: -10.0 * torch.log(x) / torch.log(torch.Tensor([10.0]).to(x.device))
psnr = lambda x, y, eps=1e-8: -10.0 * torch.log10(torch.mean(torch.square(x - y)) + eps)


def create_siren_mlp(args):
    """
    Init SIREN MLP and HyperOpt.
    """
    siren = SIREN_tar(
        D=args.netdepth,
        W=args.netwidth,
        input_ch=args.input_ch,
        output_ch=args.output_ch,
        out_bias=args.out_bias,
        omega=args.omega,
    )
    model = Hypernet(
        target_net=siren,
        ftask_dim=args.ftask_dim,
        weight_dim=args.weight_dim,
        batch_size=args.batch_size,
        deriv_hidden_dim=args.deriv_hidden_dim,
        driv_num_layers=args.driv_num_layers,
        codec_hidden_dim=args.codec_hidden_dim,
        codec_num_layers=args.codec_num_layers,
        num_layers=args.num_ho,
        weight_split_dim=args.weight_split_dim,
        ftask_adapter_num_layers=args.ftask_adapter_num_layers,
        enable_each_layer_lr=args.enable_each_layer_lr,
        enable_weight_init=args.enable_weight_init,
        use_hyper_crossattn=args.use_hyper_crossattn,  
        slice_dl_din=args.slice_dl_din,
    )

    return siren, model

class HoCelebAHQ(lightning.LightningModule):

    def __init__(self, args, data_module=None):
        super().__init__()
        self.args = args
        self.voxel_resolution = getattr(args, "voxel_resolution", 64)
        self.target_net, self.hypernet = create_siren_mlp(args)
        self.automatic_optimization = False
        self.register_buffer('shared_coord', make_coord_grid((self.voxel_resolution, self.voxel_resolution), (-1, 1)), persistent=False)

        self.voxel_encoder = Encoder(
            double_z=True,
            z_channels=3,
            resolution=64,
            in_channels=3,
            out_ch=3,
            ch=64,
            ch_mult=[ 1,2,4 ],  
            num_res_blocks=2,
            attn_resolutions=[ ],
            dropout=0.1
        )

        self.image_key = "image"
        
        self.quant_conv = torch.nn.Conv2d(
            2 * 3, 2 * 3, 1
        )
        self.post_quant_conv = torch.nn.Conv2d(3, 3, 1)
        self.atten = getattr(args, "atten", False)
        self.tokenizer = LatentTokenizer(latent_dim=3, latent_size=16, patch_size=2, dim=192, n_head=4, head_dim=32,  atten=self.atten)
        self.kl_weight = getattr(args, "kl_weight", 1.0e-04)
        self.loss_fn = LPIPSWithDiscriminator(disc_start=1.0e+4, kl_weight=self.kl_weight, disc_weight=0.75, disc_num_layers=2, disc_dropout=0.2)

    def on_fit_start(self):
        self.get_model_summary()

    def get_model_summary(self):
        if not self.global_rank == 0:
            return
        if self.args.enable_weight_init:
            model_summary_dict = {
                "main-hypernet": summary(self.hypernet),
                "main-target_net": summary(self.target_net),
                "voxel-encoder": summary(self.voxel_encoder),
            }
        else:
            model_summary_dict = {
                "main-hypernet": summary(self.hypernet),
                "main-target_net": summary(self.target_net),
                "voxel-encoder": summary(self.voxel_encoder),
            }

        def format_param_count(count):
            if count >= 1_000_000:
                return f"{count/1_000_000:.2f}M"
            elif count >= 1_000:
                return f"{count/1_000:.2f}K"
            else:
                return f"{count}"

        model_summary_str = "Model Component Summary:\n"
        model_summary_str += "-" * 50 + "\n"
        model_summary_str += f"{'Component':<25} | {'Parameter Count':>20}\n"
        model_summary_str += "-" * 50 + "\n"

        main_components = ["main-hypernet", "voxel-encoder"]
        total_params = sum(
            model_summary_dict[comp]
            for comp in main_components
            if comp in model_summary_dict
        )

        for component, param_count in model_summary_dict.items():
            param_str = format_param_count(param_count)
            model_summary_str += f"{component:<25} | {param_str:>20}\n"

        model_summary_str += "-" * 50 + "\n"
        total_str = format_param_count(total_params)
        model_summary_str += f"{'Total (main components)':<25} | {total_str:>20}\n"
        model_summary_str += "-" * 50 + "\n"

        print(model_summary_str)
        open(os.path.join(self.logger.log_dir, "model_summary.txt"), "w").write(
            model_summary_str
        )

    def configure_optimizers(self):
        output_modulation = list(self.target_net.modulation_weight_scale) + list(self.target_net.modulation_bias_scale) + list(self.target_net.modulation_bias_shift)
        id_set = set(id(p) for p in output_modulation)
        disc_params = list(self.loss_fn.discriminator.parameters())
        disc_set = set(id(p) for p in disc_params)
        param_list = list(p for p in self.parameters() if id(p) not in id_set and id(p) not in disc_set)
 
        # return Adam(self.parameters(), lr=self.args.lrate, betas=(0.9, 0.999))
        optimizer_main = Adam(
            [
                {"params": param_list, "lr": self.args.lrate},
                {"params": output_modulation, "lr": self.args.lrate_modulation},
            ]
        )
        optimizer_disc = Adam(
            [
                {"params": disc_params, "lr": self.args.lrate},
            ]
        )
        return [optimizer_main, optimizer_disc]

    def encode(self, features, sample_posterior):
        h = self.voxel_encoder(features)  # Output: (B, 2*dim_z, H', W') where the depth channel is meaned out.

        # Apply quantization conv (similar to IPVAE)
        moments = self.quant_conv(h)
        posterior = DiagonalGaussianDistribution(moments)

        if sample_posterior:
            z = posterior.sample()
        else:
            z = posterior.mode()

        return z, posterior


    def predict_voxels(self, latent_embedding, coordinates=None):
        """Predict voxel values using the target network and hypernet"""
        # Use data converter coordinates if not provided
        res = self.voxel_resolution
        B = latent_embedding.shape[0]
        if coordinates is None:
            coordinates = self.shared_coord
        if len(coordinates.shape) == 3:
            coordinates = einops.repeat(coordinates, 'h w d -> b h w d', b=B)

        weights = self.hypernet.generate_weights(latent_embedding)
        self.last_layer, weights = self.get_last_layer(weights)
        rgb = self.hypernet(coordinates)  # (N, 1)
        predicted_voxels = rgb.permute(0, 3, 1, 2)
        return predicted_voxels

    def get_last_layer(self, weights):
        last_idx = self.target_net.get_last_layer_idx()
        weight = weights[0][f'layers.{last_idx}.weight'].permute(0, 2, 1)  
        bias = weights[0][f'layers.{last_idx}.bias'].unsqueeze(-1).permute(0, 2, 1)  
        wb = torch.cat([weight, bias], dim=1) 
        # reverse the weight and bias for backprop
        weight_reverse, bias_reverse = wb.split([weight.shape[1], bias.shape[1]], dim=1)
        weight_reverse = weight_reverse.permute(0, 2, 1)
        bias_reverse = bias_reverse.permute(0, 2, 1).squeeze(-1)

        weights[-1][f'layers.{last_idx}.weight'] = weight_reverse
        weights[-1][f'layers.{last_idx}.bias'] = bias_reverse

        return wb, weights

    def voxel_loss(self, gt_voxels, pred_voxels, posterior, optimizer_idx, global_step, last_layer, split="train"):
        """Use BernoulliLoss for voxel occupancy prediction"""
        # bb()
        return self.loss_fn(
            gt_voxels, 
            pred_voxels,
            posterior, 
            optimizer_idx=optimizer_idx, 
            global_step=self.global_step, 
            last_layer=self.last_layer, 
            split=split, 
            cond=None
        )

    def get_warmup_lr(self):
        return 1.0

    def forward(self, input, sample_posterior=True):
        self.hypernet.cleanup()
        z, posterior = self.encode(input, sample_posterior)
        z = self.post_quant_conv(z)
        z = self.tokenizer(z)
        pred_voxels = self.predict_voxels(z)
        return pred_voxels, posterior
    
    def get_input(self, batch, k="image"):
        try:
            x = batch[k]
            if len(x.shape) == 3:
                x = x[..., None]
            x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
        except:
            x = batch[0]
        # Old from ldm repo:
        """x = batch[k]
        if len(x.shape) == 3:
            x = x[..., None]
        x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()"""
        return x
    
    def training_step(self, batch, batch_idx):
        optimizer_main, optimizer_disc = self.optimizers()
        voxels = self.get_input(batch, self.image_key)
        pred_voxels, posterior = self(voxels)
        aeloss, log_dict_ae = self.voxel_loss(voxels, pred_voxels, posterior, 
                                        optimizer_idx=0, 
                                        global_step=self.global_step,
                                        last_layer=self.last_layer, 
                                        split="train")
        self.manual_backward(aeloss)
        optimizer_main.step()
        optimizer_main.zero_grad()
        # bb()
        self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
        self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False, sync_dist=True)

        # train discriminator
        discloss, log_dict_disc = self.voxel_loss(voxels, pred_voxels, posterior, 
                                        optimizer_idx=1, 
                                        global_step=self.global_step,
                                        last_layer=self.last_layer, 
                                        split="train")
        self.manual_backward(discloss)
        optimizer_disc.step()
        optimizer_disc.zero_grad()
        self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True, sync_dist=True)
        self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False, sync_dist=True)


    def val_test_step(self, batch, batch_idx, mode, dataloader_idx=None):
        voxels = self.get_input(batch, self.image_key)

        pred_voxels, posterior = self(voxels)

        aeloss, log_dict_ae = self.voxel_loss(voxels, pred_voxels, posterior, 
                                        optimizer_idx=0, 
                                        global_step=self.global_step,
                                        last_layer=self.last_layer, 
                                        split="val")
        self.log("val/rec_loss", log_dict_ae["val/rec_loss"], sync_dist=True)
        self.log_dict(log_dict_ae, sync_dist=True)   

        if hasattr(self.loss_fn, 'discriminator'):
            discloss, log_dict_disc = self.voxel_loss(voxels, pred_voxels, posterior, 
                                            optimizer_idx=1, 
                                            global_step=self.global_step,
                                            last_layer=self.last_layer, 
                                            split="val")
            self.log_dict(log_dict_disc, sync_dist=True)
        return self.log_dict

    def validation_step(self, batch, batch_idx, dataloader_idx=None):
        return self.val_test_step(
            batch, batch_idx, "val", dataloader_idx=dataloader_idx
        )

    def test_step(self, batch, batch_idx, dataloader_idx=None):
        return self.val_test_step(
            batch, batch_idx, "test", dataloader_idx=dataloader_idx
        )
    
    @torch.no_grad()
    def log_images(self, batch, only_inputs=False, super_resolution=False, **kwargs):
        log = dict()
        x = self.get_input(batch, self.image_key)
        x = x.to(self.device)

        log["inputs"] = x

        if not only_inputs:
            z_enc, posterior = self.encode(x, sample_posterior=False)
            z = self.post_quant_conv(z_enc)
            z = self.tokenizer(z)
            # Original size
            xrec = self.predict_voxels(z)

            if x.shape[1] > 3:
                # colorize with random projection
                assert xrec.shape[1] > 3
                x = self.to_rgb(x)
                xrec = self.to_rgb(xrec)
            rand_latent = torch.randn_like(z_enc)
            rand_latent = self.post_quant_conv(rand_latent)
            rand_latent = self.tokenizer(rand_latent)
            log["samples"] = self.predict_voxels(rand_latent)
            log["reconstructions"] = xrec

            if super_resolution:
                # Super reconstruction
                res = 16 * self.voxel_resolution
                coord = make_coord_grid((res,res), (-1, 1)).to(x.device)
                xrec_super = self.predict_voxels(z, coordinates=coord)
                if x.shape[1] > 3:
                    xrec_super = self.to_rgb(xrec_super)
                log["super_reconstructions"] = xrec_super
                
        return log

    def to_rgb(self, x):
        assert self.image_key == "segmentation"
        if not hasattr(self, "colorize"):
            self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
        x = F.conv2d(x, weight=self.colorize)
        x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
        return x