import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import pandas as pd
import seaborn as sns
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
import random

from utils import global_print
from models.projector import build_projector
from collections import OrderedDict
import copy

from functools import partial
from models.backbones.vit import Block


from functools import partial
from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, Union, List

try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.jit import Final

from timm.data import (
    IMAGENET_INCEPTION_MEAN,
    IMAGENET_INCEPTION_STD,
)
from timm.layers import (
    PatchEmbed,
    Mlp,
    DropPath,
    PatchDropout,
    trunc_normal_,
    resample_patch_embed,
    resample_abs_pos_embed,
    use_fused_attn,
    get_act_layer,
    get_norm_layer,
    LayerType,
)
from timm.models._builder import build_model_with_cfg
from timm.models._features import feature_take_indices
from timm.models._manipulate import named_apply, adapt_input_conv
from models.backbones.moh_attention import MoH_ViT
from models.backbones.router import ConditionedRouter, StructureConditionedRouter, AdvancedConditionedRouter, DecoupledGroupRouter

class MlpProjector(nn.Module):
    def __init__(self, input_dim, output_dim, mlp_ratio=4.0, activation=nn.GELU):
        super().__init__()
        hidden_dim = int(input_dim * mlp_ratio)

        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            activation(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.net(x)

class ConstrainedSigmoid(nn.Module):
    def __init__(self, min_val: float = 0.5, max_val: float = 1.0):
        super().__init__()
        if not 0.0 <= min_val < max_val <= 1.0:
            raise ValueError(f"Range ({min_val}, {max_val}) is invalid. "
                             "Must satisfy 0.0 <= min_val < max_val <= 1.0.")

        self.min_val = min_val
        self.max_val = max_val
        self.range = max_val - min_val

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.sigmoid(x) * self.range + self.min_val

    def __repr__(self):
        return f"ConstrainedSigmoid(min_val={self.min_val}, max_val={self.max_val})"

class UniversalExpert(nn.Module):
    def __init__(self, input_dim, expert_hidden_ratio, output_dim: int = None, dropout=0.0):
        super().__init__()
        if output_dim is None:
            output_dim = input_dim
        hidden_dim = int(input_dim * expert_hidden_ratio)
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

class ContextAwareGate(nn.Module):
    def __init__(self, dim: int, mlp_ratio: float = 0.25):
        super().__init__()
        self.dim = dim
        global_print("################ Using ContextAwareGate! ##################")

        hidden_dim = int(dim * mlp_ratio)
        self.local_proj = nn.Linear(dim, hidden_dim)
        self.global_proj = nn.Linear(dim, hidden_dim)

        self.mlp = nn.Sequential(
            nn.GELU(),
            nn.Linear(hidden_dim, 1)
        )

        self.activation = lambda x: 0.5 * torch.sigmoid(x) + 0.5

        with torch.no_grad():
            self.mlp[1].bias.fill_(2.0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        global_context = torch.mean(x, dim=1, keepdim=True)
        local_feat = self.local_proj(x)
        global_feat = self.global_proj(global_context)

        fused_feat = local_feat + global_feat
        pre_activation = self.mlp(fused_feat)
        g = self.activation(pre_activation)

        return g, pre_activation

class SmarterGate(nn.Module):
    def __init__(self, dim, patch_resolution=None):
        super().__init__()
        global_print("################ Using SmarterGate! ##################")
        self.conv = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=3, padding=1, groups=dim, bias=False),
            nn.GELU(),
            nn.Conv2d(dim, 1, kernel_size=1, bias=True)
        )
        self.activation = ConstrainedSigmoid()

        with torch.no_grad():
            self.conv[2].bias.fill_(2.0)

    def forward(self, x):
        B, N, C = x.shape
        s = math.sqrt(N)
        h = w = int(s)

        assert h * w == N
        x_grid = x.transpose(1, 2).view(B, C, h, w)
        pre_act_grid = self.conv(x_grid)
        pre_act_flat = pre_act_grid.view(B, 1, N).transpose(1, 2).contiguous()

        g = self.activation(pre_act_flat)
        return g, pre_act_flat

class RSL(nn.Module):
    def __init__(self, in_features, num_experts, alpha=0.1):
        super().__init__()
        self.alpha = alpha
        self.num_experts = num_experts
        self.mlp = nn.Sequential(
            nn.Linear(in_features, 256),
            nn.GELU(),
            nn.Linear(256, num_experts),
        )
        nn.init.xavier_normal_(self.mlp[0].weight, gain=math.sqrt(2))
        nn.init.zeros_(self.mlp[0].bias)
        nn.init.xavier_normal_(self.mlp[2].weight, gain=1.0)

    def forward(self, hidden_states):
        B, seq_len, hidden_dim = hidden_states.shape
        logits = self.mlp(hidden_states).float()
        scores = F.softmax(logits, dim=-1)

        top3_weights, top3_indices = torch.topk(scores, k=3, dim=-1)
        top3_weights = F.normalize(top3_weights, p=1, dim=-1)

        aux_loss = None
        if self.training and self.alpha > 0.0:
            Pi = scores.mean(dim=0)
            fi = F.one_hot(top3_indices.view(-1), num_classes=self.num_experts).float().mean(dim=0)

            balance = (Pi * fi).sum()
            entropy = - (scores * scores.clamp(min=1e-9).log()).sum(dim=-1).mean()

            gamma = 0.1
            aux_loss = self.alpha * balance - gamma * entropy

        router_weights = scores.view(B, seq_len, self.num_experts).float()
        return router_weights, aux_loss, top3_indices, top3_weights

