import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import deque


class CausalIntervention(nn.Module):
    """
    Causal Intervention Module

    Implements intervention operations in GCIB:
    1. Store causal and non-causal representations in memory
    2. Randomly combine to create intervention pairs
    3. Calculate intervention loss
    """

    def __init__(self, memory_size=1000, feature_dim=128, intervention_strategy='random'):
        super().__init__()
        self.memory_size = memory_size
        self.feature_dim = feature_dim
        self.intervention_strategy = intervention_strategy

        # Memory (using deque to limit size)
        self.causal_memory = deque(maxlen=memory_size)
        self.non_causal_memory = deque(maxlen=memory_size)

        # Intervention classifier (for intervention prediction)
        self.intervention_classifier = nn.Sequential(
            nn.Linear(feature_dim * 2, feature_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(feature_dim, feature_dim // 2),
            nn.ReLU(),
            nn.Linear(feature_dim // 2, 1),
            nn.Sigmoid()
        )

        # Causal classifier (for causal representation classification)
        self.causal_classifier = nn.Sequential(
            nn.Linear(feature_dim, feature_dim // 2),
            nn.ReLU(),
            nn.Linear(feature_dim // 2, 1)
        )

        # Non-causal classifier (for non-causal representation classification)
        self.non_causal_classifier = nn.Sequential(
            nn.Linear(feature_dim, feature_dim // 2),
            nn.ReLU(),
            nn.Linear(feature_dim // 2, 1)
        )

        # Initialize memory statistics
        self.memory_stats = {
            'causal_size': 0,
            'non_causal_size': 0,
            'intervention_count': 0
        }

    def update_memory(self, causal_reps, non_causal_reps, labels=None):
        """
        Update memory

        Args:
            causal_reps: [batch_size, feature_dim] causal representations
            non_causal_reps: [batch_size, feature_dim] non-causal representations
            labels: [batch_size] labels (optional)
        """
        batch_size = causal_reps.shape[0]

        # Separate into lists
        causal_list = causal_reps.detach().cpu().numpy()
        non_causal_list = non_causal_reps.detach().cpu().numpy()

        # Add label information (if provided)
        if labels is not None:
            labels_list = labels.detach().cpu().numpy()
            causal_list = [(causal_list[i], labels_list[i]) for i in range(batch_size)]
            non_causal_list = [(non_causal_list[i], labels_list[i]) for i in range(batch_size)]

        # Add to memory
        self.causal_memory.extend(causal_list)
        self.non_causal_memory.extend(non_causal_list)

        # Update statistics
        self.memory_stats['causal_size'] = len(self.causal_memory)
        self.memory_stats['non_causal_size'] = len(self.non_causal_memory)

    def generate_intervention_pairs(self, batch_size, device='cpu'):
        """
        Generate intervention pairs

        Intervention pair = causal representation A + non-causal representation B (cut dependency)

        Args:
            batch_size: batch size
            device: device

        Returns:
            intervened_pairs: [batch_size, feature_dim * 2]
            intervention_labels: [batch_size] intervention labels (based on causal part labels)
        """
        if len(self.causal_memory) == 0 or len(self.non_causal_memory) == 0:
            # Memory is empty, return random data
            random_pairs = torch.randn(batch_size, self.feature_dim * 2, device=device)
            random_labels = torch.randint(0, 2, (batch_size,), device=device).float()
            return random_pairs, random_labels

        intervened_pairs = []
        intervention_labels = []

        for _ in range(batch_size):
            # Randomly select causal representation
            causal_idx = np.random.randint(0, len(self.causal_memory))
            # Randomly select non-causal representation
            non_causal_idx = np.random.randint(0, len(self.non_causal_memory))

            # Get representations
            if isinstance(self.causal_memory[causal_idx], tuple):
                causal_rep, causal_label = self.causal_memory[causal_idx]
                causal_rep = torch.tensor(causal_rep, device=device)
            else:
                causal_rep = torch.tensor(self.causal_memory[causal_idx], device=device)
                causal_label = 0  # Default label

            if isinstance(self.non_causal_memory[non_causal_idx], tuple):
                non_causal_rep, _ = self.non_causal_memory[non_causal_idx]
                non_causal_rep = torch.tensor(non_causal_rep, device=device)
            else:
                non_causal_rep = torch.tensor(self.non_causal_memory[non_causal_idx], device=device)

            # Combine into intervention pair
            intervened_pair = torch.cat([causal_rep, non_causal_rep], dim=-1)

            intervened_pairs.append(intervened_pair)
            intervention_labels.append(causal_label)

        intervened_pairs = torch.stack(intervened_pairs)
        intervention_labels = torch.tensor(intervention_labels, device=device).float()

        # Update statistics
        self.memory_stats['intervention_count'] += batch_size

        return intervened_pairs, intervention_labels

    def compute_intervention_loss(self, intervened_pairs, labels):
        """
        Compute intervention loss

        Args:
            intervened_pairs: [batch_size, feature_dim * 2] intervention pairs
            labels: [batch_size] true labels (based on causal part labels)

        Returns:
            intervention_loss: scalar loss
            accuracy: intervention prediction accuracy
        """
        # Intervention prediction
        predictions = self.intervention_classifier(intervened_pairs).squeeze()

        # Binary cross entropy loss
        intervention_loss = F.binary_cross_entropy(predictions, labels)

        # Calculate accuracy
        pred_labels = (predictions > 0.5).float()
        accuracy = (pred_labels == labels).float().mean()

        return intervention_loss, accuracy

    def compute_causal_consistency_loss(self, causal_reps, non_causal_reps, labels):
        """
        Compute causal consistency loss

        Ensure:
        1. Causal representations can accurately predict labels
        2. Non-causal representations cannot accurately predict labels

        Args:
            causal_reps: [batch_size, feature_dim] causal representations
            non_causal_reps: [batch_size, feature_dim] non-causal representations
            labels: [batch_size] true labels

        Returns:
            causal_loss: causal loss
            non_causal_loss: non-causal loss
        """
        # Causal prediction
        causal_pred = self.causal_classifier(causal_reps).squeeze()
        causal_loss = F.binary_cross_entropy_with_logits(causal_pred, labels)

        # Non-causal prediction (should be inaccurate)
        non_causal_pred = self.non_causal_classifier(non_causal_reps).squeeze()

        # We want non-causal representations to not predict labels
        # So use reversed label loss, or minimize prediction confidence
        non_causal_loss = F.binary_cross_entropy_with_logits(
            non_causal_pred,
            torch.ones_like(labels) * 0.5  # Target is 0.5 (uncertain)
        )

        return causal_loss, non_causal_loss

    def get_memory_stats(self):
        """Get memory statistics"""
        return self.memory_stats.copy()

    def clear_memory(self):
        """Clear memory"""
        self.causal_memory.clear()
        self.non_causal_memory.clear()
        self.memory_stats = {
            'causal_size': 0,
            'non_causal_size': 0,
            'intervention_count': 0
        }

    def forward(self, causal_reps, non_causal_reps, labels, training=True):
        """
        Forward propagation

        Args:
            causal_reps: causal representations
            non_causal_reps: non-causal representations
            labels: true labels
            training: whether in training mode

        Returns:
            dict: containing various losses and statistics
        """
        if training:
            # Update memory
            self.update_memory(causal_reps, non_causal_reps, labels)

            # Generate intervention pairs
            device = causal_reps.device
            intervened_pairs, intervention_labels = self.generate_intervention_pairs(
                causal_reps.shape[0], device
            )

            # Calculate intervention loss
            intervention_loss, intervention_acc = self.compute_intervention_loss(
                intervened_pairs, intervention_labels
            )

            # Calculate causal consistency loss
            causal_loss, non_causal_loss = self.compute_causal_consistency_loss(
                causal_reps, non_causal_reps, labels
            )

            return {
                'intervention_loss': intervention_loss,
                'causal_loss': causal_loss,
                'non_causal_loss': non_causal_loss,
                'intervention_accuracy': intervention_acc,
                'memory_stats': self.get_memory_stats()
            }
        else:
            # Evaluation mode: only calculate causal consistency
            causal_loss, non_causal_loss = self.compute_causal_consistency_loss(
                causal_reps, non_causal_reps, labels
            )

            return {
                'causal_loss': causal_loss,
                'non_causal_loss': non_causal_loss,
                'memory_stats': self.get_memory_stats()
            }