import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
import numpy as np
from packaging import version
from ldm.modules.ema import LitEma
from contextlib import contextmanager
import copy
from einops import rearrange
from typing import List

from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer

from ldm.modules.diffusionmodules.model_joint import *

from ldm.util import instantiate_from_config


class VQDiGANJoint(pl.LightningModule):
    def __init__(self,
                 ddconfig,
                 lossconfig,
                 n_embed,
                 embed_dim,
                 num_directions,
                 ckpt_path=None,
                 ignore_keys=[],
                 image_key="image",
                 colorize_nlabels=None,
                 monitor=None,
                 batch_resize_range=None,
                 scheduler_config=None,
                 lr_g_factor=1.0,
                 remap=None,
                 sane_index_shape=False, # tell vector quantizer to return indices as bhw
                 use_ema=False,
                 inference_mode=False
                 ):
        super().__init__()
        self.embed_dim = embed_dim # 3
        self.n_embed = n_embed # 8192
        self.image_key = image_key # 'image'
        self.num_directions = num_directions
        num_classes_encoder = ddconfig.pop("num_classes_encoder")
        num_classes_decoder = ddconfig.pop("num_classes_decoder")
        ddconfig["num_classes"] = num_classes_encoder
        self.encoder = MultiClass_DiEncoder(**ddconfig)
        ddconfig["num_classes"] = num_classes_decoder
        self.decoder = MultiClass_DiDecoderWithResidual(**ddconfig)
        self.loss = instantiate_from_config(lossconfig)
        self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
                                        remap=remap,
                                        sane_index_shape=sane_index_shape)
        self.quant_conv = torch.nn.Conv3d(ddconfig["z_channels"], embed_dim, 1)
        self.post_quant_conv = torch.nn.Conv3d(embed_dim, ddconfig["z_channels"], 1)
        if colorize_nlabels is not None:
            assert type(colorize_nlabels)==int
            self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
        if monitor is not None:
            self.monitor = monitor
        self.batch_resize_range = batch_resize_range
        if self.batch_resize_range is not None:
            print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")

        self.use_ema = use_ema
        if self.use_ema:
            self.model_ema = LitEma(self)
            print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")

        self.decoder_joint = MultiClass_DiDecoderWithResidual_Joint(**ddconfig)
        
        if not inference_mode:
            assert ckpt_path is not None, "ckpt_path is required"
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
            
            state_dict = self.decoder.state_dict()
            missing, unexpected = self.decoder_joint.load_state_dict(state_dict, strict=False)
            print(f"Restored from {ckpt_path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
            if len(missing) > 0:
                print(f"Missing Keys: {missing}")
                print(f"Unexpected Keys: {unexpected}")
        
        self.scheduler_config = scheduler_config

    @contextmanager
    def ema_scope(self, context=None):
        if self.use_ema:
            self.model_ema.store(self.parameters())
            self.model_ema.copy_to(self)
            if context is not None:
                print(f"{context}: Switched to EMA weights")
        try:
            yield None
        finally:
            if self.use_ema:
                self.model_ema.restore(self.parameters())
                if context is not None:
                    print(f"{context}: Restored training weights")

    def init_from_ckpt(self, path, ignore_keys=list()):
        sd = torch.load(path, map_location="cpu")["state_dict"]
        keys = list(sd.keys())
        for k in keys:
            for ik in ignore_keys:
                if k.startswith(ik):
                    print("Deleting key {} from state_dict.".format(k))
                    del sd[k]
        missing, unexpected = self.load_state_dict(sd, strict=False)
        missing = [n for n in missing if 'decoder_joint' not in n]
        print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
        if len(missing) > 0:
            print(f"Missing Keys: {missing}")
            print(f"Unexpected Keys: {unexpected}")

    def on_train_batch_end(self, *args, **kwargs):
        if self.use_ema:
            self.model_ema(self)

    def encode(self, x, tensor_mapping_index, ret_feature=False):
        phi_list = None
        if ret_feature:
            h, phi_list = self.encoder(x, tensor_mapping_index, ret_feature)
        else:
            h = self.encoder(x, tensor_mapping_index)
        h = self.quant_conv(h)
        quant, emb_loss, info = self.quantize(h)
        if ret_feature:
            return quant, emb_loss, info, phi_list
        return quant, emb_loss, info

    def encode_to_prequant(self, x):
        h = self.encoder(x)
        h = self.quant_conv(h)
        return h
    
    def encode_cond(self, b0, b1000x, b1000y, b1000z, b0_mapping_index, b1000x_mapping_index, b1000y_mapping_index, b1000z_mapping_index):
        batch_inputs = torch.cat([b0, b1000x, b1000y, b1000z], dim=0)
        batch_mapping_indices = torch.cat([b0_mapping_index, b1000x_mapping_index, b1000y_mapping_index, b1000z_mapping_index], dim=0)
        _, _, _, batch_phi_list = self.encode(batch_inputs, batch_mapping_indices, ret_feature=True)
        
        batch_size = b0.shape[0]
        phi_b0_list = [phi[:batch_size] for phi in batch_phi_list]
        phi_b1000x_list = [phi[batch_size:2*batch_size] for phi in batch_phi_list]
        phi_b1000y_list = [phi[2*batch_size:3*batch_size] for phi in batch_phi_list]
        phi_b1000z_list = [phi[3*batch_size:] for phi in batch_phi_list]
        
        cond_dict = dict(
            phi_b0_list = phi_b0_list,
            phi_b1000x_list = phi_b1000x_list,
            phi_b1000y_list = phi_b1000y_list,
            phi_b1000z_list = phi_b1000z_list,
        )
        
        return cond_dict
    
    def decode(self, quant, tensor_mapping_index, cond_dict):
        quant = self.post_quant_conv(quant)
        dec = self.decoder_joint(quant, y=tensor_mapping_index, cond_dict=cond_dict)
        return dec

    def decode_code(self, code_b):
        quant_b = self.quantize.embed_code(code_b)
        dec = self.decode(quant_b)
        return dec
    
    def forward(self, input, tensor_mapping_index, cond_dict, return_pred_indices=False):
        input = rearrange(input, 'b a c d h w -> (b a) c d h w')
        tensor_mapping_index = rearrange(tensor_mapping_index, 'b a -> (b a)')
        quant, diff, (_,_,ind), phi_list = self.encode(input, tensor_mapping_index, ret_feature=True)
        dec = self.decode(quant, tensor_mapping_index, cond_dict)
        if return_pred_indices:
            return dec, diff, ind, phi_list
        return dec, diff, phi_list

    def get_input(self, batch, k):
        x = batch[k]
        if len(x.shape) == 3:
            x = x[..., None]
        return x

    def training_step(self, batch, batch_idx):
        x = torch.stack([
            self.get_input(batch, 'tensor_xx'), 
            self.get_input(batch, 'tensor_yy'), 
            self.get_input(batch, 'tensor_zz'), 
            self.get_input(batch, 'tensor_xy'), 
            self.get_input(batch, 'tensor_xz'), 
            self.get_input(batch, 'tensor_yz')], dim=1)
        x_tensor_mapping_index = torch.stack([
            self.get_input(batch, 'tensor_mapping_index_xx'), 
            self.get_input(batch, 'tensor_mapping_index_yy'), 
            self.get_input(batch, 'tensor_mapping_index_zz'), 
            self.get_input(batch, 'tensor_mapping_index_xy'), 
            self.get_input(batch, 'tensor_mapping_index_xz'), 
            self.get_input(batch, 'tensor_mapping_index_yz')
        ], dim=1)
        x_b0 = self.get_input(batch, 'b0')
        x_b1000x = self.get_input(batch, 'b1000x')
        x_b1000y = self.get_input(batch, 'b1000y')
        x_b1000z = self.get_input(batch, 'b1000z')
        b0_mapping_index = torch.ones_like(x_tensor_mapping_index[:, 0]) * 2
        b1000x_mapping_index = torch.ones_like(x_tensor_mapping_index[:, 0]) * 3
        b1000y_mapping_index = torch.ones_like(x_tensor_mapping_index[:, 0]) * 3
        b1000z_mapping_index = torch.ones_like(x_tensor_mapping_index[:, 0]) * 3
        
        cond_dict = self.encode_cond(x_b0, x_b1000x, x_b1000y, x_b1000z, b0_mapping_index, b1000x_mapping_index, b1000y_mapping_index, b1000z_mapping_index)
        for i in cond_dict.keys():
            cond_dict[i] = [phi.repeat_interleave(self.num_directions//2, dim=0) for phi in cond_dict[i]]

        xrec, qloss, phi_list = self(x, x_tensor_mapping_index, cond_dict)
        
        mask = self.get_input(batch, 'mask')
        aeloss, log_dict_ae = self.loss(x, xrec, mask, self.global_step, split="train")
        
        self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
        return aeloss

    def validation_step(self, batch, batch_idx):
        log_dict = self._validation_step(batch, batch_idx)
        with self.ema_scope():
            log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
        return log_dict

    def _validation_step(self, batch, batch_idx, suffix=""):
        x = torch.stack([
            self.get_input(batch, 'tensor_xx'), 
            self.get_input(batch, 'tensor_yy'), 
            self.get_input(batch, 'tensor_zz'), 
            self.get_input(batch, 'tensor_xy'), 
            self.get_input(batch, 'tensor_xz'), 
            self.get_input(batch, 'tensor_yz')], dim=1)
        x_tensor_mapping_index = torch.stack([
            self.get_input(batch, 'tensor_mapping_index_xx'), 
            self.get_input(batch, 'tensor_mapping_index_yy'), 
            self.get_input(batch, 'tensor_mapping_index_zz'), 
            self.get_input(batch, 'tensor_mapping_index_xy'), 
            self.get_input(batch, 'tensor_mapping_index_xz'), 
            self.get_input(batch, 'tensor_mapping_index_yz')
        ], dim=1)
        x_b0 = self.get_input(batch, 'b0')
        x_b1000x = self.get_input(batch, 'b1000x')
        x_b1000y = self.get_input(batch, 'b1000y')
        x_b1000z = self.get_input(batch, 'b1000z')
        b0_mapping_index = torch.ones_like(x_tensor_mapping_index[:, 0]) * 2
        b1000x_mapping_index = torch.ones_like(x_tensor_mapping_index[:, 0]) * 3
        b1000y_mapping_index = torch.ones_like(x_tensor_mapping_index[:, 0]) * 3
        b1000z_mapping_index = torch.ones_like(x_tensor_mapping_index[:, 0]) * 3
        cond_dict = self.encode_cond(x_b0, x_b1000x, x_b1000y, x_b1000z, b0_mapping_index, b1000x_mapping_index, b1000y_mapping_index, b1000z_mapping_index)
        for i in cond_dict.keys():
            cond_dict[i] = [phi.repeat_interleave(self.num_directions//2, dim=0) for phi in cond_dict[i]]
        xrec, qloss, phi_list = self(x, x_tensor_mapping_index, cond_dict)
        
        mask = self.get_input(batch, 'mask')
        aeloss, log_dict_ae = self.loss(x, xrec, mask, self.global_step, split="val"+suffix)

        rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
        self.log(f"val{suffix}/rec_loss", rec_loss,
                   prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
        self.log(f"val{suffix}/aeloss", aeloss,
                   prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
        if version.parse(pl.__version__) >= version.parse('1.4.0'):
            del log_dict_ae[f"val{suffix}/rec_loss"]
        self.log_dict(log_dict_ae)
        return self.log_dict

    def configure_optimizers(self):
        lr = self.learning_rate
        print("lr", lr)
        
        params = [p for p in self.decoder_joint.joint.parameters()]
        for decoder in self.decoder_joint.decoders:
            params += [p for p in decoder.norm_out.parameters()]
            params += [p for p in decoder.conv_out.parameters()]
        opt_ae = torch.optim.Adam(params, lr=lr, betas=(0.5, 0.9))

        if self.scheduler_config is not None:
            scheduler = instantiate_from_config(self.scheduler_config)

            print("Setting up LambdaLR scheduler...")
            scheduler = [
                {
                    'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
                    'interval': 'step',
                    'frequency': 1
                }
            ]
            return [opt_ae], scheduler
        return [opt_ae], []
        
    def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
        log = dict()
        x = torch.stack([
            self.get_input(batch, 'tensor_xx'), 
            self.get_input(batch, 'tensor_yy'), 
            self.get_input(batch, 'tensor_zz'), 
            self.get_input(batch, 'tensor_xy'), 
            self.get_input(batch, 'tensor_xz'), 
            self.get_input(batch, 'tensor_yz')], dim=1)
        x_tensor_mapping_index = torch.stack([
            self.get_input(batch, 'tensor_mapping_index_xx'), 
            self.get_input(batch, 'tensor_mapping_index_yy'), 
            self.get_input(batch, 'tensor_mapping_index_zz'), 
            self.get_input(batch, 'tensor_mapping_index_xy'), 
            self.get_input(batch, 'tensor_mapping_index_xz'), 
            self.get_input(batch, 'tensor_mapping_index_yz')
        ], dim=1)
        x_b0 = self.get_input(batch, 'b0')
        x_b1000x = self.get_input(batch, 'b1000x')
        x_b1000y = self.get_input(batch, 'b1000y')
        x_b1000z = self.get_input(batch, 'b1000z')
        b0_mapping_index = torch.ones_like(x_tensor_mapping_index[:, 0]) * 2
        b1000x_mapping_index = torch.ones_like(x_tensor_mapping_index[:, 0]) * 3
        b1000y_mapping_index = torch.ones_like(x_tensor_mapping_index[:, 0]) * 3
        b1000z_mapping_index = torch.ones_like(x_tensor_mapping_index[:, 0]) * 3
        
        if only_inputs:
            log["inputs"] = x
            return log
        
        cond_dict = self.encode_cond(x_b0, x_b1000x, x_b1000y, x_b1000z, b0_mapping_index, b1000x_mapping_index, b1000y_mapping_index, b1000z_mapping_index)
        for i in cond_dict.keys():
            cond_dict[i] = [phi.repeat_interleave(self.num_directions//2, dim=0) for phi in cond_dict[i]]
        xrec, _, _ = self(x, x_tensor_mapping_index, cond_dict)
        
        log["inputs"] = x.squeeze(2)
        log["reconstructions"] = xrec.squeeze(2)
        log["condtion"] = torch.cat([x_b0, x_b1000x, x_b1000y, x_b1000z], dim=1)
        return log
        
    
class VQDiGANJointInterface(VQDiGANJoint):
    def __init__(self, **kwargs):
        ckpt_path = kwargs.pop("ckpt_path")
        super().__init__(**kwargs, inference_mode=True)
        print(f"Loading checkpoint from {ckpt_path}")
        self.init_from_ckpt(ckpt_path, ignore_keys=[])

    def init_from_ckpt(self, path, ignore_keys=list()):
        sd = torch.load(path, map_location="cpu")["state_dict"]
        keys = list(sd.keys())
        for k in keys:
            for ik in ignore_keys:
                if k.startswith(ik):
                    print("Deleting key {} from state_dict.".format(k))
                    del sd[k]
        missing, unexpected = self.load_state_dict(sd, strict=False)
        print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
        if len(missing) > 0:
            print(f"Missing Keys: {missing}")
            print(f"Unexpected Keys: {unexpected}")

    def encode(self, x, tensor_mapping_index, ret_feature=False):
        phi_list = None
        if ret_feature:
            h, phi_list = self.encoder(x, tensor_mapping_index, ret_feature)
        else:
            h = self.encoder(x, tensor_mapping_index)
        h = self.quant_conv(h)
        if ret_feature:
            return h, phi_list
        return h
    
    def encode_cond(self, b_x, tensor_mapping_index):
        batch_size, num_directions, num_channels, height, width, depth = b_x.shape
        b0_x, b1000x_x, b1000y_x, b1000z_x = b_x.unbind(dim=1)
        batch_inputs = torch.cat([b0_x, b1000x_x, b1000y_x, b1000z_x], dim=0)
        b0_mapping_index = torch.ones_like(tensor_mapping_index) * 2
        b1000x_mapping_index = torch.ones_like(tensor_mapping_index) * 3
        b1000y_mapping_index = torch.ones_like(tensor_mapping_index) * 3
        b1000z_mapping_index = torch.ones_like(tensor_mapping_index) * 3
        batch_mapping_indices = torch.cat([b0_mapping_index, b1000x_mapping_index, b1000y_mapping_index, b1000z_mapping_index], dim=0)
        h, batch_phi_list = self.encode(batch_inputs, batch_mapping_indices, ret_feature=True)
        h = h.reshape(batch_size, num_directions, -1, *h.shape[-3:])
        
        phi_b0_list = [phi[:batch_size] for phi in batch_phi_list]
        phi_b1000x_list = [phi[batch_size:2*batch_size] for phi in batch_phi_list]
        phi_b1000y_list = [phi[2*batch_size:3*batch_size] for phi in batch_phi_list]
        phi_b1000z_list = [phi[3*batch_size:] for phi in batch_phi_list]
        
        return h, phi_b0_list, phi_b1000x_list, phi_b1000y_list, phi_b1000z_list

    def encode_multidirection(self, x, tensor_mapping_index, ret_feature=True, full=True):
        if full:
            b_x = x[:, :4]
            b_tensor_mapping_index = tensor_mapping_index[:, 0]
            b_z, phi_b0_list, phi_b1000x_list, phi_b1000y_list, phi_b1000z_list = self.encode_cond(b_x, b_tensor_mapping_index)
            cond_dict = dict(
                phi_b0_list = phi_b0_list,
                phi_b1000x_list = phi_b1000x_list,
                phi_b1000y_list = phi_b1000y_list,
                phi_b1000z_list = phi_b1000z_list,
            )
            
            tensor_x = x[:, 4:]
            tensor_mapping_index = tensor_mapping_index[:, 4:]
            batch_size, num_directions, num_channels, height, width, depth = tensor_x.shape
            tensor_x = tensor_x.reshape(batch_size * num_directions, num_channels, height, width, depth)
            tensor_mapping_index = tensor_mapping_index.reshape(batch_size * num_directions)
            tensor_z = self.encoder(tensor_x, tensor_mapping_index)
            tensor_z = self.quant_conv(tensor_z)
            tensor_z = tensor_z.reshape(batch_size, num_directions, -1, *tensor_z.shape[-3:])
            
            z = torch.cat([b_z, tensor_z], dim=1)
            if ret_feature:
                return z, cond_dict
            return z
        else:
            tensor_mapping_index = tensor_mapping_index[:, 0]
            h, phi_b0_list, phi_b1000x_list, phi_b1000y_list, phi_b1000z_list = self.encode_cond(x, tensor_mapping_index)
            h = self.quant_conv(h)
            cond_dict = dict(
                phi_b0_list = phi_b0_list,
                phi_b1000x_list = phi_b1000x_list,
                phi_b1000y_list = phi_b1000y_list,
                phi_b1000z_list = phi_b1000z_list,
            )
            if ret_feature:
                return h, cond_dict
            return h

    def decode(self, h, tensor_mapping_index, cond_dict, force_not_quantize=False):
        # also go through quantization layer
        if not force_not_quantize:
            quant, emb_loss, info = self.quantize(h)
        else:
            quant = h
        quant = self.post_quant_conv(quant)
        dec = self.decoder(quant, y=tensor_mapping_index, cond_dict=cond_dict)
        return dec

    def decode_multidirection(self, h, tensor_mapping_index, cond_dict, force_not_quantize=False):
        batch_size, num_directions, num_channels, height, width, depth = h.shape
        h = h.reshape(batch_size * num_directions, num_channels, height, width, depth)
        tensor_mapping_index = tensor_mapping_index.reshape(batch_size * num_directions)
        for i in cond_dict.keys():
            cond_dict[i] = [phi.expand(batch_size * num_directions//2, *phi.shape[1:]) for phi in cond_dict[i]]
            
        if not force_not_quantize:
            quant, emb_loss, info = self.quantize(h)
        else:
            quant = h
            
        quant = self.post_quant_conv(quant)
        dec = self.decoder_joint(quant, y=tensor_mapping_index, cond_dict=cond_dict)
        return dec
    
    def enable_gradient_checkpointing(self):
        self.decoder_joint.enable_gradient_checkpointing()

class FinetuneDecoderJoint(pl.callbacks.Callback):
    """
    A PyTorch Lightning callback for finetuning specific joint components of the VQDiGANJoint model's decoder.

    This callback follows a two-step process:
    1.  **Freeze:** Initially, it freezes all parameters in the entire model before training begins.
    2.  **Make Trainable:** It then makes specific parameters trainable by setting requires_grad=True for parameters
        related to the 'joint' module and the final output layers (`norm_out`, `conv_out`) within the `decoder_joint`.
    3.  **Optional Unfreeze:** At a user-specified epoch, it ensures the specified parameters remain trainable.

    This precisely matches the parameter selection logic in the `VQDiGANJoint.configure_optimizers` method,
    ensuring that only the intended parts of the model are updated during the finetuning phase.

    Note: This callback works by directly manipulating parameter requires_grad attributes, allowing for
    fine-grained parameter control without relying on PyTorch Lightning's BaseFinetuning.

    Args:
        unfreeze_at_epoch (int): The epoch at which to ensure the specified decoder parameters are trainable. Defaults to 0.
        verbose (bool): If True, prints messages indicating when layers are frozen and unfrozen. Defaults to True.
    """
    def __init__(self, unfreeze_at_epoch: int = 0, verbose: bool = False):
        super().__init__()
        if unfreeze_at_epoch < 0:
            raise ValueError(f"unfreeze_at_epoch must be a non-negative integer, but got {unfreeze_at_epoch}")
        self.unfreeze_at_epoch = unfreeze_at_epoch
        self.verbose = verbose

    def _get_finetune_params(self, pl_module: 'pl.LightningModule') -> List[torch.Tensor]:
        """Helper function to gather the specific parameters to be finetuned."""
        params = []
        
        # 1. Add parameters from the joint module
        if hasattr(pl_module, 'decoder_joint') and hasattr(pl_module.decoder_joint, 'joint'):
            if pl_module.decoder_joint.joint is not None:
                joint_params = list(pl_module.decoder_joint.joint.parameters())
                if joint_params:  # Only extend if we have parameters
                    params.extend(joint_params)

        # 2. Add parameters from the output layers of each individual decoder path
        if hasattr(pl_module, 'decoder_joint') and hasattr(pl_module.decoder_joint, 'decoders'):
            for decoder in pl_module.decoder_joint.decoders:
                if hasattr(decoder, 'norm_out') and decoder.norm_out is not None:
                    norm_params = list(decoder.norm_out.parameters())
                    if norm_params:  # Only extend if we have parameters
                        params.extend(norm_params)
                if hasattr(decoder, 'conv_out') and decoder.conv_out is not None:
                    conv_params = list(decoder.conv_out.parameters())
                    if conv_params:  # Only extend if we have parameters
                        params.extend(conv_params)

        # Ensure we return a list of tensors, not empty or containing problematic items
        return [p for p in params if isinstance(p, torch.Tensor) and p.numel() > 0]

    def on_fit_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
        """
        Called at the beginning of training. This method freezes all parameters and then
        makes specific parameters trainable.
        """
        if self.verbose:
            print(f"✅ FinetuneDecoderJoint: Setting up finetuning...")

        # Freeze all parameters
        for param in pl_module.parameters():
            param.requires_grad_(False)

        # Make specific parameters trainable
        finetune_params = self._get_finetune_params(pl_module)
        for param in finetune_params:
            param.requires_grad_(True)

        if self.verbose:
            print(f"✅ FinetuneDecoderJoint: Frozen all parameters, made {len(finetune_params)} parameters trainable.")

    def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
        """
        Called at the beginning of each training epoch. This ensures the specified parameters
        remain trainable at the designated epoch (defensive programming).
        """
        if trainer.current_epoch == self.unfreeze_at_epoch:
            finetune_params = self._get_finetune_params(pl_module)
            trainable_count = sum(1 for param in finetune_params if param.requires_grad)

            if self.verbose:
                print(f"✅ FinetuneDecoderJoint: At epoch {trainer.current_epoch}, {trainable_count} parameters are trainable.")

            # Ensure parameters are trainable (defensive)
            for param in finetune_params:
                param.requires_grad_(True)
                
                