class ConditionedMoELayer(nn.Module):
    def __init__(self, input_dim, num_experts, condition_dim, top_k=1,
                 expert_hidden_ratio=4.0,
                 load_balancing_alpha=0.01, noisy_gating=True, num_experts_per_group=5,
                 router_type="no_structure_router",
                 base_block=None,
                 ortho_loss_weight=0.01,
                 variance_loss_weight=0.01,
                 seq_aux=False,
                 num_tasks=0,
                 ):
        if num_experts % num_experts_per_group != 0:
            raise ValueError("num_experts must be divisible by num_experts_per_group.")

        super().__init__()
        self.dim = input_dim
        self.num_experts = num_experts
        self.condition_dim = condition_dim
        self.top_k = top_k
        self.load_balancing_alpha = load_balancing_alpha
        self.num_experts_per_group = num_experts_per_group

        self.ortho_loss_target = ortho_loss_weight
        self.variance_loss_target = variance_loss_weight
        self.seq_aux = seq_aux

        self.register_buffer('ortho_loss_weight', torch.tensor(-1.0))
        self.register_buffer('variance_loss_weight', torch.tensor(-1.0))

        new_weight = self._compute_dynamic_weight(torch.tensor(0.004540108144283295), self.variance_loss_target)
        self.variance_loss_weight.fill_(new_weight)

        new_weight = self._compute_dynamic_weight(torch.tensor(0.9623544216156006), self.ortho_loss_target)
        self.ortho_loss_weight.fill_(new_weight)

        if router_type == "structure_router":
            self.router_type = "structure_router"
            self.router = StructureConditionedRouter(input_dim, num_experts, condition_dim, noisy_gating, top_k=self.top_k)
        elif router_type == "no_structure_router":
            self.router_type = "no_structure_router"
            self.router = ConditionedRouter(input_dim, num_experts, condition_dim, noisy_gating)
        elif router_type == "advanced_router":
            self.router_type = "advanced_router"
            self.router = AdvancedConditionedRouter(input_dim, num_experts, condition_dim, noisy_gating)
        elif router_type == "decoupled_router":
            self.router_type = "decoupled_router"
            self.router = DecoupledGroupRouter(input_dim, num_experts, num_experts_per_group,  condition_dim, noisy_gating, num_tasks=num_tasks)
        self.experts = nn.ModuleList(
            [UniversalExpert(input_dim, expert_hidden_ratio) for _ in range(num_experts)]
        )

        self.shared_expert = UniversalExpert(input_dim, expert_hidden_ratio)

        for i, expert in enumerate(self.experts):
            expert.load_state_dict(copy.deepcopy(base_block.mlp.state_dict()))

        self.shared_expert.load_state_dict(copy.deepcopy(base_block.mlp.state_dict()))

    def forward(self, x, condition_embedding, is_vfm_condition=False, vfm_teacher_id=None, task_id=None, return_expert_outputs=False):
        batch_size, num_tokens, dim = x.shape
        x_flat = x.view(-1, dim)
        current_top_k = self.top_k

        shared_output = self.shared_expert(x)

        gating_weights = None
        top_k_indices = None
        router_probs_flat = None

        if self.router_type == "structure_router":
            routing_mask = None
            if is_vfm_condition:
                if vfm_teacher_id is None:
                    raise ValueError("vfm_teacher_id must be provided in vfm_condition mode.")

                current_top_k = 2

                group_indices = vfm_teacher_id.unsqueeze(1)
                expert_indices = torch.arange(self.num_experts, device=x.device).unsqueeze(0)
                start_indices = group_indices * self.num_experts_per_group
                is_in_group = (expert_indices >= start_indices) & (
                        expert_indices < start_indices + self.num_experts_per_group)
                routing_mask = torch.where(is_in_group, 0.0, float('-inf'))

            _top_k_probs, _top_k_indices, _router_probs = self.router(
                x,
                condition_embedding=condition_embedding,
                is_vfm_condition=is_vfm_condition,
                routing_mask=routing_mask,
                override_top_k=current_top_k
            )

            gating_weights = _top_k_probs.view(-1, current_top_k)
            top_k_indices = _top_k_indices.view(-1, current_top_k)
            router_probs_flat = _router_probs.view(-1, self.num_experts)

        elif self.router_type == "no_structure_router":
            router_logits = self.router(
                x,
                condition_embedding,
                is_vfm_condition=is_vfm_condition,
                routing_mask=None
            )
            router_logits_flat = router_logits.view(-1, self.num_experts)

            top_k_logits, top_k_indices = torch.topk(router_logits_flat, self.top_k, dim=-1)
            gating_weights = F.softmax(top_k_logits, dim=-1)

            router_probs_flat = F.softmax(router_logits_flat, dim=-1)

        elif self.router_type == "advanced_router":
            router_logits = self.router(
                x,
                condition_embedding,
                is_vfm_condition=is_vfm_condition,
                routing_mask=None
            )

            router_logits_flat = router_logits.view(-1, self.num_experts)
            router_probs_flat = F.softmax(router_logits_flat, dim=-1)
            top_k_logits, top_k_indices = torch.topk(router_logits_flat, self.top_k, dim=-1)
            gating_weights = F.softmax(top_k_logits, dim=-1)

        elif self.router_type == "decoupled_router":
            if task_id == None:
                override_top_k = 2
            else:
                override_top_k = 1
            gating_weights, top_k_indices, probs_grouped = self.router(
                x, condition_embedding, task_id, override_top_k
            )
            if is_vfm_condition and vfm_teacher_id is not None:
                if vfm_teacher_id.shape[0] == batch_size:
                    vfm_teacher_id_flat = vfm_teacher_id.repeat_interleave(num_tokens, dim=0)
                else:
                    vfm_teacher_id_flat = vfm_teacher_id

                group_mask = F.one_hot(vfm_teacher_id_flat, num_classes=self.router.num_groups).float()

                current_total_k = gating_weights.shape[1]
                k_per_group = current_total_k // self.router.num_groups

                if k_per_group > 1:
                    group_mask = group_mask.repeat_interleave(k_per_group, dim=1)

                gating_sum = gating_weights.sum(dim=-1, keepdim=True) + 1e-6
                gating_weights = gating_weights / gating_sum

        combine_weights = torch.zeros(
            batch_size * num_tokens, self.num_experts,
            device=x.device, dtype=x.dtype
        )

        combine_weights.scatter_add_(1, top_k_indices, gating_weights)

        sparse_output_flat = torch.zeros_like(x_flat)

        collected_expert_outputs = None
        if self.ortho_loss_weight > 0 and current_top_k > 1:
            collected_expert_outputs = torch.zeros(
                batch_size * num_tokens, current_top_k, dim,
                device=x.device, dtype=torch.float16
            )

        for i in range(self.num_experts):
            expert_weight = combine_weights[:, i]
            active_mask = expert_weight > 1e-6
            active_inputs = x_flat[active_mask]
            weights_active = expert_weight[active_mask].unsqueeze(-1)
            active_inputs = active_inputs * weights_active
            active_outputs = self.experts[i](active_inputs)
            indices = torch.nonzero(active_mask).squeeze(-1)

            if collected_expert_outputs is not None:
                current_tokens_topk = top_k_indices[indices]
                is_current_expert = (current_tokens_topk == i)
                k_indices = is_current_expert.nonzero(as_tuple=True)[1]
                collected_expert_outputs[indices, k_indices] = active_outputs.to(collected_expert_outputs.dtype)

            weighted_outputs = active_outputs * expert_weight[active_mask].unsqueeze(-1)
            sparse_output_flat.index_add_(0, indices, weighted_outputs)

        sparse_output = sparse_output_flat.view(batch_size, num_tokens, dim)

        if return_expert_outputs:
            dispatched_input = torch.einsum('be,bd->bed', combine_weights, x_flat)
            dispatched_input = dispatched_input.transpose(0, 1)
            expert_outputs = torch.stack([self.experts[i](dispatched_input[i]) for i in range(self.num_experts)])
            expert_outputs = expert_outputs.transpose(0, 1)

        final_output = shared_output + sparse_output
        dispatch_weights =  combine_weights
        load_balancing_loss = torch.tensor(0.0, device=x.device, dtype=x.dtype)

        if self.training and self.load_balancing_alpha > 0:
            expert_load = (dispatch_weights > 0).float().mean(dim=0)
            router_probs_mean = router_probs_flat.mean(dim=0)
            load_balancing_loss = self.load_balancing_alpha * self.num_experts * torch.sum( expert_load * router_probs_mean )

        if return_expert_outputs:
            return final_output, load_balancing_loss, top_k_indices, expert_outputs

        if not self.training:
            return final_output, load_balancing_loss, top_k_indices, None

        return final_output, load_balancing_loss, None, None

    def _compute_dynamic_weight(self, current_val, target_val):
        if current_val < 1e-9:
            return torch.tensor(0.0, device=current_val.device)

        ratio = target_val / current_val
        weight = 10 ** torch.round(torch.log10(ratio))
        return weight

    def compute_ortho_loss(self, expert_outputs):
        if self.top_k <= 1:
            return torch.tensor(0.0, device=expert_outputs.device)

        norm_outputs = F.normalize(expert_outputs, p=2, dim=-1, eps=1e-6)
        sim_matrix = torch.bmm(norm_outputs, norm_outputs.transpose(1, 2))
        identity = torch.eye(self.top_k, device=expert_outputs.device).unsqueeze(0)
        ortho_loss = torch.sum((sim_matrix - identity) ** 2)
        count = expert_outputs.size(0) * (self.top_k * (self.top_k - 1))
        return ortho_loss / (count + 1e-6)

    def compute_weight_ortho_loss(self):
        loss = 0
        for i in range(self.router.num_groups):
            start = i * self.num_experts_per_group
            end = start + self.num_experts_per_group
            group_weights = torch.stack([e.fc2.weight for e in self.experts[start:end]])
            flat_weights = group_weights.view(self.num_experts_per_group, -1)
            norm_weights = F.normalize(flat_weights, p=2, dim=1)
            gram = torch.mm(norm_weights, norm_weights.t())
            eye = torch.eye(self.num_experts_per_group, device=gram.device)
            loss += torch.sum((gram - eye) ** 2)
        return loss

    def _compute_weight_ortho_loss(self):
        if self.ortho_loss_weight <= 0 or self.num_experts_per_group <= 1:
            return torch.tensor(0.0, device=self.experts[0].fc1.weight.device)

        total_ortho_loss = 0.0
        num_groups = self.num_experts // self.num_experts_per_group

        for g_id in range(num_groups):
            start_idx = g_id * self.num_experts_per_group
            end_idx = start_idx + self.num_experts_per_group
            w_list = [
                self.experts[i].fc1.weight.view(1, -1)
                for i in range(start_idx, end_idx)
            ]
            w_matrix = torch.cat(w_list, dim=0)
            w_norm = F.normalize(w_matrix, p=2, dim=1)
            gram_matrix = torch.mm(w_norm, w_norm.t())
            identity = torch.eye(self.num_experts_per_group, device=gram_matrix.device)
            loss_group = torch.mean((gram_matrix - identity) ** 2)
            total_ortho_loss += loss_group

        return total_ortho_loss / num_groups

