import torch
import torch.nn as nn
import torch.nn.functional as F
from clip_w_local import clip
from clip_w_local.simple_tokenizer import SimpleTokenizer as _Tokenizer
from trainers.models import TextEncoder


class MomentumLossBalancer:
    """Momentum-based loss balancing with dynamic thresholding for NPT"""
    
    def __init__(self, momentum_beta=0.9, adaptation_alpha=0.1, 
                 min_weight_factor=0.1, max_weight_factor=3.0, warmup_steps=50):
        """
        Initialize momentum-based loss balancer
        
        Args:
            momentum_beta: Momentum factor for EMA smoothing (default: 0.9)
            adaptation_alpha: Adaptation strength for weight adjustment (default: 0.1)
            min_weight_factor: Minimum weight scaling factor (default: 0.1)
            max_weight_factor: Maximum weight scaling factor (default: 3.0)
            warmup_steps: Number of steps before applying adaptive weighting (default: 50)
        """
        self.momentum_beta = momentum_beta
        self.adaptation_alpha = adaptation_alpha
        self.min_weight_factor = min_weight_factor
        self.max_weight_factor = max_weight_factor
        self.warmup_steps = warmup_steps
        
        # EMA tracking
        self.ema_global = None
        self.ema_patch = None 
        self.ema_margin = None
        self.step_count = 0
        
    def update_emas(self, loss_global, loss_patch, loss_margin):
        """Update exponential moving averages for loss components"""
        if self.ema_global is None:
            # Initialize EMAs
            self.ema_global = loss_global.item()
            self.ema_patch = loss_patch.item()
            self.ema_margin = loss_margin.item()
        else:
            # Update EMAs with momentum
            self.ema_global = self.momentum_beta * self.ema_global + (1 - self.momentum_beta) * loss_global.item()
            self.ema_patch = self.momentum_beta * self.ema_patch + (1 - self.momentum_beta) * loss_patch.item()
            self.ema_margin = self.momentum_beta * self.ema_margin + (1 - self.momentum_beta) * loss_margin.item()
        
        self.step_count += 1
    
    def get_adaptive_weights(self, base_lambda_patch, base_lambda_margin):
        """
        Compute adaptive weights based on EMA ratios
        
        Args:
            base_lambda_patch: Base weight for patch loss
            base_lambda_margin: Base weight for margin loss
            
        Returns:
            tuple: (adaptive_lambda_patch, adaptive_lambda_margin)
        """
        # During warmup, use base weights
        if self.step_count < self.warmup_steps or self.ema_global is None:
            return base_lambda_patch, base_lambda_margin
        
        # Compute mean loss magnitude for normalization
        mean_loss = (self.ema_global + self.ema_patch + self.ema_margin) / 3.0
        
        # Avoid division by zero
        if mean_loss < 1e-8:
            return base_lambda_patch, base_lambda_margin
        
        # Compute relative ratios (higher ratio means lower relative magnitude)
        ratio_patch = self.ema_patch / mean_loss
        ratio_margin = self.ema_margin / mean_loss
        
        # Adaptive adjustment: increase weight if loss is relatively small
        # Formula: adaptive_lambda = base_lambda * (1 + α * (1/ratio - 1))
        adjustment_patch = 1.0 + self.adaptation_alpha * (1.0 / max(ratio_patch, 1e-8) - 1.0)
        adjustment_margin = 1.0 + self.adaptation_alpha * (1.0 / max(ratio_margin, 1e-8) - 1.0)
        
        # Apply thresholding to prevent extreme adjustments
        adjustment_patch = max(self.min_weight_factor, min(self.max_weight_factor, adjustment_patch))
        adjustment_margin = max(self.min_weight_factor, min(self.max_weight_factor, adjustment_margin))
        
        # Compute adaptive weights
        adaptive_lambda_patch = base_lambda_patch * adjustment_patch
        adaptive_lambda_margin = base_lambda_margin * adjustment_margin
        
        return adaptive_lambda_patch, adaptive_lambda_margin
    
    def get_stats(self):
        """Get current EMA stats for logging"""
        return {
            'ema_global': self.ema_global,
            'ema_patch': self.ema_patch, 
            'ema_margin': self.ema_margin,
            'step_count': self.step_count
        }


