import torch
import torch.nn as nn
import copy
import logging

from .vit import Embeddings

# A helper class to track and compare the parameter change of each major 
#   component in MoT model based on L2 norm of parameter differences.
class MoTTracker:
    def __init__(self, moe_gating, expert_models):
        self.moe_gating = moe_gating
        self.expert_models = expert_models
        self.prev_state = {}

    def __flatten_params(self, module):
        # if the parameter is already on CPU, .cpu() is a no-op.
        flat_params = [p.detach().view(-1).cpu() for p in module.parameters() if p.requires_grad]
        # when we freeze certain module, this list becomes empty, so we need to check:
        if not flat_params:
            return torch.tensor([0.0])
        return torch.cat(flat_params)
    
    def _compute_scalar_stats(self):
        stats = {}
        stats['shared_embedding'] = torch.norm(self.__flatten_params(self.moe_gating.embedding), p=2).item()
        stats['shared_gating'] = torch.norm(self.moe_gating.theta.detach().view(-1).cpu(), p=2).item()

        for idx, expert in enumerate(self.expert_models):
            enc = expert.encoder
            attn_norm = 0.0
            ffn_norm = 0.0
            for blk in enc.blocks:
                if hasattr(blk, "attention"):
                    attn_norm += torch.norm(self.__flatten_params(blk.attention), p=2).item()
                if hasattr(blk, "layernorm_1"):
                    attn_norm += torch.norm(self.__flatten_params(blk.layernorm_1), p=2).item()
                ffn_norm += torch.norm(self.__flatten_params(blk.mlp), p=2).item()
                ffn_norm += torch.norm(self.__flatten_params(blk.layernorm_2), p=2).item()
            stats[f'expert_{idx}_encoder_attn'] = attn_norm
            stats[f'expert_{idx}_encoder_ffn'] = ffn_norm
            stats[f'expert_{idx}_classifier'] = torch.norm(self.__flatten_params(expert.classifier), p=2).item()

        return stats
    
    def _compare_with_previous(self, new_stats):
        if not self.prev_state:
            self.prev_state = copy.deepcopy(new_stats)
            return {k: 0.0 for k in new_stats}
        deltas = {}
        for key in new_stats:
            deltas[key] = abs(new_stats[key] - self.prev_state.get(key, 0.0))
        self.prev_state = copy.deepcopy(new_stats)
        return deltas
    
    def check(self, simplify=True):
        # @Behavior: called at the beginning of each epoch
        def __fv(val):
            # format value
            if simplify and abs(val) < 1e-4:
                return "N/A"
            return f"{val:.4f}"
        logger = logging.getLogger()
        current_stats = self._compute_scalar_stats()
        deltas = self._compare_with_previous(current_stats)
        
        logger.info(f"TRACKER: shared_embedding = {__fv(deltas.get('shared_embedding', 0.0))}, shared_gating = {__fv(deltas.get('shared_gating', 0.0))}")
        # Prepare expert-specific rows
        attn_row, ffn_row, clf_row = [], [], []
        num_experts = sum("encoder_attn" in k for k in deltas)

        for i in range(num_experts):
            attn_val = __fv(deltas.get(f'expert_{i}_encoder_attn', 0.0))
            ffn_val  = __fv(deltas.get(f'expert_{i}_encoder_ffn', 0.0))
            clf_val  = __fv(deltas.get(f'expert_{i}_classifier', 0.0))
            attn_row.append(f"E{i}:{attn_val}")
            ffn_row.append(f"E{i}:{ffn_val}")
            clf_row.append(f"E{i}:{clf_val}")

        logger.info(f"TRACKER: encoder_attn     = {'  '.join(attn_row)}")
        logger.info(f"TRACKER: encoder_ffn      = {'  '.join(ffn_row)}")
        logger.info(f"TRACKER: classifier       = {'  '.join(clf_row)}")


# Definition of our Gating module.
class ViTGating(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.M = config['n_expert']
        self.d = config["hidden_size"]
        self.embedding = Embeddings(config)
        self.theta = nn.Parameter(torch.empty(self.M, self.d))
        nn.init.normal_(self.theta, mean=0.0, std=0.02)
        self.apply(self._init_weights)

    def forward(self, x, r=None):
        X = self.embedding(x) # [B, L, d]
        batch_size = X.shape[0]
        h = (X @ self.theta.T).sum(dim=1) # [B, M]
        r = r[:batch_size, :self.M]
        m = torch.argmax(h + r, dim=1)
        pi = torch.softmax(h, dim=-1)
        return X, m, pi
    
    def _init_weights(self, module):
        if isinstance(module, Embeddings):
            module.position_embeddings.data = nn.init.trunc_normal_(
                module.position_embeddings.data.to(torch.float32),
                mean=0.0,
                std=self.config["initializer_range"],
            ).to(module.position_embeddings.dtype)

            module.cls_token.data = nn.init.trunc_normal_(
                module.cls_token.data.to(torch.float32),
                mean=0.0,
                std=self.config["initializer_range"],
            ).to(module.cls_token.dtype)
    