class ConditionedMoETransformerBlock(nn.Module):
    def __init__(self,
                 input_dim,
                 num_heads,
                 condition_dim,
                 num_moe_experts,
                 moe_top_k=3,
                 mlp_ratio=4.0,
                 expert_hidden_ratio=4.0,
                 qkv_bias=True, noisy_gating=True,
                 attn_drop=0., proj_drop=0.,
                 base_block=None,
                 router_type="no_structure_router",
                 task_training=False,
                 num_tasks=0,
                 ):
        super().__init__()

        global_print("Initializing ConditionedMoETransformerBlock (API-Compatible Version).")
        self.norm1 = copy.deepcopy(base_block.norm1)
        self.attn = copy.deepcopy(base_block.attn)
        self.ls1 = copy.deepcopy(base_block.ls1)
        self.drop_path1 = copy.deepcopy(base_block.drop_path1)

        self.norm2 = copy.deepcopy(base_block.norm2)
        self.ls2 = copy.deepcopy(base_block.ls2)
        self.drop_path2 = copy.deepcopy(base_block.drop_path2)

        self.moe_ffn_layer = ConditionedMoELayer(
            input_dim=input_dim,
            num_experts=num_moe_experts,
            condition_dim=condition_dim,
            top_k=moe_top_k,
            expert_hidden_ratio=expert_hidden_ratio,
            noisy_gating=noisy_gating,
            router_type=router_type,
            base_block=base_block,
            num_tasks=num_tasks
        )

    def forward(self, x: torch.Tensor, condition_embedding, is_vfm_condition=False, vfm_teacher_id=None,
                return_expert_outputs=False, task_id=None):
        x_attn = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
        moe_input = self.norm2(x_attn)

        moe_output, aux_loss, top_k_indices, all_expert_outputs = self.moe_ffn_layer(
            moe_input,
            condition_embedding,
            is_vfm_condition=is_vfm_condition,
            vfm_teacher_id=vfm_teacher_id,
            return_expert_outputs=return_expert_outputs,
            task_id=task_id
        )

        x_ffn = x_attn + self.drop_path2(self.ls2(moe_output))

        g_placeholder = torch.zeros(x.shape[0], x.shape[1], 1, device=x.device, dtype=x.dtype)
        gate_loss_placeholder = torch.tensor(0.0, device=x.device, dtype=x.dtype)
        pre_act_placeholder = torch.full_like(g_placeholder, float('-inf'))

        if return_expert_outputs:
            top_k_indices_detached = top_k_indices.detach() if top_k_indices is not None else None
            return x_ffn, aux_loss, g_placeholder.detach(), top_k_indices_detached, all_expert_outputs

        if not self.training:
            top_k_indices_detached = top_k_indices.detach() if top_k_indices is not None else None
            return x_ffn, aux_loss, g_placeholder.detach(), top_k_indices_detached, None

        return x_ffn, aux_loss, g_placeholder, gate_loss_placeholder, pre_act_placeholder

