# Code/objectives/counterfactual.py
"""
Counterfactual recombination utilities.
"""

import itertools
import torch
import torch.nn.functional as F
from typing import List, Callable, Sequence

def recombine_modules(module_realizations: Sequence[Sequence[torch.Tensor]],
                      decoders: Sequence[Callable]=None,
                      max_combinations: int = 1000):
    """
    Given module_realizations = [U1, U2, ..., UM] where Um = [u_m_1, u_m_2, ...],
    produce a list of recombined full trajectories.
    decoders (optional) is a list of callables for each module that map the module
    tensor into its contribution in observation space. If not provided, raw tensors
    are concatenated (safe fallback).
    """
    if decoders is None:
        decoders = [None] * len(module_realizations)
    # cartesian product (bounded)
    prod = itertools.product(*module_realizations)
    out = []
    for i, combo in enumerate(prod):
        if i >= max_combinations:
            break
        parts = []
        for m_idx, part in enumerate(combo):
            dec = decoders[m_idx]
            try:
                if dec is not None:
                    part_out = dec(part)
                else:
                    part_out = part
            except Exception:
                # best-effort: use raw tensor if decoder fails
                part_out = part
            parts.append(part_out)
        # try to sum contributions first (common for additive PDE modules), else concat
        try:
            combined = sum(parts)
        except Exception:
            try:
                combined = torch.cat([p.unsqueeze(0) if p.dim() == 1 else p for p in parts], dim=-1)
            except Exception:
                # final fallback: stack
                combined = torch.stack(parts, dim=0)
        out.append(combined)
    return out


def filter_uncertainty(trajs, ensemble, threshold=0.1):
    """
    Keep recombined trajectories whose predictive std (across ensemble) is < threshold.
    ensemble: list of models; each model should accept a trajectory and return predictions.
    """
    valid = []
    for traj in trajs:
        preds = []
        for model in ensemble:
            try:
                preds.append(model(traj))
            except Exception:
                preds.append(torch.zeros_like(traj))
        preds = torch.stack(preds, dim=0)
        std = preds.std(dim=0).mean().item()
        if std < threshold:
            valid.append(traj)
    return valid


def total_correlation_penalty(z_samples):
    """
    Simple Frobenius norm of off-diagonal covariance as an independence proxy.
    """
    z_flat = z_samples.view(z_samples.size(0), -1)  # [B, D]
    cov = torch.cov(z_flat.T)
    off_diag = cov - torch.diag(torch.diag(cov))
    return torch.norm(off_diag, p="fro")
