import torch
import torch.nn as nn
from attention import Attention, NewAttention
from language_model import WordEmbedding, QuestionEmbedding
from classifier import SimpleClassifier
from fc import FCNet
import numpy as np


import torch
import torch.nn as nn
import torch.nn.functional as F

# If MLP is not defined, needs to be added (can be referenced below)
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim):
        super().__init__()
        layers = []
        current_dim = input_dim
        for h_dim in hidden_dims:
            layers.append(nn.Linear(current_dim, h_dim))
            layers.append(nn.ReLU())
            # Add Dropout within MLP for MC Dropout potential
            layers.append(nn.Dropout(0.1)) # Example dropout rate
            current_dim = h_dim
        layers.append(nn.Linear(current_dim, output_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

class AdaptiveUncertaintyModule(nn.Module):
    """Module to estimate uncertainty and generate an adaptive modulation factor"""
    def __init__(self, logit_dim, hidden_dim=128, method='entropy', 
                 uncertainty_sources=['vq_only'], is_va=True, num_ensemble_heads=5):
        super().__init__()
        if method not in ['entropy', 'margin', 'ensemble_disagreement']:
            raise ValueError(f"Unknown uncertainty method: {method}")
        self.method = method
        self.uncertainty_sources = uncertainty_sources
        self.logit_dim = logit_dim
        self.is_va = is_va
        
        # Determine the input dimension for the alpha_net based on sources
        num_sources = 0
        if 'vq_only' in uncertainty_sources: num_sources += 1
        if 'q_only' in uncertainty_sources: num_sources += 1
        if 'v_only' in uncertainty_sources and self.is_va: num_sources += 1
        if num_sources == 0: 
            raise ValueError("No valid uncertainty sources provided.")
            
        self.alpha_input_dim = num_sources

        # Network to map uncertainty metric(s) to alpha for modulation
        # Used for entropy/margin
        self.alpha_net = nn.Sequential(
            nn.Linear(self.alpha_input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()  # Alpha should be between 0 and 1
        )
        
        # Ensemble heads for ensemble_disagreement method
        self.alpha_heads = None
        if self.method == 'ensemble_disagreement':
             # Input dimension is based on concatenated *logits* from sources
             ensemble_input_dim = 0
             if 'vq_only' in uncertainty_sources: ensemble_input_dim += self.logit_dim
             if 'q_only' in uncertainty_sources: ensemble_input_dim += self.logit_dim
             if 'v_only' in uncertainty_sources and self.is_va: ensemble_input_dim += self.logit_dim
             if ensemble_input_dim == 0:
                 raise ValueError("No valid uncertainty sources for ensemble disagreement.")

             self.num_ensemble_heads = num_ensemble_heads
             self.alpha_heads = nn.ModuleList()
             for _ in range(self.num_ensemble_heads):
                 # Each head is a small MLP mapping concatenated logits -> alpha
                 head = nn.Sequential(
                     nn.Linear(ensemble_input_dim, hidden_dim // 2), # Smaller hidden dim for heads
                     nn.ReLU(),
                     nn.Linear(hidden_dim // 2, 1),
                     nn.Sigmoid() # Each head predicts alpha directly
                 )
                 self.alpha_heads.append(head)
                 
             # Final mapping from disagreement (variance) to output alpha
             # Input to this is scalar variance, output is scalar alpha
             self.disagreement_to_alpha_net = nn.Sequential(
                 nn.Linear(1, hidden_dim // 4),
                 nn.ReLU(),
                 nn.Linear(hidden_dim // 4, 1),
                 nn.Sigmoid()
             )

    def _calculate_uncertainty(self, logits):
        """Calculates uncertainty for entropy or margin methods."""
        if logits is None:
            # Return a default high uncertainty if a source is missing (e.g., v_only when is_va=False)
            # Or handle this case differently if needed.
            return torch.ones(1, device=self.alpha_net[0].weight.device) * np.log(self.logit_dim) # Max entropy
            
        if self.method == 'entropy':
            # Use entropy of probability distribution as uncertainty metric
            probs = F.softmax(logits, dim=-1)
            entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1, keepdim=True)
            # Normalize entropy to a reasonable range for the alpha_net
            normalized_entropy = entropy / np.log(logits.size(-1))  # Normalize by log(num_classes)
            return normalized_entropy
        elif self.method == 'margin':
            # Calculate the margin between the top two predicted classes
            probs = F.softmax(logits, dim=-1)
            top_probs, _ = torch.topk(probs, k=2, dim=-1)
            margins = (top_probs[:, 0] - top_probs[:, 1]).unsqueeze(-1)
            # Invert and normalize margins: 1 - margin means higher values = higher uncertainty
            normalized_margins = 1 - margins  # Higher value = higher uncertainty
            return normalized_margins
        else:
            # This part should not be reached if method is mc_variance due to forward pass logic
            raise ValueError(f"Unknown uncertainty method in _calculate_uncertainty: {self.method}")
    
    def _calculate_ensemble_disagreement(self, logits_vq, logits_q, logits_v):
        """ Calculates alpha based on disagreement among ensemble heads. """
        if not self.alpha_heads:
            raise ValueError("Ensemble heads not initialized for ensemble_disagreement method.")
            
        inputs = []
        default_logits = torch.zeros((logits_vq.size(0), self.logit_dim), device=logits_vq.device)
        
        if 'vq_only' in self.uncertainty_sources:
            inputs.append(logits_vq)
        if 'q_only' in self.uncertainty_sources:
            inputs.append(logits_q if logits_q is not None else default_logits)
        if 'v_only' in self.uncertainty_sources and self.is_va:
             inputs.append(logits_v if logits_v is not None else default_logits)
        elif 'v_only' in self.uncertainty_sources and not self.is_va:
             # If v_only requested but not available, maybe append zeros or skip?
             # Appending zeros seems safer to maintain consistent input dim
             inputs.append(default_logits)
             
        if not inputs:
            raise ValueError("No valid logit sources found for ensemble disagreement.")
            
        combined_logits = torch.cat(inputs, dim=-1) # [batch, ensemble_input_dim]
        
        # Get predictions from all heads
        head_alphas = []
        for head in self.alpha_heads:
            head_alphas.append(head(combined_logits))
            
        # Stack predictions: [num_heads, batch, 1]
        stacked_alphas = torch.stack(head_alphas, dim=0)
        
        # Calculate disagreement (variance across heads)
        # Keep unbiased=True for sample variance, False might be okay too
        disagreement_variance = torch.var(stacked_alphas, dim=0, unbiased=True) # [batch, 1]
        
        # Map disagreement variance to final alpha
        # Normalize variance before mapping (e.g., simple scaling)
        normalized_disagreement = torch.sigmoid(disagreement_variance * 10) # Heuristic scaling
        final_alpha = self.disagreement_to_alpha_net(normalized_disagreement)
        
        return final_alpha

    def forward(self, logits_vq, logits_q, logits_v):
        """Calculates alpha based on the configured uncertainty sources.
           Inputs: Single tensors [batch, dim] or MC samples [samples, batch, dim]."""
           
        if self.method == 'ensemble_disagreement':
             # This method requires single logit tensors, not samples
             return self._calculate_ensemble_disagreement(logits_vq, logits_q, logits_v)
             
        # --- Logic for Entropy, Margin --- 
        uncertainty_signals = []
        calculator = self._calculate_uncertainty

        if 'vq_only' in self.uncertainty_sources:
            uncertainty_signals.append(calculator(logits_vq))
        
        if 'q_only' in self.uncertainty_sources:
            uncertainty_signals.append(calculator(logits_q))
            
        if 'v_only' in self.uncertainty_sources and self.is_va:
            uncertainty_signals.append(calculator(logits_v))
        elif 'v_only' in self.uncertainty_sources and not self.is_va:
             # Handle missing visual source consistently
             uncertainty_signals.append(calculator(None))

        # Concatenate the selected uncertainty signals
        # Ensure all signals have shape [batch, 1]
        processed_signals = []
        batch_size = -1
        for sig in uncertainty_signals:
            if sig is not None:
                if batch_size == -1:
                    batch_size = sig.size(0)
                # Ensure shape [batch, 1]
                processed_signals.append(sig.view(batch_size, 1))
            else: # Should not happen if calculator returns default for None
                raise ValueError("Unexpected None signal") 
                
        if not processed_signals:
             raise ValueError("No uncertainty signals generated.")
             
        combined_uncertainty = torch.cat(processed_signals, dim=-1) # Shape [batch, num_sources]
            
        # Generate alpha using the standard alpha_net (mc_variance net removed)
        alpha = self.alpha_net(combined_uncertainty)
        return alpha

class AdaptiveCFVQAModel(nn.Module):
    """Enhanced CFVQAModel with adaptive causal intervention based on uncertainty"""
    def __init__(self, base_model, output_size, classif_q, classif_v, is_va=True,
                 adaptive_method='entropy', uncertainty_hidden_dim=128, alpha_reg_weight=0.1,
                 alpha_uncertainty_sources=['vq_only'], num_ensemble_heads=5):
        super().__init__()
        self.base_model = base_model
        self.output_size = output_size
        self.is_va = is_va  # Whether to use the visual branch
        if adaptive_method not in ['entropy', 'margin', 'ensemble_disagreement']:
             raise ValueError(f"Unknown adaptive method: {adaptive_method}")
        self.adaptive_method = adaptive_method
        self.alpha_reg_weight = alpha_reg_weight
        self.num_ensemble_heads = num_ensemble_heads # Store this

        # Question branch
        self.q_1 = MLP(**classif_q)
        self.q_2 = nn.Linear(output_size, output_size)

        # Visual branch (if enabled)
        if self.is_va:
            self.v_1 = MLP(**classif_v)
            self.v_2 = nn.Linear(output_size, output_size)

        self.constant = nn.Parameter(torch.tensor(0.0))

        # Uncertainty estimation module
        self.uncertainty_module = AdaptiveUncertaintyModule(
            output_size, 
            hidden_dim=uncertainty_hidden_dim,
            method=adaptive_method,
            uncertainty_sources=alpha_uncertainty_sources,
            is_va=self.is_va,
            num_ensemble_heads=self.num_ensemble_heads # Pass this
        )

    def forward(self, v, _, q, labels, bias):
        
        # --- Standard Forward Pass --- 
        # Get base model outputs
        base_output = self.base_model(v, _, q, labels=None, bias=None)
        logits_vq = base_output["logits"]
        q_emb = base_output["q_emb"]
        v_emb = base_output["v_emb"]
        
        # Q and V branches
        q_pred = self.q_1(q_emb)
        q_out = self.q_2(q_pred)
        if self.is_va:
            v_pred = self.v_1(v_emb)
            v_out = self.v_2(v_pred)
        else:
            v_pred = None
            v_out = None
            
        # Calculate alpha using single predictions (works for all remaining methods)
        alpha = self.uncertainty_module(logits_vq, q_out, v_out)
        
        # --- Main Prediction and Loss Calculation --- 
        # Continuous intervention factors derived from alpha
        k_fact_continuous = 1 - alpha  
        v_fact_continuous = 1 - alpha if self.is_va else False
        
        # Calculate z_qkv (Total Effect) using sum fusion (intervention is not applied inside sum_fusion)
        z_qkv = self.sum_fusion(logits_vq, q_pred, v_pred, q_emb, q_fact=True, k_fact=True, v_fact=True)
        
        # Calculate z_q_adaptive (Counterfactual base using ORIGINAL sum_fusion with ADAPTIVE intervention)
        # 1. Apply adaptive intervention
        z_k_adapt_intervened, z_q_adapt_intervened, z_v_adapt_intervened = \
            self._apply_intervention(logits_vq, q_pred, v_pred, 
                                     q_fact=True, k_fact=k_fact_continuous, v_fact=v_fact_continuous)
        # 2. Fuse using the original sum_fusion method (intervention already applied)
        z_q_adaptive = self.sum_fusion(z_k_adapt_intervened, z_q_adapt_intervened, z_v_adapt_intervened, 
                                     q_emb, q_fact=True, k_fact=True, v_fact=True)
                                     
        # Traditional counterfactual (z_nde - for KL loss calculation using ORIGINAL sum_fusion with FULL intervention)
        # 1. Apply full intervention (discrete transform)
        z_k_full_intervened, z_q_full_intervened, z_v_full_intervened = \
             self.transform(logits_vq.detach(), 
                           q_pred.detach(), 
                           v_pred.detach() if v_pred is not None else None, 
                           q_fact=True, k_fact=False, v_fact=False)
        # 2. Fuse using the original sum_fusion method with detached inputs
        z_nde = self.sum_fusion(z_k_full_intervened, z_q_full_intervened, z_v_full_intervened, 
                              q_emb.detach(), q_fact=True, k_fact=True, v_fact=True)

        # Counterfactual prediction
        logits_cfvqa = z_qkv - z_q_adaptive
        
        # If labels are provided, calculate loss
        if labels is not None:
            class_id = labels.squeeze(1)
            fusion_loss = nn.CrossEntropyLoss()(z_qkv, class_id)
            question_loss = nn.CrossEntropyLoss()(q_out, class_id)
            if self.is_va:
                vision_loss = nn.CrossEntropyLoss()(v_out, class_id)
            else:
                vision_loss = 0.0

            p_te = F.softmax(z_qkv, dim=-1).clone().detach()
            p_nde = F.softmax(z_nde, dim=-1)
            kl_loss = - (p_te * p_nde.log()).sum(1).mean()
            
            # Add regularization to prevent the uncertainty module from always predicting extreme values
            alpha_regularization = self.alpha_reg_weight * (alpha * (1 - alpha)).mean()  # Encourage values away from 0 and 1
            
            # Loss weights (can be adjusted via parameters)
            question_loss_weight = 1.0
            vision_loss_weight = 1.0 if self.is_va else 0.0
            loss = fusion_loss + question_loss_weight * question_loss + kl_loss + vision_loss_weight * vision_loss - alpha_regularization
            pred = z_qkv
        else:
            loss = None
            pred = logits_cfvqa

        # Use fusion prediction as final output
        return pred, loss

    def sum_fusion(self, z_k, z_q, z_v, q_emb, q_fact, k_fact, v_fact):
        """Original sum fusion logic (accepts q_emb unused, ignores intervention factors if applied before call)"""
        # Intervention should be applied *before* calling this if needed.
        # This method now just implements the core original fusion.
        # If q_fact/k_fact/v_fact are passed True/False (discrete case from CFVQAModel),
        # apply the original transform.
        if not isinstance(k_fact, torch.Tensor) and not isinstance(v_fact, torch.Tensor):
             z_k, z_q, z_v = self.transform(z_k, z_q, z_v, q_fact, k_fact, v_fact)

        # Original fusion calculation
        if self.is_va:
            z = z_k + z_q + z_v
        else:
            z = z_k + z_q
        # Keep the original log-sigmoid transformation
        z = torch.log(torch.sigmoid(z) + 1e-12)
        return z

    def _apply_intervention(self, z_k, z_q, z_v, q_fact, k_fact, v_fact):
        """Helper to apply intervention factors (discrete or continuous)"""
        # If k_fact or v_fact are tensors (for adaptive modulation), handle them differently
        if isinstance(k_fact, torch.Tensor) or isinstance(v_fact, torch.Tensor):
            # For continuous intervention, we interpolate between original and zeroed values
            if not q_fact:
                z_q = self.constant * torch.ones_like(z_q).to(z_q.device)
            
            # Apply k_fact as a continuous scaling factor (element-wise multiplication)
            if isinstance(k_fact, torch.Tensor):
                # Ensure k_fact has the same shape as z_k for element-wise mult
                # k_fact is [batch, 1], z_k is [batch, output_size]
                k_fact_expanded = k_fact.expand_as(z_k)
                z_k_scaled = k_fact_expanded * z_k
                z_k_zeroed = self.constant * torch.ones_like(z_k).to(z_k.device)
                z_k = z_k_scaled + (1 - k_fact_expanded) * z_k_zeroed
            elif not k_fact:
                z_k = self.constant * torch.ones_like(z_k).to(z_k.device)
            
            # Apply v_fact similarly if visual branch is used
            if self.is_va:
                if isinstance(v_fact, torch.Tensor):
                    # Ensure v_fact has the same shape as z_v
                    v_fact_expanded = v_fact.expand_as(z_v)
                    z_v_scaled = v_fact_expanded * z_v
                    z_v_zeroed = self.constant * torch.ones_like(z_v).to(z_v.device)
                    z_v = z_v_scaled + (1 - v_fact_expanded) * z_v_zeroed
                elif not v_fact:
                     # Handle case where v_fact is False but v_branch exists
                     z_v = self.constant * torch.ones_like(z_v).to(z_v.device)
            # If not self.is_va, z_v is not modified unless explicitly passed a tensor v_fact
            # (which shouldn't happen if is_va is False, but handle defensively)
            elif isinstance(v_fact, torch.Tensor):
                 # This case is unlikely if is_va is False, but prevents errors
                 pass # Don't modify z_v if there's no visual branch conceptually
        else:
            # Original discrete behavior (uses transform method)
            z_k, z_q, z_v = self.transform(z_k, z_q, z_v, q_fact, k_fact, v_fact)
            
        return z_k, z_q, z_v

    def transform(self, z_k, z_q, z_v, q_fact, k_fact, v_fact):
        if not k_fact:
            z_k = self.constant * torch.ones_like(z_k).to(z_k.device)
        if not q_fact:
            z_q = self.constant * torch.ones_like(z_q).to(z_q.device)
        if self.is_va and not v_fact:
            z_v = self.constant * torch.ones_like(z_v).to(z_v.device)
        return z_k, z_q, z_v

class CFVQAModel(nn.Module):
    def __init__(self, base_model, output_size, classif_q, classif_v, is_va=True):
        super().__init__()
        self.base_model = base_model
        self.output_size = output_size
        self.is_va = is_va  # Whether to use the visual branch

        # Question branch
        self.q_1 = MLP(**classif_q)
        self.q_2 = nn.Linear(output_size, output_size)

        # Visual branch (if enabled)
        if self.is_va:
            self.v_1 = MLP(**classif_v)
            self.v_2 = nn.Linear(output_size, output_size)

        self.constant = nn.Parameter(torch.tensor(0.0))
        # Fusion is always sum_fusion now
        self.fusion = self.sum_fusion

    def sum_fusion(self, z_k, z_q, z_v, q_emb, q_fact, k_fact, v_fact):
        """Original sum fusion logic for CFVQAModel (accepts q_emb unused)"""
        # Apply discrete intervention using transform
        z_k, z_q, z_v = self.transform(z_k, z_q, z_v, q_fact, k_fact, v_fact)

        # Original fusion calculation
        if self.is_va:
            z = z_k + z_q + z_v
        else:
            z = z_k + z_q
        # Keep the original log-sigmoid transformation
        z = torch.log(torch.sigmoid(z) + 1e-12)
        return z

    def transform(self, z_k, z_q, z_v, q_fact, k_fact, v_fact):
        if not k_fact:
            z_k = self.constant * torch.ones_like(z_k).to(z_k.device)
        if not q_fact:
            z_q = self.constant * torch.ones_like(z_q).to(z_q.device)
        if self.is_va and not v_fact:
            z_v = self.constant * torch.ones_like(z_v).to(z_v.device)
        return z_k, z_q, z_v

    def forward(self, v, _, q, labels, bias):
        # Get base model outputs
        base_output = self.base_model(v, _, q, labels=None, bias=None)
        logits_vq = base_output["logits"]
        q_emb = base_output["q_emb"]
        v_emb = base_output["v_emb"]
        
        # Question branch
        q_pred = self.q_1(q_emb)
        q_out = self.q_2(q_pred)
        # Visual branch
        if self.is_va:
            v_pred = self.v_1(v_emb)
            v_out = self.v_2(v_pred)
        else:
            v_pred = None
            v_out = None

        # Fusion prediction - Pass q_emb now!
        z_qkv = self.fusion(logits_vq, q_pred, v_pred, q_emb, q_fact=True, k_fact=True, v_fact=True)#z_te
        z_q = self.fusion(logits_vq, q_pred, v_pred, q_emb, q_fact=True, k_fact=False, v_fact=False)#z_nde
        logits_cfvqa = z_qkv - z_q  # Counterfactual prediction

        # Calculate z_nde for KL divergence - Pass q_emb!
        z_nde = self.fusion(logits_vq.detach(), q_pred.detach(),
                          v_pred.detach() if v_pred is not None else None,
                          q_emb.detach(), # Pass detached q_emb
                          q_fact=True, k_fact=False, v_fact=False)

        # If labels are provided, calculate loss
        if labels is not None:
            class_id = labels.squeeze(1)
            fusion_loss = nn.CrossEntropyLoss()(z_qkv, class_id)
            question_loss = nn.CrossEntropyLoss()(q_out, class_id)
            if self.is_va:
                vision_loss = nn.CrossEntropyLoss()(v_out, class_id)
            else:
                vision_loss = 0.0

            p_te = F.softmax(z_qkv, dim=-1).clone().detach()
            p_nde = F.softmax(z_nde, dim=-1)
            kl_loss = - (p_te * p_nde.log()).sum(1).mean()
            
            # Loss weights (can be adjusted via parameters)
            question_loss_weight = 1.0
            vision_loss_weight = 1.0 if self.is_va else 0.0
            loss = fusion_loss + question_loss_weight * question_loss + kl_loss + vision_loss_weight * vision_loss
            pred = z_qkv
        else:
            loss = None
            pred = logits_cfvqa

        # Use fusion prediction as final output
        return pred, loss


class BaseModel(nn.Module):
    def __init__(self, w_emb, q_emb, v_att, q_net, v_net, classifier):
        super(BaseModel, self).__init__()
        self.w_emb = w_emb
        self.q_emb = q_emb
        self.v_att = v_att
        self.q_net = q_net
        self.v_net = v_net
        self.classifier = classifier
        self.debias_loss_fn = None
        # self.bias_scale = torch.nn.Parameter(torch.from_numpy(np.ones((1, ), dtype=np.float32)*1.2))
        self.bias_lin = torch.nn.Linear(1024, 1)

    def forward(self, v, _, q, labels, bias, return_weights=False):
        """Forward

        v: [batch, num_objs, obj_dim]
        b: [batch, num_objs, b_dim]
        q: [batch_size, seq_length]

        return: logits, not probs
        """
        w_emb = self.w_emb(q)
        q_emb = self.q_emb(w_emb) # [batch, q_dim]

        att = self.v_att(v, q_emb)
        v_emb = (att * v).sum(1) # [batch, v_dim]

        q_repr = self.q_net(q_emb)
        v_repr = self.v_net(v_emb)
        joint_repr = q_repr * v_repr
        logits = self.classifier(joint_repr)

        if labels is not None:
          if return_weights:
            return self.debias_loss_fn(joint_repr, logits, bias, labels, True)
          loss = self.debias_loss_fn(joint_repr, logits, bias, labels)
        else:
          # Return dictionary with necessary intermediate features
          return {
              "logits": logits,
              "q_emb": q_emb,
              "v_emb": v_emb # The attended visual features
          }


def build_baseline0(dataset, num_hid):
    w_emb = WordEmbedding(dataset.dictionary.ntoken, 300, 0.0)
    q_emb = QuestionEmbedding(300, num_hid, 1, False, 0.0)
    v_att = Attention(dataset.v_dim, q_emb.num_hid, num_hid)
    q_net = FCNet([num_hid, num_hid])
    v_net = FCNet([dataset.v_dim, num_hid])
    classifier = SimpleClassifier(
        num_hid, 2 * num_hid, dataset.num_ans_candidates, 0.5)
    return BaseModel(w_emb, q_emb, v_att, q_net, v_net, classifier)


def build_baseline0_newatt(dataset, num_hid):
    w_emb = WordEmbedding(dataset.dictionary.ntoken, 300, 0.0)
    q_emb = QuestionEmbedding(300, num_hid, 1, False, 0.0)
    v_att = NewAttention(dataset.v_dim, q_emb.num_hid, num_hid)
    q_net = FCNet([q_emb.num_hid, num_hid])
    v_net = FCNet([dataset.v_dim, num_hid])
    classifier = SimpleClassifier(
        num_hid, num_hid * 2, dataset.num_ans_candidates, 0.5)
    return BaseModel(w_emb, q_emb, v_att, q_net, v_net, classifier)