class GatedMoETransformerBlock(nn.Module):
    def __init__(
            self,
            input_dim: int,
            base_block: nn.Module,
            condition_dim: int,
            num_moe_experts: int,
            moe_top_k: int,
            expert_hidden_ratio: float = 4.0,
            noisy_gating: bool = True,
            router_type="no_structure_router",
            gate_min_val: float = 0.5,
            gate_max_val: float = 1.0,
            num_tasks=0,
    ) -> None:
        super().__init__()

        global_print("Initializing GatedMoETransformerBlock with Gated MoE FFN.")
        self.norm1 = copy.deepcopy(base_block.norm1)
        self.attn = copy.deepcopy(base_block.attn)
        self.ls1 = copy.deepcopy(base_block.ls1)
        self.drop_path1 = copy.deepcopy(base_block.drop_path1)

        self.norm2 = copy.deepcopy(base_block.norm2)
        self.standard_ffn = copy.deepcopy(base_block.mlp)

        self.moe_ffn_layer = ConditionedMoELayer(
            input_dim=input_dim,
            num_experts=num_moe_experts,
            condition_dim=condition_dim,
            top_k=moe_top_k,
            expert_hidden_ratio=expert_hidden_ratio,
            noisy_gating=noisy_gating,
            router_type=router_type,
            base_block=base_block,
            num_tasks=num_tasks
        )

        self.gate_linear = nn.Linear(input_dim, 1)
        self.gate_activation = ConstrainedSigmoid(min_val=gate_min_val, max_val=gate_max_val)

        with torch.no_grad():
            self.gate_linear.bias.fill_(2.0)

        self.ls2 = copy.deepcopy(base_block.ls2)
        self.drop_path2 = copy.deepcopy(base_block.drop_path2)

    def forward(self, x: torch.Tensor, condition_embedding, is_vfm_condition=False, vfm_teacher_id=None, return_expert_outputs=False, task_id=None):
        x_attn = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
        ffn_input = self.norm2(x_attn)
        standard_output = self.standard_ffn(ffn_input)

        moe_output, aux_loss, top_k_indices, all_expert_outputs = self.moe_ffn_layer(
            ffn_input,
            condition_embedding,
            is_vfm_condition=is_vfm_condition,
            vfm_teacher_id=vfm_teacher_id,
            return_expert_outputs=return_expert_outputs,
            task_id=task_id
        )

        gate_loss = torch.tensor(0.0, device=x.device, dtype=x.dtype)
        gate_pre_activation = self.gate_linear(ffn_input)
        g = self.gate_activation(gate_pre_activation)

        fused_output = g * standard_output + (1.0 - g) * moe_output
        x_ffn = x_attn + self.drop_path2(self.ls2(fused_output))

        if return_expert_outputs:
            return x_ffn, aux_loss, g.detach(), top_k_indices.detach(), all_expert_outputs

        if not self.training:
            return x_ffn, aux_loss, g.detach(), top_k_indices.detach(), None

        return x_ffn, aux_loss, g, gate_loss, gate_pre_activation

    def gate_diversity_loss_sampled(self, ffn_input: torch.Tensor, g: torch.Tensor, num_samples: int):
        B, N, C = ffn_input.shape
        if N <= 1:
            return 0.0

        M = min(num_samples, N * (N - 1))
        ffn_input_norm = F.normalize(ffn_input, p=2, dim=-1,eps=1e-6)
        rand_indices_i = torch.randint(0, N, (B, M), device=ffn_input.device)
        rand_indices_j = torch.randint(0, N, (B, M), device=ffn_input.device)
        same_indices = (rand_indices_i == rand_indices_j)
        rand_indices_j[same_indices] = (rand_indices_j[same_indices] + 1) % N

        idx_i = rand_indices_i.unsqueeze(-1).expand(-1, -1, C)
        idx_j = rand_indices_j.unsqueeze(-1).expand(-1, -1, C)

        tokens_i = torch.gather(ffn_input_norm, 1, idx_i)
        tokens_j = torch.gather(ffn_input_norm, 1, idx_j)

        g_i = torch.gather(g, 1, rand_indices_i.unsqueeze(-1))
        g_j = torch.gather(g, 1, rand_indices_j.unsqueeze(-1))

        S_pairs = (tokens_i * tokens_j).sum(dim=-1)
        D_pairs = (g_i - g_j).squeeze(-1).pow(2)
        loss = (S_pairs * D_pairs - (1 - S_pairs) * D_pairs).mean()

        return loss

class Condition_MoE_ViT(nn.Module):
    def __init__(
            self,
            vit_base_model,
            condition_dim,
            num_moe_experts,
            moe_top_k=1,
            expert_hidden_ratio=4.0,
            moe_layers_indices=None,
            noisy_gating=True,
            out_indices=None,
            moe_type='conditioned',
            router_type="no_structure_router",
            gate_constraint_ranges=[[0.8, 1.0],[0.6, 1.0],[0.4, 1.0],[0.2, 1.0]],
            num_tasks=0
            ):
        super().__init__()
        self.embed_dim = vit_base_model.embed_dim
        self.patch_embed = copy.deepcopy(vit_base_model.patch_embed)
        self.pos_embed = copy.deepcopy(vit_base_model.pos_embed)
        self.num_classes = vit_base_model.num_classes
        self.global_pool = vit_base_model.global_pool
        self.num_features = self.head_hidden_size = self.embed_dim

        self.num_reg_tokens = vit_base_model.num_reg_tokens
        self.has_class_token = vit_base_model.has_class_token
        self.no_embed_class = vit_base_model.no_embed_class
        self.dynamic_img_size = vit_base_model.dynamic_img_size
        self.grad_checkpointing = vit_base_model.grad_checkpointing

        if hasattr(vit_base_model, 'cls_token') and vit_base_model.cls_token is not None:
            self.cls_token = copy.deepcopy(vit_base_model.cls_token)
        else:
            self.cls_token = None
        self.pos_drop = copy.deepcopy(getattr(vit_base_model, 'pos_drop', nn.Identity()))
        self.patch_drop = copy.deepcopy(getattr(vit_base_model, 'patch_drop', nn.Identity()))
        self.norm_pre = copy.deepcopy(getattr(vit_base_model, 'norm_pre', nn.Identity()))
        self.cls_token = copy.deepcopy(getattr(vit_base_model, 'cls_token', nn.Identity()))
        self.reg_token = copy.deepcopy(getattr(vit_base_model, 'reg_token', None))
        self.pos_drop = copy.deepcopy(getattr(vit_base_model, 'pos_drop', nn.Identity()))
        self.patch_drop = copy.deepcopy(getattr(vit_base_model, 'patch_drop', nn.Identity()))

        self.out_indices = out_indices
        self.num_prefix_tokens = getattr(vit_base_model, 'num_prefix_tokens', 1 if hasattr(vit_base_model, 'cls_token') else 0)

        self.blocks = nn.ModuleList()
        depth = len(vit_base_model.blocks)
        num_heads = vit_base_model.blocks[0].attn.num_heads

        assert len(moe_layers_indices) == len(gate_constraint_ranges)
        moe_config = dict(zip(moe_layers_indices, gate_constraint_ranges))
        for i in range(depth):
            if moe_layers_indices and i in moe_layers_indices:
                if moe_type == 'conditioned':
                    self.blocks.append(ConditionedMoETransformerBlock(input_dim=self.embed_dim,
                                                                      num_heads=num_heads,
                                                                      condition_dim=condition_dim,
                                                                      num_moe_experts=num_moe_experts, moe_top_k=moe_top_k,
                                                                      expert_hidden_ratio=expert_hidden_ratio,
                                                                      noisy_gating=noisy_gating,
                                                                      base_block=vit_base_model.blocks[i],
                                                                      router_type=router_type,
                                                                      num_tasks=num_tasks,
                                                                      ))
                elif moe_type == 'gated':
                    gate_min_val, gate_max_val = moe_config[i]
                    print(f"Layer {i}: Gated MoE with gate constraints [{gate_min_val}, {gate_max_val}]")
                    self.blocks.append(GatedMoETransformerBlock(input_dim=self.embed_dim,
                                                                condition_dim=condition_dim,
                                                                num_moe_experts=num_moe_experts, moe_top_k=moe_top_k,
                                                                expert_hidden_ratio=expert_hidden_ratio,
                                                                noisy_gating=noisy_gating,
                                                                base_block=vit_base_model.blocks[i],
                                                                router_type=router_type,
                                                                gate_min_val=gate_min_val,
                                                                gate_max_val=gate_max_val,
                                                                num_tasks=num_tasks,
                                                                ))
                else:
                    raise ValueError(f"Unsupported moe_type: {moe_type}. Choose 'conditioned' or 'gated'.")
            else:
                self.blocks.append(copy.deepcopy(vit_base_model.blocks[i]))

        self.norm = copy.deepcopy(vit_base_model.norm) if hasattr(vit_base_model, 'norm') and vit_base_model.norm is not None else nn.Identity()

    def forward(self, x, condition_embedding, is_vfm_condition=False,vfm_teacher_id=None):
        x = self.patch_embed(x)
        x = self._pos_embed(x)
        x = self.patch_drop(x)
        x = self.norm_pre(x)

        total_aux_loss_agg = torch.tensor(0.0, device=x.device, dtype=x.dtype)
        gate_regularization_loss = torch.tensor(0.0, device=x.device, dtype=x.dtype)
        gate_list = []
        topk_indices_list = []
        all_gate_pre_activations = []

        for i, blk in enumerate(self.blocks):
            pre_act_val = None

            if not isinstance(blk, Block):
                x, aux_loss, g, top_k_indices, pre_act_val = blk(x,
                                                                 condition_embedding,
                                                                 is_vfm_condition=is_vfm_condition,
                                                                 vfm_teacher_id=vfm_teacher_id)

                total_aux_loss_agg += aux_loss
                if g is not None:
                    gate_list.append(g)
                if not self.training:
                    topk_indices_list.append(top_k_indices)
                if self.training:
                    all_gate_pre_activations.append(pre_act_val)
            else:
                x = blk(x)

        if self.norm is not None:
            x = self.norm(x)

        if all_gate_pre_activations != [] and all_gate_pre_activations[0] is not None:
            for pre_act in all_gate_pre_activations:
                gate_regularization_loss += torch.mean(pre_act ** 2)
        return x, total_aux_loss_agg, gate_list, topk_indices_list, gate_regularization_loss

    def forward_intermediate_features(self, x, condition_embedding, is_vfm_condition=False, vfm_teacher_id=None, return_expert_outputs=False, task_id=None):
        x = self.patch_embed(x)
        x = self._pos_embed(x)
        x = self.patch_drop(x)
        x = self.norm_pre(x)

        total_aux_loss_agg = torch.tensor(0.0, device=x.device, dtype=x.dtype)
        gate_regularization_loss = torch.tensor(0.0, device=x.device, dtype=x.dtype)
        output = []
        gate_list = []
        topk_indices_list = []
        all_gate_pre_activations = []

        moe_blk_index=0
        for i, blk in enumerate(self.blocks):
            pre_act_val = None

            if not isinstance(blk, Block):
                cond = condition_embedding
                x, aux_loss, g, gate_loss_or_top_k_indices, pre_act_val = blk(x,
                                                                 cond,
                                                                 is_vfm_condition=is_vfm_condition,
                                                                 vfm_teacher_id=vfm_teacher_id,
                                                                 return_expert_outputs=return_expert_outputs,
                                                                 task_id=task_id
                                                                 )

                total_aux_loss_agg += aux_loss
                if g is not None:
                    gate_list.append(g)
                if not self.training:
                    topk_indices_list.append(gate_loss_or_top_k_indices)
                if self.training:
                    all_gate_pre_activations.append(pre_act_val)
            else:
                x = blk(x)

            if i in self.out_indices:
                if self.norm is not None:
                    x_norm = self.norm(x)
                else:
                    x_norm = x
                output.append(x_norm)

        gate_regularization_loss += gate_loss_or_top_k_indices if self.training else 0.0
        return output, total_aux_loss_agg, gate_list, topk_indices_list, gate_regularization_loss

    def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
            if self.pos_embed is None:
                return x.view(x.shape[0], -1, x.shape[-1])

            if self.dynamic_img_size:
                B, H, W, C = x.shape
                pos_embed = resample_abs_pos_embed(
                    self.pos_embed,
                    (H, W),
                    num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
                )
                x = x.view(B, -1, C)
            else:
                pos_embed = self.pos_embed

            to_cat = []
            if self.cls_token is not None:
                to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
            if self.reg_token is not None:
                to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))

            if self.no_embed_class:
                x = x + pos_embed
                if to_cat:
                    x = torch.cat(to_cat + [x], dim=1)
            else:
                if to_cat:
                    x = torch.cat(to_cat + [x], dim=1)
                x = x + pos_embed

            return self.pos_drop(x)