class NPTPromptLearner(nn.Module):
    """Nuisance-Prompt Tuning Prompt Learner with dedicated nuisance prompt"""
    
    def __init__(self, cfg, classnames, clip_model):
        super().__init__()
        n_cls = len(classnames)
        n_ctx = cfg.TRAINER.LOCOOP.N_CTX
        ctx_init = cfg.TRAINER.LOCOOP.CTX_INIT
        dtype = clip_model.dtype
        ctx_dim = clip_model.ln_final.weight.shape[0]
        clip_imsize = clip_model.visual.input_resolution
        cfg_imsize = cfg.INPUT.SIZE[0]
        assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"

        if ctx_init:
            # use given words to initialize context vectors
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt).type(dtype)
            ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
            prompt_prefix = ctx_init
        else:
            # random initialization
            if cfg.TRAINER.LOCOOP.CSC:
                print("Initializing class-specific contexts")
                ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype)
            else:
                print("Initializing a generic context")
                ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
            nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)

        print(f'Initial context: "{prompt_prefix}"')
        print(f"Number of context words (tokens): {n_ctx}")

        # Class context vectors
        self.ctx = nn.Parameter(ctx_vectors)  # to be optimized
        
        # NPT: Add nuisance context vector
        print("Initializing nuisance context vector")
        nuisance_ctx_vector = torch.empty(n_ctx, ctx_dim, dtype=dtype)
        nn.init.normal_(nuisance_ctx_vector, std=0.02)
        self.nuisance_ctx = nn.Parameter(nuisance_ctx_vector)  # to be optimized

        classnames = [name.replace("_", " ") for name in classnames]
        name_lens = [len(_Tokenizer().encode(name)) for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]
        
        # NPT: Add nuisance prompt 
        nuisance_prompt = prompt_prefix + " background."
        all_prompts = prompts + [nuisance_prompt]

        tokenized_prompts = torch.cat([clip.tokenize(p) for p in all_prompts])
        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)

        # These token vectors will be saved when in save_model(),
        # but they should be ignored in load_model() as we want to use
        # those computed using the current class names
        self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :])  # CLS, EOS

        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.tokenized_prompts = tokenized_prompts  # torch.Tensor
        self.name_lens = name_lens
        self.class_token_position = cfg.TRAINER.LOCOOP.CLASS_TOKEN_POSITION

    def forward(self):
        ctx = self.ctx
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)

        # NPT: Combine class contexts with nuisance context
        nuisance_ctx = self.nuisance_ctx.unsqueeze(0)  # (1, n_ctx, dim)
        all_ctx = torch.cat([ctx, nuisance_ctx], dim=0)  # (n_cls + 1, n_ctx, dim)

        prefix = self.token_prefix
        suffix = self.token_suffix

        if self.class_token_position == "end":
            prompts = torch.cat(
                [
                    prefix,     # (n_cls + 1, 1, dim)
                    all_ctx,    # (n_cls + 1, n_ctx, dim)
                    suffix,     # (n_cls + 1, *, dim)
                ],
                dim=1,
            )

        elif self.class_token_position == "middle":
            half_n_ctx = self.n_ctx // 2
            prompts = []
            # Handle class prompts + nuisance prompt
            for i in range(self.n_cls + 1):
                if i < self.n_cls:
                    name_len = self.name_lens[i]
                else:
                    # Nuisance prompt - assume similar structure to class prompts
                    name_len = len(_Tokenizer().encode("background"))
                    
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i_half1 = all_ctx[i : i + 1, :half_n_ctx, :]
                ctx_i_half2 = all_ctx[i : i + 1, half_n_ctx:, :]
                prompt = torch.cat(
                    [
                        prefix_i,     # (1, 1, dim)
                        ctx_i_half1,  # (1, n_ctx//2, dim)
                        class_i,      # (1, name_len, dim)
                        ctx_i_half2,  # (1, n_ctx//2, dim)
                        suffix_i,     # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        elif self.class_token_position == "front":
            prompts = []
            # Handle class prompts + nuisance prompt  
            for i in range(self.n_cls + 1):
                if i < self.n_cls:
                    name_len = self.name_lens[i]
                else:
                    # Nuisance prompt
                    name_len = len(_Tokenizer().encode("background"))
                    
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i = all_ctx[i : i + 1, :, :]
                prompt = torch.cat(
                    [
                        prefix_i,  # (1, 1, dim)
                        class_i,   # (1, name_len, dim)
                        ctx_i,     # (1, n_ctx, dim)
                        suffix_i,  # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        else:
            raise ValueError

        return prompts

    def get_class_prompts(self):
        """Get only class prompts (without nuisance)"""
        prompts = self.forward()
        return prompts[:-1]  # Return all except last (nuisance) prompt
        
    def get_nuisance_prompt(self):
        """Get only nuisance prompt"""
        prompts = self.forward()
        return prompts[-1:] # Return last (nuisance) prompt


def extract_attention_weights(visual_model, image_input):
    """
    Extract attention weights from CLIP Vision Transformer
    
    Args:
        visual_model: CLIP visual model (VisionTransformer)
        image_input: Input image tensor [batch, 3, H, W]
    
    Returns:
        background_weights: Background weights for patches [batch, num_patches]
    """
    try:
        # Forward through patch projection
        x = visual_model.conv1(image_input)  # [batch, width, grid, grid]
        x = x.reshape(x.shape[0], x.shape[1], -1)  # [batch, width, grid**2]
        x = x.permute(0, 2, 1)  # [batch, grid**2, width]
        
        # Add class token
        class_embedding = visual_model.class_embedding.to(x.dtype) + torch.zeros(
            x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
        x = torch.cat([class_embedding, x], dim=1)  # [batch, grid**2 + 1, width]
        
        # Add positional embedding
        x = x + visual_model.positional_embedding.to(x.dtype)
        x = visual_model.ln_pre(x)
        
        # Pass through transformer - need to permute for transformer input
        x = x.permute(1, 0, 2)  # [seq_len, batch, width] for transformer
        
        # Get attention weights from the last transformer layer
        last_layer = visual_model.transformer.resblocks[-1]
        
        # Get attention weights using the attention_weight method
        attention_weights = last_layer.attention_weight(x)  # [batch, seq_len, seq_len] or [seq_len, seq_len]
        
        # Handle different attention weight formats
        if attention_weights.dim() == 3:  # [batch, seq_len, seq_len]
            cls_to_patches = attention_weights[:, 0, 1:]  # [batch, num_patches]
        elif attention_weights.dim() == 2:  # [seq_len, seq_len] 
            cls_to_patches = attention_weights[0, 1:].unsqueeze(0)  # [1, num_patches]
            cls_to_patches = cls_to_patches.expand(x.size(1), -1)  # [batch, num_patches]
        else:
            # Fallback - return uniform weights
            batch_size = image_input.size(0)
            num_patches = x.size(0) - 1  # subtract class token
            return torch.ones(batch_size, num_patches, device=image_input.device) * 0.5
        
        # Normalize to get relevance scores [0, 1]
        relevance = torch.softmax(cls_to_patches, dim=-1)
        
        # Background weights: 1 - relevance
        background_weights = 1.0 - relevance
        
        return background_weights
        
    except Exception as e:
        print(f"Warning: Attention extraction failed: {e}")
        # Fallback to uniform weights
        batch_size = image_input.size(0)
        num_patches = 49  # Default for ViT-B/16 with 224x224
        return torch.ones(batch_size, num_patches, device=image_input.device) * 0.5


class NPTCustomCLIP(nn.Module):
    """NPT Custom CLIP with nuisance prompt and attention-based weighting"""
    
    def __init__(self, cfg, classnames, clip_model):
        super().__init__()
        self.prompt_learner = NPTPromptLearner(cfg, classnames, clip_model)
        self.tokenized_prompts = self.prompt_learner.tokenized_prompts
        self.image_encoder = clip_model.visual
        self.text_encoder = TextEncoder(clip_model)
        self.logit_scale = clip_model.logit_scale
        self.dtype = clip_model.dtype
        self.n_cls = len(classnames)
        
        # NPT hyperparameters
        self.lambda_patch = getattr(cfg, 'lambda_patch', 0.25)
        self.lambda_margin = getattr(cfg, 'lambda_margin', 0.25) 
        self.margin = getattr(cfg, 'margin', 0.2)
        
        # Variance-aware attention regularization and entropy maximization hyperparameters
        self.lambda_var = getattr(cfg, 'lambda_var', 0.1)
        self.lambda_entropy = getattr(cfg, 'lambda_entropy', 0.05)
        self.epsilon = getattr(cfg, 'epsilon', 1e-8)
        
        # Momentum-based loss balancer
        momentum_beta = getattr(cfg, 'momentum_beta', 0.9)
        adaptation_alpha = getattr(cfg, 'adaptation_alpha', 0.1)
        min_weight_factor = getattr(cfg, 'min_weight_factor', 0.1)
        max_weight_factor = getattr(cfg, 'max_weight_factor', 3.0)
        warmup_steps = getattr(cfg, 'warmup_steps', 50)
        
        self.loss_balancer = MomentumLossBalancer(
            momentum_beta=momentum_beta,
            adaptation_alpha=adaptation_alpha,
            min_weight_factor=min_weight_factor,
            max_weight_factor=max_weight_factor,
            warmup_steps=warmup_steps
        )

    def forward(self, image):
        # Get image features
        image_features, local_image_features = self.image_encoder(image.type(self.dtype))

        # Get all prompts (class + nuisance)
        prompts = self.prompt_learner()
        tokenized_prompts = self.tokenized_prompts
        text_features = self.text_encoder(prompts, tokenized_prompts)

        # Normalize features
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        local_image_features = local_image_features / local_image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        logit_scale = self.logit_scale.exp()

        # Global logits (for class prompts only)
        class_text_features = text_features[:-1]  # Exclude nuisance prompt
        logits = logit_scale * image_features @ class_text_features.t()
        
        # Local logits (for all prompts including nuisance)
        logits_local = logit_scale * local_image_features @ text_features.t()

        return logits, logits_local, text_features

    def compute_variance_loss(self, attention_weights):
        """
        Compute variance regularization loss to encourage discriminative attention patterns
        
        Args:
            attention_weights: Background weights for patches [batch, num_patches]
            
        Returns:
            var_loss: Variance regularization loss (scalar tensor)
        """
        if attention_weights is None:
            return torch.tensor(0.0, device=next(self.parameters()).device)
        
        # Compute variance across patch dimension for each image
        attention_var = torch.var(attention_weights, dim=1)  # [batch]
        
        # Clip variance to prevent extreme values
        attention_var = torch.clamp(attention_var, min=self.epsilon, max=10.0)
        
        # Encourage higher variance by penalizing low variance (inverse variance penalty)
        # Use a bounded formulation: 1 / (var + epsilon) - 1 / (max_var + epsilon)
        max_var = 1.0  # Maximum expected variance for uniform distribution over patches
        var_loss = torch.mean(1.0 / (attention_var + self.epsilon) - 1.0 / (max_var + self.epsilon))
        
        # Ensure the loss is always finite and positive when variance is low
        var_loss = torch.clamp(var_loss, min=0.0, max=100.0)
        
        return var_loss
    
    def compute_entropy_loss(self, text_features):
        """
        Compute entropy maximization loss for nuisance prompt similarities
        
        Args:
            text_features: Text features for all prompts [num_prompts, dim]
            
        Returns:
            entropy_loss: Entropy maximization loss (scalar tensor)
        """
        if text_features.size(0) <= 1:
            return torch.tensor(0.0, device=text_features.device)
        
        # Get nuisance and class features
        nuisance_features = text_features[-1:]  # Last is nuisance [1, dim]
        class_features = text_features[:-1]     # All except last [num_classes, dim]
        
        # Compute similarities between nuisance and class prompts
        nuisance_similarities = nuisance_features @ class_features.t()  # [1, num_classes]
        nuisance_similarities = nuisance_similarities.squeeze(0)  # [num_classes]
        
        # Convert similarities to probabilities with temperature scaling for numerical stability
        temperature = 1.0
        probs = F.softmax(nuisance_similarities / temperature, dim=-1)  # [num_classes]
        
        # Clamp probabilities to prevent log(0)
        probs = torch.clamp(probs, min=self.epsilon, max=1.0 - self.epsilon)
        
        # Compute entropy: H(p) = -sum(p * log(p))
        log_probs = torch.log(probs)
        entropy = -torch.sum(probs * log_probs)
        
        # Maximize entropy by minimizing negative entropy
        # Clamp to prevent extreme values
        entropy_loss = torch.clamp(-entropy, min=-10.0, max=10.0)
        
        return entropy_loss

    def compute_npt_loss(self, image_features, local_features, text_features, labels, attention_weights=None):
        """
        Compute NPT loss with global, patch, margin, variance, and entropy components using momentum-based balancing
        
        Args:
            image_features: Global image features [batch, dim]
            local_features: Local patch features [batch, num_patches, dim]
            text_features: Text features for all prompts [num_prompts, dim]
            labels: Ground truth labels [batch]
            attention_weights: Background weights for patches [batch, num_patches]
            
        Returns:
            loss_dict: Dictionary with individual loss components and adaptive weights
        """
        batch_size = image_features.size(0)
        logit_scale = self.logit_scale.exp()
        
        # 1. Global loss: Standard cross-entropy on class prompts only
        class_text_features = text_features[:-1]  # Exclude nuisance
        global_logits = logit_scale * image_features @ class_text_features.t()
        loss_global = F.cross_entropy(global_logits, labels)
        
        # 2. Patch loss: Attention-weighted background classification
        loss_patch = torch.tensor(0.0, device=image_features.device)
        if attention_weights is not None:
            # Reshape local features for patch-level classification
            num_patches = local_features.size(1)
            local_flat = local_features.view(-1, local_features.size(-1))  # [batch*patches, dim]
            
            # Compute similarities to all prompts (including nuisance)
            patch_logits = logit_scale * local_flat @ text_features.t()  # [batch*patches, num_prompts]
            patch_logits = patch_logits.view(batch_size, num_patches, -1)  # [batch, patches, num_prompts]
            
            # Background labels (nuisance class index)
            background_labels = torch.full((batch_size, num_patches), self.n_cls, 
                                         device=labels.device, dtype=labels.dtype)
            
            # Attention-weighted cross-entropy
            patch_losses = F.cross_entropy(patch_logits.view(-1, patch_logits.size(-1)), 
                                         background_labels.view(-1), reduction='none')
            patch_losses = patch_losses.view(batch_size, num_patches)  # [batch, patches]
            
            # Weight by attention (background) weights
            weighted_losses = patch_losses * attention_weights
            loss_patch = weighted_losses.mean()
        
        # 3. Margin loss: Repel nuisance from class prompts
        nuisance_features = text_features[-1:]  # Last is nuisance
        class_features = text_features[:-1]     # All except last
        
        # Cosine similarities between nuisance and each class prompt
        nuisance_class_sims = nuisance_features @ class_features.t()  # [1, num_classes]
        nuisance_class_sims = nuisance_class_sims.squeeze(0)  # [num_classes]
        
        # Margin loss: max(0, sim - margin)
        margin_violations = torch.clamp(nuisance_class_sims - self.margin, min=0)
        loss_margin = margin_violations.mean()
        
        # 4. Variance regularization loss: Encourage discriminative attention patterns
        loss_var = self.compute_variance_loss(attention_weights)
        
        # 5. Entropy maximization loss: Encourage diverse nuisance prompt modeling
        loss_entropy = self.compute_entropy_loss(text_features)
        
        # Update momentum-based loss balancer
        self.loss_balancer.update_emas(loss_global, loss_patch, loss_margin)
        
        # Get adaptive weights
        adaptive_lambda_patch, adaptive_lambda_margin = self.loss_balancer.get_adaptive_weights(
            self.lambda_patch, self.lambda_margin
        )
        
        # Total loss with adaptive weighting and regularization
        loss_total = (loss_global + 
                     adaptive_lambda_patch * loss_patch + 
                     adaptive_lambda_margin * loss_margin +
                     self.lambda_var * loss_var + 
                     self.lambda_entropy * loss_entropy)
        
        return {
            'loss_total': loss_total,
            'loss_global': loss_global,
            'loss_patch': loss_patch,
            'loss_margin': loss_margin,
            'loss_var': loss_var,
            'loss_entropy': loss_entropy,
            'adaptive_lambda_patch': adaptive_lambda_patch,
            'adaptive_lambda_margin': adaptive_lambda_margin,
            'base_lambda_patch': self.lambda_patch,
            'base_lambda_margin': self.lambda_margin,
            'lambda_var': self.lambda_var,
            'lambda_entropy': self.lambda_entropy
        }