# Copyright (c) 2024 Boston Dynamics AI Institute LLC. All rights reserved.

from typing import Any, Optional, Union, Dict, Tuple, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import OmegaConf

from .backbones import build_backbone
from .feature_translators import build_feature_translator
from .utils import handle_feature_output
import lerobot.common.policies.vfmoe.models.vision_transformer as models
import copy
from .reconstruction_head import CNNDecoder

class RobotVisionFM(nn.Module):
    """Robot Vision Foundation Model (temporary name).

    Attributes:
        backbone (str | nn.Module): backbone network. Defaults to "deit-small-patch16-224".
        pretrained (bool): whether to use pretrained weights. Default to False.
        translator (str | nn.Module): feature translator module. Defaults to "conv".
        target_feature_sizes (Optional[dict[str, torch.Size | tuple[int, ...]]]):
            a dict to hold target feature sizes.
        translator_kwargs (Optional[dict[str, Any]]): other keyword arguments to the translator.
        target_loss_weights (Optional[dict[str, float]]):
            weights to balance loss from different target models. If not specified, use even weights.
        checkpoint_path: (Optional[str]): filename of pretrained weights to load.
        feature_reduce_method: (Optional[str]): how to reduce the feature in downstream applications.
    """

    def __init__(
        self,
        backbone: Union[str, nn.Module] = "facebook/deit-small-patch16-224",
        pretrained: bool = False,
        translator: Union[str, nn.Module] = "lconv",
        target_feature_sizes: Optional[Dict[str, Union[torch.Size, Tuple[int, ...]]]] = None,
        translator_kwargs: Optional[Dict[str, Any]] = None,
        target_loss_weights: Optional[Dict[str, float]] = None,
        checkpoint_path: Optional[str] = None,
        feature_reduce_method: Optional[str] = None,
        image_size: int = 224,
        **kwargs: Any
    ) -> None:
        super().__init__()

        self.target_feature_sizes = target_feature_sizes
        self.preprocessor = None
        self.pretrained = pretrained

        # backbone
        self.image_size = image_size
        self.backbone: nn.Module = build_backbone(backbone, pretrained, image_size=image_size, **kwargs)
        self.final_spatial = None
        if hasattr(self.backbone, "final_spatial"):
            self.final_spatial = self.backbone.final_spatial

        # handle output feature (feature reduce)
        self.feature_reduce_method = feature_reduce_method
        self.no_cls = hasattr(self.backbone, "no_cls")
        self.num_reg_tokens = self.backbone.num_reg_tokens if hasattr(self.backbone, "num_reg_tokens") else 0

        # translator
        backbone_feature_size = self.backbone.get_feature_size(keep_spatial=True)
        if self.target_feature_sizes:
            translator_kwargs = {} if translator_kwargs is None else OmegaConf.to_container(translator_kwargs)
            translator_kwargs["backbone_feature_size"] = backbone_feature_size
            translator_kwargs["target_feature_sizes"] = target_feature_sizes
            self.translator = build_feature_translator(translator, **translator_kwargs)

        # loss
        self.mse_loss = nn.MSELoss()
        self.l1_loss = nn.SmoothL1Loss()
        self.cos_loss = nn.CosineEmbeddingLoss()
        self.cos_target = torch.ones((1), dtype=torch.int, requires_grad=False)
        self.target_loss_weights = target_loss_weights

    def load_pretrained_weights(self, checkpoint_path: str):
        """Load pretrained weights.

        Args:
            checkpoint_path (str): path to checkpoint / weight.
        """
        if checkpoint_path:
            weights_dict = torch.load(checkpoint_path, map_location="cpu")
            # Filter out unnecessary keys
            pretrained_dict = {k: v for k, v in weights_dict.items() if k in self.state_dict()}
            self.load_state_dict(pretrained_dict, strict=False)

    def freeze_translator(self) -> None:
        """Freeze the feature translator."""
        for param in self.translator.parameters():
            param.requires_grad = False

    def forward_feature(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
        """Forward RVFM feature only (before translators).

        Args:
            x (torch.Tensor): input image. By default it accepts images 
                in shape [B, H, W, C] or [B, C, H, W], pixel range [0,255], torch.uint8.
            kwargs (Any): kwargs including mainly those for huggingface preprocessor:
                `do_resize` (bool) defaults to True.
                `interpolate_pos_encoding` (Optional[bool]) defaults to None.
                `do_rescale` (bool) defaults to True.
                `do_normalize` (bool) defaults to True.

        Returns:
            torch.Tensor: RVFM feature.
        """
        feature = self.backbone(x, **kwargs)
        # [B, 1+H*W+N, C] if including both CLS and register tokens.
        # [B, 1+H*W, C] for standard model (N=0).
        # [B, H*W, C] for model without CLS.
        return handle_feature_output(feature, num_discard_tokens=self.num_reg_tokens)

    def forward(self, x: torch.Tensor, target_model_names: Optional[List[str]] = None, **kwargs: Any) -> Dict[str, torch.Tensor]:
        """Forward pass of Robot Vision Foundation Model.

        Args:
            x (torch.Tensor): input image. By default it accepts images 
                in shape [B, H, W, C] or [B, C, H, W], pixel range [0,255], torch.uint8.
            target_model_names (Optional[list[str]]): names of the target foundation models.
            kwargs (Any): kwargs including mainly those for huggingface preprocessor:
                `do_resize` (bool) defaults to True.
                `interpolate_pos_encoding` (Optional[bool]) defaults to None.
                `do_rescale` (bool) defaults to True.
                `do_normalize` (bool) defaults to True.

        Returns:
            dict[str, torch.Tensor]: features that match to each foundation model.
                Each feature is in [B, (H*W), C] or [B, C].
        """
        x = self.backbone(x, **kwargs)
        if self.num_reg_tokens > 0:
            x = x[:, :-self.num_reg_tokens]  # [B, (1)+H*W, C]
        features = self.translator(x, target_model_names, backbone_no_cls=self.no_cls)  # each is [B, H*W, C] or [B, C]
        return features

    def get_loss(self, pred_features: Dict[str, torch.Tensor], y: Dict[str, torch.Tensor]) -> Dict[str, Any]:
        """Get loss terms given predictions and targets.

        Args:
            pred_features (dict[str, torch.Tensor]): predictions.
            y (dict[str, torch.Tensor]): targets.

        Returns:
            tuple[Any, ...]: loss terms
        """
        mse_loss_avg, cos_loss_avg, l1_loss_avg = 0, 0, 0
        mse_losses_per_model = {}
        cos_losses_per_model = {}
        l1_losses_per_model = {}

        for t in pred_features:
            pred = pred_features[t]
            target = y[t]

            # mse loss
            mse_loss = self.mse_loss(pred, target)
            weight = self.target_loss_weights if self.target_loss_weights else 1.0 / len(pred_features)

            # l1 loss
            l1_loss = self.l1_loss(pred, target)

            # cos loss
            pred_norm = F.normalize(pred.flatten(start_dim=1), dim=1, p=2)
            target_norm = F.normalize(target.flatten(start_dim=1), dim=1, p=2)
            target = self.cos_target.repeat(pred.size(0)).to(pred.device)
            cos_loss = self.cos_loss(pred_norm, target_norm, target)

            mse_loss_avg += mse_loss * weight
            cos_loss_avg += cos_loss / len(pred_features)  # balance cos by default for meaningful eval
            l1_loss_avg += l1_loss * weight

            mse_losses_per_model[t] = mse_loss.item()
            cos_losses_per_model[t] = cos_loss.item()
            l1_losses_per_model[t] = l1_loss.item()

        return {
            "mse_loss": mse_loss_avg,
            "cos_loss": cos_loss_avg,
            "l1_loss": l1_loss_avg,
            "mse_losses_per_model": mse_losses_per_model,
            "cos_losses_per_model": cos_losses_per_model,
            "l1_losses_per_model": l1_losses_per_model,
        }
        
        
        

class VisionFMMoE(nn.Module):
    """Vision Foundation Model MoE.

    Attributes:
        backbone (str | nn.Module): backbone network. Defaults to "deit-small-patch16-224".
        pretrained (bool): whether to use pretrained weights. Default to False.
        translator (str | nn.Module): feature translator module. Defaults to "conv".
        target_feature_sizes (Optional[dict[str, torch.Size | tuple[int, ...]]]):
            a dict to hold target feature sizes.
        translator_kwargs (Optional[dict[str, Any]]): other keyword arguments to the translator.
        target_loss_weights (Optional[dict[str, float]]):
            weights to balance loss from different target models. If not specified, use even weights.
        checkpoint_path: (Optional[str]): filename of pretrained weights to load.
        feature_reduce_method: (Optional[str]): how to reduce the feature in downstream applications.
    """

    def __init__(
        self,
        backbone:  Union[str, nn.Module] = "facebook/deit-small-patch16-224",
        pretrained: bool = False,
        translator: Union[str, nn.Module] = "lconv",
        target_feature_sizes: Optional[Dict[str, Union[torch.Size, Tuple[int, ...]]]] = None,
        translator_kwargs: Optional[Dict[str, Any]] = None,
        target_loss_weights: Optional[Dict[str, float]] = None,
        checkpoint_path: Optional[str] = None,
        feature_reduce_method: Optional[str] = None,
        image_size: int = 224,
        moe_cfg: Any = None,
        **kwargs: Any
    ) -> None:
        super().__init__()

        self.target_feature_sizes = target_feature_sizes
        self.preprocessor = None
        self.pretrained = pretrained

        # backbone
        self.image_size = image_size
        self.backbone: nn.Module = build_backbone(backbone, pretrained, image_size=image_size, **kwargs)
        self.final_spatial = None
        if hasattr(self.backbone, "final_spatial"):
            self.final_spatial = self.backbone.final_spatial

        # handle output feature (feature reduce)
        self.feature_reduce_method = feature_reduce_method
        self.no_cls = hasattr(self.backbone, "no_cls")
        self.num_reg_tokens = self.backbone.num_reg_tokens if hasattr(self.backbone, "num_reg_tokens") else 0

        # translator
        backbone_feature_size = self.backbone.get_feature_size(keep_spatial=True)
        self.backbone_feature_size = backbone_feature_size
        if self.target_feature_sizes:
            translator_kwargs = {} if translator_kwargs is None else OmegaConf.to_container(translator_kwargs)
            translator_kwargs["backbone_feature_size"] = backbone_feature_size
            translator_kwargs["target_feature_sizes"] = target_feature_sizes
            # print(translator_kwargs)
            self.translator = build_feature_translator(translator, **translator_kwargs)
            
        name = moe_cfg.model  
        block_cfg_dict = OmegaConf.to_container(moe_cfg)  
        del block_cfg_dict['name']
        del block_cfg_dict['model']  
        block_cfg_dict['dim'] = backbone_feature_size[0]
        block_cfg_dict['ffn_hidden_dim'] = int(backbone_feature_size[0] * 4.0)
        self.expert_pool = nn.ModuleList([getattr(models, name)(**block_cfg_dict) for _ in range(3)])
        
        # MoE pool
        # self.expert_pool = nn.ModuleList([
        #     MoEBlock(
        #         dim = backbone_feature_size[0],
        #         num_heads = 8,
        #         mlp_ratio=4.0,
        #         qkv_bias=False,
        #         qk_scale=None,
        #         drop=0.0,
        #         attn_drop=0.0,
        #         drop_path=0.0,
        #         act_layer=nn.GELU,
        #         norm_layer=nn.LayerNorm,
        #         num_total_experts=6,
        #         num_active_experts=2,
        #         task_num=3,
        #     )
        #     for _ in range(3)
        # ])
        
        # self.expert_pool = nn.ModuleList([
        #     models.MoETransformerBlock(
        #         dim=backbone_feature_size[0],
        #         attn_k=3, # 2
        #         ffn_k=6, # 2
        #         attn_num_experts=4, # 6
        #         ffn_num_experts=8, # 6
        #         task_num=3, 
        #         attn_expert_bias=False,
        #         ffn_expert_bias=True,
        #         attn_expert_dim_divisor=4, # 2
        #         ffn_expert_dim_divisor=8, # 2
        #         ffn_hidden_dim=int(backbone_feature_size[0] * 4.0),
        #         shared_routers=False,
        #         num_heads=4, # 8
        #         qkv_bias=False,
        #         qk_scale=None,
        #         attn_drop=0.0,
        #         proj_drop=0.0,         
        #         w_MI=0.0005, w_H=0, 
        #         w_finetune_MI=0, 
        #         noisy_gating=True,
        #         drop_path=0.0,
        #     )
        #     for _ in range(3)
        # ])
        
        # self.expert_pool = nn.ModuleList([
        #     MoEFfnBlock(
        #         dim=backbone_feature_size[0],
        #         ffn_k=2,
        #         ffn_num_experts=6, 
        #         task_num=3, 
        #         ffn_expert_bias=True,
        #         ffn_expert_dim_divisor=2,
        #         ffn_hidden_dim=int(backbone_feature_size[0] * 4.0),
        #         num_heads=8,
        #         qkv_bias=False,
        #         qk_scale=None,
        #         attn_drop=0.0,
        #         proj_drop=0.0,         
        #         w_MI=0.0005, w_H=0, 
        #         w_finetune_MI=0, 
        #         noisy_gating=True,
        #         drop_path=0.0,
        #         norm_layer=nn.LayerNorm,
        #     )
        #     for _ in range(3)
        # ])
        

        # loss
        self.mse_loss = nn.MSELoss()
        self.l1_loss = nn.SmoothL1Loss()
        self.cos_loss = nn.CosineEmbeddingLoss()
        self.cos_target = torch.ones((1), dtype=torch.int, requires_grad=False)
        self.target_loss_weights = target_loss_weights
        
        # task map
        self.task_map = {
            'google/vit-huge-patch14-224-in21k': 0, 
            'facebook/dinov2-large': 1, 
            'openai/clip-vit-large-patch14': 2,
        }
        

    def load_pretrained_weights(self, checkpoint_path: str):
        """Load pretrained weights.

        Args:
            checkpoint_path (str): path to checkpoint / weight.
        """
        if checkpoint_path:
            weights_dict = torch.load(checkpoint_path, map_location="cpu")
            # Filter out unnecessary keys
            pretrained_dict = {k: v for k, v in weights_dict.items() if k in self.state_dict()}
            self.load_state_dict(pretrained_dict, strict=False)

    def freeze_translator(self) -> None:
        """Freeze the feature translator."""
        for param in self.translator.parameters():
            param.requires_grad = False
            
    def forward_backbone(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
        """Forward RVFM feature only (before translators).

        Args:
            x (torch.Tensor): input image. By default it accepts images 
                in shape [B, H, W, C] or [B, C, H, W], pixel range [0,255], torch.uint8.
            kwargs (Any): kwargs including mainly those for huggingface preprocessor:
                `do_resize` (bool) defaults to True.
                `interpolate_pos_encoding` (Optional[bool]) defaults to None.
                `do_rescale` (bool) defaults to True.
                `do_normalize` (bool) defaults to True.

        Returns:
            torch.Tensor: RVFM feature.
        """

        feature = self.backbone(x, **kwargs)
        
        # [B, 1+H*W+N, C] if including both CLS and register tokens.
        # [B, 1+H*W, C] for standard model (N=0).
        # [B, H*W, C] for model without CLS.
        return feature
    
    def forward_backbone_patch(self, x: torch.Tensor, **kwargs: Any) -> torch.Tensor:
        """Forward RVFM feature only (before translators).

        Args:
            x (torch.Tensor): input image. By default it accepts images 
                in shape [B, H, W, C] or [B, C, H, W], pixel range [0,255], torch.uint8.
            kwargs (Any): kwargs including mainly those for huggingface preprocessor:
                `do_resize` (bool) defaults to True.
                `interpolate_pos_encoding` (Optional[bool]) defaults to None.
                `do_rescale` (bool) defaults to True.
                `do_normalize` (bool) defaults to True.

        Returns:
            torch.Tensor: RVFM feature.
        """

        feature = self.backbone(x, **kwargs)
        
        # [B, 1+H*W+N, C] if including both CLS and register tokens.
        # [B, 1+H*W, C] for standard model (N=0).
        # [B, H*W, C] for model without CLS.
        return handle_feature_output(feature, num_discard_tokens=self.num_reg_tokens)
    
    def forward_expert(self, x: torch.Tensor, task_name: str, **kwargs: Any) -> torch.Tensor:
        
        if task_name not in self.task_map:
            raise ValueError(f"Task name {task_name} not found in task map. /nTask map: {self.task_map}")
        task_id = self.task_map[task_name]
        feature = x
        mi_loss = 0
        
        # print(task_name, task_id)
        
        # print(task_name, task_id)
        for expert in self.expert_pool:
            output = expert(feature, task_id)
            feature = output[0]
            mi_loss += output[3]
        
        # [B, 1+H*W+N, C] if including both CLS and register tokens.
        # [B, 1+H*W, C] for standard model (N=0).
        # [B, H*W, C] for model without CLS.
        return handle_feature_output(feature, num_discard_tokens=self.num_reg_tokens), mi_loss

        
        

    def forward_feature(self, x: torch.Tensor, task_name: str, return_pre = False,**kwargs: Any) -> torch.Tensor:
        """Forward RVFM feature only (before translators).

        Args:
            x (torch.Tensor): input image. By default it accepts images 
                in shape [B, H, W, C] or [B, C, H, W], pixel range [0,255], torch.uint8.
            kwargs (Any): kwargs including mainly those for huggingface preprocessor:
                `do_resize` (bool) defaults to True.
                `interpolate_pos_encoding` (Optional[bool]) defaults to None.
                `do_rescale` (bool) defaults to True.
                `do_normalize` (bool) defaults to True.

        Returns:
            torch.Tensor: RVFM feature.
        """
        if task_name not in self.task_map:
            raise ValueError(f"Task name {task_name} not found in task map. /nTask map: {self.task_map}")
        task_id = self.task_map[task_name]
        pre_feature = self.backbone(x, **kwargs)
        feature = pre_feature
        for expert in self.expert_pool:
            output = expert(feature, task_id)
            feature = output[0]
        
        # [B, 1+H*W+N, C] if including both CLS and register tokens.
        # [B, 1+H*W, C] for standard model (N=0).
        # [B, H*W, C] for model without CLS.
        if return_pre:
            return handle_feature_output(feature, num_discard_tokens=self.num_reg_tokens), handle_feature_output(pre_feature, num_discard_tokens=self.num_reg_tokens)
        else:
            return handle_feature_output(feature, num_discard_tokens=self.num_reg_tokens)

    def forward(self, x: torch.Tensor, target_model_names: Optional[List[str]] = None, reconstruction_enabled: bool = False, **kwargs: Any) -> Dict[str, torch.Tensor]:
        """Forward pass of Robot Vision Foundation Model.

        Args:
            x (torch.Tensor): input image. By default it accepts images 
                in shape [B, H, W, C] or [B, C, H, W], pixel range [0,255], torch.uint8.
            target_model_names (Optional[list[str]]): names of the target foundation models.
            kwargs (Any): kwargs including mainly those for huggingface preprocessor:
                `do_resize` (bool) defaults to True.
                `interpolate_pos_encoding` (Optional[bool]) defaults to None.
                `do_rescale` (bool) defaults to True.
                `do_normalize` (bool) defaults to True.

        Returns:
            dict[str, torch.Tensor]: features that match to each foundation model.
                Each feature is in [B, (H*W), C] or [B, C].
        """
        
        
        
        x = self.backbone(x, **kwargs)
        if self.num_reg_tokens > 0:
            x = x[:, :-self.num_reg_tokens]  # [B, (1)+H*W, C]
            
        
        features = {}
        mi_loss = 0
        for model_name in target_model_names:
            task_id = self.task_map[model_name]
            feature = x
            for expert in self.expert_pool:
                try:
                    feature, attn_loss, attn_probs, ffn_loss, ffn_probs= expert(feature, task_id)
                    if attn_loss is not None:
                        mi_loss += attn_loss
                    if ffn_loss is not None:
                        mi_loss += ffn_loss
                except:
                    feature, aux_loss, probs = expert(feature, task_id)
                    mi_loss += aux_loss
                
            
            features.update(self.translator(feature, [model_name], backbone_no_cls=self.no_cls))  # each is [B, H*W, C] or [B, C]
        
        if reconstruction_enabled:
            features['reconstruction'] = self.translator.translator_heads['reconstruction'](x)
        
        
        return features, mi_loss

    def get_loss(self, pred_features: Dict[str, torch.Tensor], y: Dict[str, torch.Tensor], images: torch.Tensor, reconstruction_enabled: bool = False) -> Dict[str, Any]:
        """Get loss terms given predictions and targets.

        Args:
            pred_features (dict[str, torch.Tensor]): predictions.
            y (dict[str, torch.Tensor]): targets.

        Returns:
            tuple[Any, ...]: loss terms
        """
        mse_loss_avg, cos_loss_avg, l1_loss_avg = 0, 0, 0
        mse_losses_per_model = {}
        cos_losses_per_model = {}
        l1_losses_per_model = {}

        for t in y:
            pred = pred_features[t]
            target = y[t]

            # mse loss
            mse_loss = self.mse_loss(pred, target)
            weight = self.target_loss_weights if self.target_loss_weights else 1.0 / len(pred_features)

            # l1 loss
            l1_loss = self.l1_loss(pred, target)

            # cos loss
            pred_norm = F.normalize(pred.flatten(start_dim=1), dim=1, p=2)
            target_norm = F.normalize(target.flatten(start_dim=1), dim=1, p=2)
            target = self.cos_target.repeat(pred.size(0)).to(pred.device)
            cos_loss = self.cos_loss(pred_norm, target_norm, target)

            mse_loss_avg += mse_loss * weight
            cos_loss_avg += cos_loss / len(pred_features)  # balance cos by default for meaningful eval
            l1_loss_avg += l1_loss * weight

            mse_losses_per_model[t] = mse_loss.item()
            cos_losses_per_model[t] = cos_loss.item()
            l1_losses_per_model[t] = l1_loss.item()
            
        if reconstruction_enabled:
            # transform images to [B, C, H, W] from [B, H, W, C]
            # normalize to [0, 1]
            target_images = images.permute(0, 3, 1, 2) / 255.0
            
            rec_mse_loss = self.mse_loss(pred_features['reconstruction'], target_images)
            rec_l1_loss = self.l1_loss(pred_features['reconstruction'], target_images)
            rec_ssim_loss = ssim_loss(pred_features['reconstruction'], target_images)
            
            rec_total_loss = 1.0 * rec_mse_loss + 0.5 * rec_l1_loss + 0.1 * rec_ssim_loss
            
            rec_loss = {'rec/mse_loss': rec_mse_loss, 'rec/l1_loss': rec_l1_loss, 'rec/ssim_loss': rec_ssim_loss, 'rec/total_loss': rec_total_loss}
            
            
        losses = {
            "mse_loss": mse_loss_avg,
            "cos_loss": cos_loss_avg,
            "l1_loss": l1_loss_avg,
            "mse_losses_per_model": mse_losses_per_model,
            "cos_losses_per_model": cos_losses_per_model,
            "l1_losses_per_model": l1_losses_per_model,
        }
        if reconstruction_enabled:
            losses.update(rec_loss)


        return losses
        
    
    
        
    def add_moe_experts(self, n_new_tasks, n_new_experts, new_task_names: Dict, freeze_old=True):
        # add new experts for the expert pool
        for param in self.backbone.parameters():
            param.requires_grad = False
        
        for param in self.expert_pool.parameters():
            param.requires_grad = False
            
        for old_task in self.task_map.keys():
            for param in self.translator.translator_heads[old_task].parameters():
                param.requires_grad = False
        
        old_task_num = len(self.task_map)
        self.task_map.update({task_name: old_task_num + i for i, task_name in enumerate(new_task_names)})
        
        for expert in self.expert_pool:
            expert.add_moe_experts(n_new_tasks, n_new_experts, freeze_old=freeze_old)


    def freeze_old(self, new_task_names: List[str]):
        # freeze parameters for old tasks; only support one new task.
        for param in self.backbone.parameters():
            param.requires_grad = False
        
        for key, param in self.expert_pool.named_parameters():
            if 'new' in key:
                param.requires_grad = True
            else:
                param.requires_grad = False
            
        # print('.'*20)
        # print(self.task_map)
        # print('.'*20)
            
        old_task_num = len(self.task_map)

        for i in range(old_task_num):
            old_task = list(self.task_map.keys())[i]
            for param in self.translator.translator_heads[old_task].parameters():
                param.requires_grad = False
                
        self.task_map.update({task_name: old_task_num + i for i, task_name in enumerate(new_task_names)})
        
        
        

    def merge_experts(self):
        """
        Merges new_weight into weight and removes new_weight.
        Should be called after loading the checkpoint of continual training.
        Then you can save it as a new checkpoint. the parameter 'new_weight' will be removed.
        """
        for expert in self.expert_pool:
            expert.merge_experts()
            
    # add a new translator head to decode to the orginal images
    def add_reconstruction_translator_head(self, decoder_type='transformer', train_reconstruction_only=False):
        if train_reconstruction_only:
            for param in self.backbone.parameters():
                param.requires_grad = False
            
            for param in self.expert_pool.parameters():
                param.requires_grad = False
            
            for param in self.translator.translator_heads.parameters():
                param.requires_grad = False
        
        
        num_tasks = len(self.task_map)
        task_name = 'reconstruction'
        self.task_map[task_name] = num_tasks
        # print(self.backbone_feature_size, self.image_size)
        if decoder_type == 'cnn':
            self.translator.translator_heads.update({task_name: CNNDecoder(in_dim=self.backbone_feature_size[0], img_size=self.image_size, patch_size=16, out_channels=3)})
        elif decoder_type == 'transformer':
            dim = self.backbone_feature_size[0]
            if dim == 192:
                depth = 4
                num_heads = 4
            elif dim == 384:
                depth = 5
                num_heads = 8
            elif dim == 768:
                depth = 6
                num_heads = 8
            self.translator.translator_heads.update({task_name: models.TransformerDecoder(in_dim=dim,
                                                                                          num_patches=196,
                                                                                          patch_size=16,
                                                                                          if_cls_token=True,
                                                                                          depth = depth,
                                                                                          num_heads = num_heads,)})

        
    def add_downstream_tasks(self, n_new_tasks, n_new_experts, new_task_names: List[str], freeze_old=True, unfreeze_norm=False, noising_gating=False, topk=-1):
        # add new experts for the expert pool
        for param in self.parameters():
            param.requires_grad = False
            
        if unfreeze_norm:
            for expert in self.expert_pool:
                expert.unfreeze_norm()
            
        
        
        old_task_num = len(self.task_map)
        self.task_map.update({task_name: old_task_num + i for i, task_name in enumerate(new_task_names)})
        
        for expert in self.expert_pool:
            expert.add_moe_experts(n_new_tasks, n_new_experts, freeze_old=freeze_old, noisy_gating=noising_gating, topk=topk)
        

def ssim_loss(img1, img2, window_size=11, size_average=True):
    C1 = 0.01 ** 2
    C2 = 0.03 ** 2

    # # Convert inputs to float and normalize to [0,1] if needed
    # if img1.dtype != torch.float32:
    #     img1 = img1.float()
    # if img2.dtype != torch.float32:
    #     img2 = img2.float()
    # if img1.max() > 1:
    #     img1 = img1 / 255.0
    # if img2.max() > 1:
    #     img2 = img2 / 255.0

    mu1 = F.avg_pool2d(img1, window_size, stride=1, padding=window_size//2)
    mu2 = F.avg_pool2d(img2, window_size, stride=1, padding=window_size//2)

    sigma1_sq = F.avg_pool2d(img1 * img1, window_size, stride=1, padding=window_size//2) - mu1 ** 2
    sigma2_sq = F.avg_pool2d(img2 * img2, window_size, stride=1, padding=window_size//2) - mu2 ** 2
    sigma12 = F.avg_pool2d(img1 * img2, window_size, stride=1, padding=window_size//2) - mu1 * mu2

    ssim_map = ((2 * mu1 * mu2 + C1) * (2 * sigma12 + C2)) / ((mu1 ** 2 + mu2 ** 2 + C1) * (sigma1_sq + sigma2_sq + C2))
    return 1 - ssim_map.mean()