class Fat_Block(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = False,
        qk_norm: bool = False,
        proj_drop: float = 0.0,
        attn_drop: float = 0.0,
        init_values: Optional[float] = None,
        drop_path: float = 0.0,
        act_layer: nn.Module = nn.GELU,
        norm_layer: nn.Module = nn.LayerNorm,
        mlp_layer: nn.Module = Mlp,
        base_block=None,
    ) -> None:
        super().__init__()
        self.norm1 = copy.deepcopy(base_block.norm1)
        self.attn = copy.deepcopy(base_block.attn)
        self.ls1 = copy.deepcopy(base_block.ls1)
        self.drop_path1 = copy.deepcopy(base_block.drop_path1)
        self.norm2 = copy.deepcopy(base_block.norm2)
        self.mlp = mlp_layer(
            in_features=dim,
            hidden_features=int(dim * mlp_ratio),
            act_layer=act_layer,
            drop=proj_drop,
        )
        self.ls2 = copy.deepcopy(base_block.ls2)
        self.drop_path2 = copy.deepcopy(base_block.drop_path2)

        self._init_fat_mlp(base_block.mlp, self.mlp)

    def _init_fat_mlp(self, src_mlp, dst_mlp):
        w1 = src_mlp.fc1.weight.data
        b1 = src_mlp.fc1.bias.data
        w2 = src_mlp.fc2.weight.data
        b2 = src_mlp.fc2.bias.data

        dst_h_dim = dst_mlp.fc1.out_features
        src_h_dim = src_mlp.fc1.out_features

        repeat = dst_h_dim // src_h_dim
        if dst_h_dim % src_h_dim != 0:
            print("Warning: Fat MLP width is not an integer multiple. Initialization might be suboptimal.")

        with torch.no_grad():
            dst_mlp.fc1.weight.data.copy_(w1.repeat(repeat, 1))
            dst_mlp.fc1.bias.data.copy_(b1.repeat(repeat))
            dst_mlp.fc2.weight.data.copy_(w2.repeat(1, repeat) / repeat)
            dst_mlp.fc2.bias.data.copy_(b2)

        print(f"Fat MLP initialized from Base MLP using Net2Net tiling (Repeat: {repeat}x).")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
        x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
        return x

class Fat_ViT(nn.Module):
    def __init__(
            self,
            vit_base_model,
            fat_mlp_ratio=4.0,
            fat_layers_indices=None,
            out_indices=None,
    ):
        super().__init__()

        self.embed_dim = vit_base_model.embed_dim
        self.patch_embed = copy.deepcopy(vit_base_model.patch_embed)
        self.pos_embed = copy.deepcopy(vit_base_model.pos_embed)
        self.num_classes = vit_base_model.num_classes
        self.global_pool = vit_base_model.global_pool
        self.num_features = self.head_hidden_size = self.embed_dim
        self.num_reg_tokens = vit_base_model.num_reg_tokens
        self.has_class_token = vit_base_model.has_class_token
        self.no_embed_class = vit_base_model.no_embed_class
        self.dynamic_img_size = vit_base_model.dynamic_img_size
        self.grad_checkpointing = vit_base_model.grad_checkpointing
        self.cls_token = copy.deepcopy(getattr(vit_base_model, 'cls_token', None))
        self.pos_drop = copy.deepcopy(getattr(vit_base_model, 'pos_drop', nn.Identity()))
        self.patch_drop = copy.deepcopy(getattr(vit_base_model, 'patch_drop', nn.Identity()))
        self.norm_pre = copy.deepcopy(getattr(vit_base_model, 'norm_pre', nn.Identity()))
        self.reg_token = copy.deepcopy(getattr(vit_base_model, 'reg_token', None))
        self.num_prefix_tokens = getattr(vit_base_model, 'num_prefix_tokens',
                                         1 if hasattr(vit_base_model, 'cls_token') else 0)
        self.out_indices = out_indices

        self.blocks = nn.ModuleList()
        depth = len(vit_base_model.blocks)

        for i in range(depth):
            base_block = vit_base_model.blocks[i]
            num_heads = base_block.attn.num_heads
            original_mlp_ratio = base_block.mlp.mlp_ratio if hasattr(base_block.mlp, 'mlp_ratio') else 4.0

            if fat_layers_indices and i in fat_layers_indices:
                print(f"Creating a FAT block at index {i} with mlp_ratio={fat_mlp_ratio}")

                self.blocks.append(
                    Fat_Block(
                        dim=self.embed_dim,
                        num_heads=num_heads,
                        mlp_ratio=fat_mlp_ratio,
                        base_block=base_block,
                    )
                )
            else:
                print(f"Creating a REGULAR block at index {i} with mlp_ratio={original_mlp_ratio}")
                self.blocks.append(copy.deepcopy(base_block))

        self.norm = copy.deepcopy(vit_base_model.norm) if hasattr(vit_base_model, 'norm') and vit_base_model.norm is not None else nn.Identity()

    def forward(self, x, condition_embedding=None, is_vfm_condition=True, vfm_teacher_id=None):
        x = self.patch_embed(x)
        x = self._pos_embed(x)
        x = self.patch_drop(x)
        x = self.norm_pre(x)

        for blk in self.blocks:
            x = blk(x)

        if self.norm is not None:
            x = self.norm(x)

        return x, torch.tensor(0.0, device=x.device, dtype=x.dtype), None, None, torch.tensor(0.0, device=x.device, dtype=x.dtype)

    def forward_intermediate_features(self, x, condition_embedding=None, is_vfm_condition = False, vfm_teacher_id = None, return_expert_outputs = False, task_id = None):
        x = self.patch_embed(x)
        x = self._pos_embed(x)
        x = self.patch_drop(x)
        x = self.norm_pre(x)

        output = []
        for i, blk in enumerate(self.blocks):
            x = blk(x)
            if i in self.out_indices:
                feature_out = self.norm(x) if self.norm is not None else x
                output.append(feature_out)

        return output, torch.tensor(0.0, device=x.device, dtype=x.dtype), None, None, torch.tensor(0.0, device=x.device, dtype=x.dtype)

    def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
            if self.pos_embed is None:
                return x.view(x.shape[0], -1, x.shape[-1])

            if self.dynamic_img_size:
                B, H, W, C = x.shape
                pos_embed = resample_abs_pos_embed(
                    self.pos_embed,
                    (H, W),
                    num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
                )
                x = x.view(B, -1, C)
            else:
                pos_embed = self.pos_embed

            to_cat = []
            if self.cls_token is not None:
                to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
            if self.reg_token is not None:
                to_cat.append(self.reg_token.expand(x.shape[0], -1, -1))

            if self.no_embed_class:
                x = x + pos_embed
                if to_cat:
                    x = torch.cat(to_cat + [x], dim=1)
            else:
                if to_cat:
                    x = torch.cat(to_cat + [x], dim=1)
                x = x + pos_embed

            return self.pos_drop(x)

class Condition_MoE_SAK(nn.Module):
    def __init__(
            self,
            img_size,
            vit_name,
            tea_dims: dict = None,
            tasks: list = None,
            tasks_dict: dict = None,
            condition_dim = None,
            num_moe_experts = 15,
            moe_top_k=9,
            moe_layers_indices=None,
            noisy_gating=True,
            task_head_configs: dict = None,
            vfm_projection_configs: dict = None,
            freeze_vit: bool = False,
            freeze: bool = False,
            lora_config: dict = None,
            *args,
            **kwargs,
                 ):
        super().__init__()

        if vit_name == "vit_small":
            from models.backbones.vit import vit_small_patch16_384
            vit_base_model_instance = vit_small_patch16_384(img_size=img_size, pretrained=True)
            self.out_indices = kwargs["out_indices_cfg_for_task"]["small"]
        elif vit_name == "vit_base":
            from models.backbones.vit import vit_base_patch16_384
            vit_base_model_instance = vit_base_patch16_384(img_size=img_size, pretrained=True)
            self.out_indices = kwargs["out_indices_cfg_for_task"]["base"]
        elif vit_name == "vit_large":
            from models.backbones.vit import vit_large_patch16_384
            vit_base_model_instance = vit_large_patch16_384(img_size=img_size, pretrained=True)
            self.out_indices = kwargs["out_indices_cfg_for_task"]["large"]
        else:
            raise NotImplementedError("vit_name not supported")

        if kwargs["vit_checkpoint_path"] is not None:
            global_print("Loading ViT checkpoint from {}".format(kwargs["vit_checkpoint_path"]))
            state_dict = torch.load(kwargs["vit_checkpoint_path"], map_location='cpu')
            state_dict = {k.replace("vit.", ""): v for k, v in state_dict.items()}
            vit_base_model_instance.load_state_dict(state_dict, strict=False)

        self.embed_dim = vit_base_model_instance.embed_dim
        self.fea_size = (img_size[0] // 16, img_size[1] // 16)
        if condition_dim is None:
            self.condition_dim = vit_base_model_instance.embed_dim
        else:
            self.condition_dim = condition_dim

        num_tasks = len(tasks) if tasks else 0
        num_vfm_teachers = len(tea_dims) if tea_dims else 0
        self.tasks = tasks
        self.num_tasks = num_tasks
        self.num_vfm_teachers = num_vfm_teachers
        self.task_condition_embeddings = nn.Embedding(num_tasks, self.condition_dim)
        self.task_composer_weights = nn.Parameter(torch.randn(num_tasks, num_vfm_teachers) * 0.02)
        self.vfm_condition_embeddings = nn.Embedding(num_vfm_teachers, self.condition_dim)
        if self.task_composer_weights.shape[0] == 5:
            with torch.no_grad():
                self.task_composer_weights.fill_(0.0)
                self.task_composer_weights[0] = torch.tensor([2.0, -1.0, -1.0])
                self.task_composer_weights[1] = torch.tensor([2.0, -1.0, -1.0])
                self.task_composer_weights[2] = torch.tensor([-1.0, -1.0, 2.0])
                self.task_composer_weights[3] = torch.tensor([-1.0, -1.0, 2.0])
                self.task_composer_weights[4] = torch.tensor([-1.0, 2.0, -1.0])

        elif self.task_composer_weights.shape[0] == 4:
            with torch.no_grad():
                self.task_composer_weights.fill_(0.0)
                self.task_composer_weights[0] = torch.tensor([2.0, -1.0, -1.0])
                self.task_composer_weights[1] = torch.tensor([-1.0, -1.0, 2.0])
                self.task_composer_weights[2] = torch.tensor([-1.0, -1.0, 2.0])
                self.task_composer_weights[3] = torch.tensor([2.0, -1.0, -1.0])

        self.task_free_condition_embeddings = nn.Embedding(1, self.condition_dim)
        self.vfm_p_drop = kwargs.get("vfm_p_drop", 0.3)

        vit_type = kwargs["vit_type"]
        if vit_type == "moh":
            global_print(f"Initializing MoH (Mixture-of-Heads) Backbone with Top-K={moe_top_k}")
            self.backbone = MoH_ViT(
                vit_base_model=vit_base_model_instance,
                num_selected_heads=moe_top_k,
                moh_layers_indices=moe_layers_indices,
                noisy_gating=noisy_gating,
                out_indices=self.out_indices
            )
        elif vit_type == "conditioned_moe":
            self.backbone = Condition_MoE_ViT(
                vit_base_model=vit_base_model_instance,
                condition_dim=self.condition_dim,
                num_moe_experts=num_moe_experts,
                moe_top_k=moe_top_k,
                expert_hidden_ratio=kwargs["expert_hidden_ratio"],
                moe_layers_indices=moe_layers_indices,
                noisy_gating=noisy_gating,
                out_indices=self.out_indices,
                moe_type=kwargs["moe_type"],
                router_type=kwargs["router_type"],
                gate_constraint_ranges=kwargs["gate_constraint_ranges"],
                num_tasks=self.num_tasks
            )
        elif vit_type == "fat_vit":
            self.backbone = Fat_ViT(
                vit_base_model=vit_base_model_instance,
                fat_mlp_ratio=20.0,
                fat_layers_indices=moe_layers_indices,
                out_indices=self.out_indices,
            )
        else:
            raise NotImplementedError("moe_type not supported")

        del vit_base_model_instance
        torch.cuda.empty_cache()

        if freeze_vit:
            for param in self.backbone.parameters():
                param.requires_grad = False
        elif lora_config:
            from peft import LoraConfig, get_peft_model
            self.backbone = get_peft_model(self.backbone, LoraConfig(**lora_config))
            if tasks:
                global_print("Full fine-tune ViT and LoRA in stage 2")
                for param in self.backbone.parameters():
                    param.requires_grad = True
            self.backbone.print_trainable_parameters()

        self.task_heads = None
        self.vfm_projection_heads = nn.ModuleDict()

        for tea_no in tea_dims.keys():
            if tea_dims[tea_no] == 0:
                break
            p_type = kwargs["vfm_projector_type"]
            if p_type == "tp":
                global_print("################# VFM Projector type is TP ##################")
                for l_ind in range(len(self.out_indices)):
                    self.vfm_projection_heads[tea_no + "_" +str(l_ind)] = build_projector(
                        input_dim=self.embed_dim,
                        output_dim=tea_dims[tea_no],
                        extra_args=None
                    )
            elif p_type == "mlp":
                global_print("################# VFM Projector type is MLP ##################")
                for l_ind in range(len(self.out_indices)):
                    self.vfm_projection_heads[tea_no + "_" + str(l_ind)] = MlpProjector(
                        input_dim=self.embed_dim,
                        output_dim=tea_dims[tea_no],
                        mlp_ratio=1.0
                    )
            else:
                global_print("################# VFM Projector type is linear ##################")
                for l_ind in range(len(self.out_indices)):
                    self.vfm_projection_heads[tea_no + "_" +str(l_ind)] = nn.Linear(self.embed_dim, tea_dims[tea_no])
        self._init_weights_custom()

    def _init_weights_custom(self, m=None):
        if m is None:
            if hasattr(self, 'task_condition_embeddings'): self.task_condition_embeddings.apply(self._init_weights_custom)
            if hasattr(self, 'vfm_condition_embeddings'): self.vfm_condition_embeddings.apply(self._init_weights_custom)
            if hasattr(self, 'task_free_condition_embeddings'): self.task_free_condition_embeddings.apply(self._init_weights_custom)
            if hasattr(self, 'vfm_free_condition_embeddings'): self.vfm_free_condition_embeddings.apply(self._init_weights_custom)
            if hasattr(self, 'task_heads') and self.task_heads: self.task_heads.apply(self._init_weights_custom)
            if hasattr(self, 'vfm_projection_heads') and self.vfm_projection_heads: self.vfm_projection_heads.apply(self._init_weights_custom)
            return

        if isinstance(m, nn.Linear):
            torch.nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Embedding):
            torch.nn.init.trunc_normal_(m.weight, std=0.02)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            torch.nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def get_task_condition_embedding(self, task_ids=None):
        if task_ids is not None:
            return self.task_condition_embeddings(task_ids)
        else:
            raise ValueError("必须提供 task_ids 作为条件。")

    def get_vfm_condition_embedding(self, vfm_teacher_ids, training=False):
        return self.vfm_condition_embeddings(vfm_teacher_ids)

    def forward(self, batch, vfm_training=True, task_training=False):
        if isinstance(batch, list) or isinstance(batch, tuple):
            images = batch[0]
            vfm_ids_from_batch = batch[1] if len(batch) > 1 else None
        elif isinstance(batch, dict):
            images = batch["image"]
            vfm_ids_from_batch = batch.get("vfm_teacher_id", None)
        else:
            images = batch
            vfm_ids_from_batch = None

        batch_size = images.shape[0]
        outputs = {}
        H, W = self.fea_size

        outputs["aux_loss"] = torch.tensor(0.0, device=images.device, dtype=images.dtype)
        outputs["gate_regularization_loss"] = torch.tensor(0.0, device=images.device, dtype=images.dtype)
        outputs["gate_entropy_loss"] = torch.tensor(0.0, device=images.device, dtype=images.dtype)

        if vfm_training:
            if not hasattr(self, 'vfm_projection_heads') or not self.vfm_projection_heads:
                raise ValueError("vfm_training is True, but no vfm_projection_heads found.")
            if vfm_ids_from_batch is None:
                raise ValueError("vfm_training is True, but vfm_teacher_ids are missing.")

            vfm_teacher_ids = vfm_ids_from_batch.to(images.device)
            vfm_condition_embedding = self.get_vfm_condition_embedding(vfm_teacher_ids, training=vfm_training)
            final_features_all_tokens, aux_loss, gate_list, _, gate_regularization_loss = self.backbone.forward_intermediate_features(images, vfm_condition_embedding, is_vfm_condition=True, vfm_teacher_id=vfm_teacher_ids)

            outputs["aux_loss"] += aux_loss
            outputs["gate_regularization_loss"] += gate_regularization_loss

            vfm_student_projections_output = OrderedDict()
            unique_teacher_ids_in_batch = torch.unique(vfm_teacher_ids)
            for teacher_id_tensor in unique_teacher_ids_in_batch:
                teacher_id_str = str(teacher_id_tensor.item())
                mask = (vfm_teacher_ids == teacher_id_tensor)
                masked_features_all_tokens = []
                for l_ind in range(len(self.out_indices)):
                    masked_features_all_tokens.append(final_features_all_tokens[l_ind][mask])

                if masked_features_all_tokens[0].shape[0] == 0:
                    continue

                patch_tokens_projected = []
                for l_ind in range(len(self.out_indices)):
                    projected_feature_tokens_layer = self.vfm_projection_heads[teacher_id_str+"_"+str(l_ind)](masked_features_all_tokens[l_ind])
                    num_prefix_tokens = self.backbone.num_prefix_tokens
                    patch_tokens_projected.append(projected_feature_tokens_layer[:, num_prefix_tokens:])

                H_feat, W_feat = self.fea_size
                current_teacher_batch_size = patch_tokens_projected[0].shape[0]

                reshaped_feature_map = []
                for l_ind in range(len(self.out_indices)):
                    reshaped_feature_map.append(patch_tokens_projected[l_ind].reshape(current_teacher_batch_size, H_feat, W_feat, -1).permute(0, 3, 1, 2).contiguous())

                vfm_student_projections_output[teacher_id_str] = reshaped_feature_map

            outputs["vfm_student_projections"] = vfm_student_projections_output

        if task_training:
            feature_for_tasks = OrderedDict()
            for i in range(len(self.tasks)):
                task_ids = torch.tensor([i], device=images.device, dtype=torch.long)
                task_ids = task_ids.expand(batch_size)
                task_ids_embedding = self.get_task_condition_embedding(task_ids=task_ids)
                final_features_all_tokens, total_aux_loss_task, _, _, gate_regularization_loss  = self.backbone.forward_intermediate_features(images, task_ids_embedding, task_id=task_ids)

                for j, feat in enumerate(final_features_all_tokens):
                    patch_tokens = final_features_all_tokens[j][:, self.backbone.num_prefix_tokens:]
                    current_teacher_batch_size = patch_tokens.shape[0]
                    H_feat, W_feat = self.fea_size
                    final_features_all_tokens[j] = patch_tokens.reshape(current_teacher_batch_size, H_feat, W_feat,
                                                                        -1).permute(0, 3, 1, 2).contiguous()

                feature_for_tasks[self.tasks[i]] = final_features_all_tokens
                outputs["aux_loss"] += total_aux_loss_task / len(self.tasks)
                outputs["gate_regularization_loss"] += gate_regularization_loss
            outputs["feature_for_tasks"] = feature_for_tasks

        elif not vfm_training and not task_training:
            global_print("Default/Inference behavior: No specific task or VFM training active.")
            vfm_condition_embedding = self.get_vfm_condition_embedding(vfm_ids_from_batch.to(images.device), training=vfm_training)
            final_features_all_tokens, aux_loss, gate_list, topk_indices_list = self.backbone.forward_intermediate_features(images, vfm_condition_embedding, is_vfm_condition=False)
            outputs["aux_loss"] = aux_loss

            if self.backbone.num_prefix_tokens > 0:
                feature_for_output = final_features_all_tokens[:, 0]
            else:
                patch_tokens_output = final_features_all_tokens[:, self.backbone.num_prefix_tokens:]
                feature_for_output = torch.mean(patch_tokens_output, dim=1)
            outputs["features"] = feature_for_output

        return outputs

    @torch.no_grad()
    def run_analysis_forward(self, image_batch, task_name):
        self.eval()
        images = image_batch.to(next(self.parameters()).device)
        batch_size = images.shape[0]

        if task_name not in self.tasks:
            raise ValueError(f"Task '{task_name}' not found.")
        task_idx = self.tasks.index(task_name)
        task_ids = torch.tensor([task_idx] * batch_size, device=images.device, dtype=torch.long)
        condition_embedding = self.get_task_condition_embedding(task_ids=task_ids)

        final_features_all_tokens, _, gate_list, topk_indices_list = \
            self.backbone.forward_intermediate_features(images, condition_embedding, is_vfm_condition=False)

        processed_features = []
        H_feat, W_feat = self.fea_size
        for feat_map in final_features_all_tokens:
            patch_tokens = feat_map[:, self.backbone.num_prefix_tokens:]
            b = patch_tokens.shape[0]
            reshaped_map = patch_tokens.reshape(b, H_feat, W_feat, -1).permute(0, 3, 1, 2).contiguous()
            processed_features.append(reshaped_map)

        return processed_features, gate_list, topk_indices_list

    def generate_gate_activation_visuals(self, image_path, task_names_to_compare, patch_size=16):
        from torchvision import transforms
        transform = transforms.Compose([
            transforms.Resize((384, 384)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ])
        original_img = Image.open(image_path).convert("RGB")
        image_tensor = transform(original_img).unsqueeze(0)

        analysis_results = {}
        for task_name in task_names_to_compare:
            _, gate_list, _ = self.run_analysis_forward(image_tensor, task_name)
            analysis_results[task_name] = gate_list[-1]

        num_tasks = len(task_names_to_compare)
        fig, axes = plt.subplots(1, num_tasks + 1, figsize=(6 * (num_tasks + 1), 6))

        axes[0].imshow(original_img)
        axes[0].set_title("Original Image")
        axes[0].axis('off')

        for i, task_name in enumerate(task_names_to_compare):
            gate_values = analysis_results[task_name]
            h, w = image_tensor.shape[2] // patch_size, image_tensor.shape[3] // patch_size
            gate_map = gate_values.reshape(h, w).cpu().numpy()

            gate_map_resized = F.interpolate(
                torch.tensor(gate_map).unsqueeze(0).unsqueeze(0),
                size=original_img.size[::-1],
                mode='bicubic',
                align_corners=False
            ).squeeze().numpy()

            ax = axes[i + 1]
            ax.imshow(original_img)
            im = ax.imshow(gate_map_resized, cmap='magma', alpha=0.7, vmin=0, vmax=1)
            ax.set_title(f"Gate for '{task_name}'")
            ax.axis('off')

        fig.colorbar(im, ax=axes.ravel().tolist(), orientation='vertical', fraction=0.05, pad=0.02)
        plt.suptitle("Gate Activation Heatmap", fontsize=16)
        plt.savefig("gate_activation_heatmap_comparison.png", dpi=300)

    def generate_conditional_affinity_heatmap(self, task_dataloaders, vfm_groups, gate_threshold=0.5):
        affinity_results = {}
        num_experts = self.backbone.blocks[-1].moe_ffn_layer.num_experts

        for task_name, loader in task_dataloaders.items():
            expert_counts = torch.zeros(num_experts, dtype=torch.long)

            for batch in tqdm(loader, desc=f"Analyzing affinity for {task_name}"):
                images = batch[0]
                _, gate_list, topk_indices_list = self.run_analysis_forward(images, task_name)

                gate_values = gate_list[-1]
                router_indices = topk_indices_list[-1]

                moe_needed_mask = (gate_values < gate_threshold).squeeze(-1)
                selected_indices = router_indices[moe_needed_mask]

                if selected_indices.numel() > 0:
                    counts = torch.bincount(selected_indices.flatten(), minlength=num_experts)
                    expert_counts += counts.cpu()

            total_selections = expert_counts.sum()
            affinity_results[task_name] = expert_counts.float() / total_selections if total_selections > 0 else np.zeros(num_experts)

        df = pd.DataFrame.from_dict(affinity_results, orient='index').numpy()
        df_percent = pd.DataFrame(df, index=affinity_results.keys(), columns=[f'E{i}' for i in range(num_experts)])

        plt.figure(figsize=(16, len(task_dataloaders) * 0.9))
        ax = sns.heatmap(df_percent, annot=True, fmt=".2%", cmap="viridis", linewidths=.5)
        plt.title(f"Conditional Expert Affinity (Gate < {gate_threshold})", fontsize=16)
        plt.xlabel("Experts")
        plt.ylabel("Downstream Tasks")
        plt.tight_layout()
        plt.show()