import random
import numpy as np
from sklearn.preprocessing import label_binarize
from resnet import TwoDResNet, ConvNet, Resnet18, Resnet50
from resnet_gn import TwoDResNetGN
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from typing import Dict, List, Optional, Tuple, Union
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam, AdamW, SGD, Adamax, RAdam
from torch.optim.lr_scheduler import CosineAnnealingLR
from utils.cmmd_utils import mmd_compute
from utils.cdisco_torch import global_conditional_distance_correlation, classical_distance_correlation, global_conditional_distance_correlation_true
from scipy.linalg import solve as scp_solve
from functools import partial

import hydra
from omegaconf import DictConfig, OmegaConf

from sklearn.metrics import balanced_accuracy_score, r2_score, roc_auc_score, mean_squared_error


class CorrelationLoss(torch.nn.Module):
    def forward(self, inp: torch.Tensor, target: torch.Tensor):
        in_mean = inp.mean()
        tar_mean = target.mean()

        in_centered = inp - in_mean
        tar_centered = target - tar_mean  # TODO wird 0, wenn nur eine Klasse des Targets in Batch -> nan im backward

        r_numerator = torch.sum(in_centered * tar_centered)
        r_denominator = torch.sqrt((torch.sum(in_centered**2)) * torch.sum(tar_centered**2)) + 1e-5  # TODO ziemlich "groß". 1e-7?

        r = r_numerator / r_denominator

        r = torch.clamp(r, min=-1.0, max=1.0)
        return r**2


class CorrelationLossNegative(torch.nn.Module):
    def forward(self, inp: torch.Tensor, target: torch.Tensor):
        in_mean = inp.mean()
        tar_mean = target.mean()

        in_centered = inp - in_mean
        tar_centered = target - tar_mean

        r_numerator = torch.sum(in_centered * tar_centered)
        r_denominator = torch.sqrt((torch.sum(in_centered**2)) * torch.sum(tar_centered**2)) + 1e-5

        r = r_numerator / r_denominator

        r = torch.clamp(r, min=-1.0, max=1.0)

        return -r**2


class PredictionHead(torch.nn.Module):
    def __init__(self, in_features: int, n_outputs: int):
        super().__init__()
        # use layers with ReLU activation
        self.fc1 = torch.nn.Linear(in_features, 2*in_features)
        self.fc2 = torch.nn.Linear(2*in_features, in_features)
        self.fc3 = torch.nn.Linear(in_features, n_outputs)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class PredictionHeadLinear(torch.nn.Module):
    def __init__(self, in_features: int, n_outputs: int):
        super().__init__()
        # use layers with ReLU activation
        self.fc = torch.nn.Linear(in_features, n_outputs)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc(x)


class MetaDataPredictionAbstract(pl.LightningModule):
    def __init__(
        self,
        resnet_cfg: dict,
        target: str,
        class_weights: Optional[Union[np.ndarray, torch.Tensor]] = None,
        encoder_type: str = "resnet",
        pred_head: str = "linear",
        optimizer_cfg: Optional[DictConfig] = None,
        protected_attributes: Optional[List[str]] = None, # just for consistency with yaml files, not needed in base class
    ):
        super().__init__()
        # Use the new 2D ResNet
        if encoder_type == "resnet":
            self.predictor = TwoDResNet(**resnet_cfg)
        elif encoder_type == "resnet_gn":
            self.predictor = TwoDResNetGN(**resnet_cfg)
        elif encoder_type == "convnet":
            self.predictor = ConvNet(**resnet_cfg)
        elif encoder_type == "resnet_18":
            self.predictor = Resnet18(**resnet_cfg)
        elif encoder_type == "resnet_50":
            self.predictor = Resnet50(**resnet_cfg)
        else:
            raise ValueError(f"Unsupported encoder type: {encoder_type}")
        
        self.optimizer_cfg = optimizer_cfg
        self.n_outputs = resnet_cfg.get("n_outputs", 1)  # default to 1 output if not specified

        # Choose between a linear head or a non-linear one.
        prediction_head = PredictionHeadLinear if pred_head == "linear" else PredictionHead

        self.classifier_task = prediction_head(
            self.predictor.feature_size, 
            self.n_outputs
        )

        self.target = target  # should be "label" in our case

        # We use manual optimization.
        self.automatic_optimization = False

        self.class_weights = (torch.as_tensor(class_weights, dtype=torch.float32)
                              if class_weights is not None else None)

        self.save_hyperparameters(ignore=["predictor"])

    def get_loss_function(self, target: str, weights: Optional[torch.Tensor] = None):
        # Here we assume that "label" is binary.
        if target in ["label"]:
            return BCEWithLogitsLoss(weight=weights)
        elif target in ["cf", "cf_std", "label_c"]:
            return F.mse_loss
        elif target in ["label_cat", "label_cat_ordered"]:
            return F.cross_entropy
        else:
            raise ValueError(f"Unsupported target type: {target}")

    def forward(self, x) -> torch.Tensor:
        # Return classifier output (using features from the predictor)
        return self.classifier_task(self.predictor(x)[1])

    def training_step(self, batch, batch_idx):
        raise NotImplementedError

    def validation_step(self, batch, batch_idx):
        raise NotImplementedError

    def test_step(self, batch, batch_idx):
        raise NotImplementedError

    def calculate_log_metrics(
        self, 
        y: torch.Tensor, 
        y_hat: torch.Tensor, 
        target: str, 
        prefix: str = "val"
    ):
        y_np = y.cpu().numpy()
        y_hat_np = y_hat.cpu().numpy()

        if target == "label":
            y_pred = (y_hat_np > 0).astype(float)
            bacc = balanced_accuracy_score(y_np, y_pred)
            roc_auc = roc_auc_score(y_np, torch.sigmoid(y_hat).cpu().numpy())
            loss = F.binary_cross_entropy_with_logits(y_hat, y)

            self.log(f"{prefix}/label/bacc", bacc)
            self.log(f"{prefix}/label/roc_auc", roc_auc)

        elif target == "label_c":
            mse = F.mse_loss(y_hat, y)
            r2 = r2_score(y_np, y_hat_np)

            self.log(f"{prefix}/label_c/mse", mse)
            self.log(f"{prefix}/label_c/r2", r2)
            loss = mse

        elif target in ["label_cat", "label_cat_ordered"]:
            y_pred = torch.argmax(y_hat, dim=1).cpu().numpy()
            bacc = balanced_accuracy_score(y_np, y_pred)

            # Compute ROC-AUC for multiclass
            y_bin = label_binarize(y_np, classes=list(range(y_hat.shape[1])))
            roc_auc = roc_auc_score(
                y_bin, 
                F.softmax(y_hat, dim=1).cpu().numpy(), 
                multi_class="ovr"
            )

            self.log(f"{prefix}/label_cat/bacc", bacc)
            self.log(f"{prefix}/label_cat/roc_auc", roc_auc)
            loss = F.cross_entropy(y_hat, y)

        else:
            raise ValueError(f"Unsupported target type: {target}")

        self.log(f"{prefix}/loss", loss)
        return loss


    def calculate_metrics(self, y: torch.Tensor, y_hat: torch.Tensor, target: str) -> dict:
        y_np = y.cpu().numpy()
        y_hat_np = y_hat.cpu().numpy()

        if target == "label":
            # Binary classification
            y_pred = (y_hat_np > 0).astype(float)
            bacc = balanced_accuracy_score(y_np, y_pred)
            roc_auc = roc_auc_score(y_np, torch.sigmoid(y_hat).cpu().numpy())
            return {"label/bacc": bacc, "label/roc_auc": roc_auc}

        elif target == "label_c":
            mse = mean_squared_error(y_np, y_hat_np)
            r2 = r2_score(y_np, y_hat_np)
            return {"label_c/mse": mse, "label_c/r2": r2}

        elif target in ["label_cat", "label_cat_ordered"]:
            # Multiclass classification
            y_pred = torch.argmax(y_hat, dim=1).cpu().numpy()
            bacc = balanced_accuracy_score(y_np, y_pred)

            # Binarize labels for ROC-AUC
            y_bin = label_binarize(y_np, classes=list(range(y_hat.shape[1])))
            roc_auc = roc_auc_score(y_bin, F.softmax(y_hat, dim=1).cpu().numpy(), multi_class="ovr")

            return {"label_cat/bacc": bacc, "label_cat/roc_auc": roc_auc}

        else:
            raise ValueError(f"Unsupported target type: {target}")
        
    # def configure_optimizers(self):
    #     optimizer = SGD(list(self.predictor.parameters()) + list(self.classifier_task.parameters()), lr=self.lr_start, weight_decay=0.01)
    #     scheduler = CosineAnnealingLR(optimizer, T_max=self.trainer.max_epochs, eta_min=self.lr_end)
    #     return [optimizer], [scheduler]

    def configure_optimizers(self):
        optimizer_map = {
            "adam": Adam,
            "adamw": AdamW,
            "sgd": SGD,
            "adamax": Adamax,
            "radam": RAdam
        }

        optimizer = optimizer_map.get(self.optimizer_cfg.name)
        if optimizer is None:
            raise ValueError(f"Unsupported optimizer type: {self.optimizer_cfg.name}")

        # use all fields except "name" for optimizer params
        optimizer_params = {k: v for k, v in self.optimizer_cfg.items() if k not in ["name", "lr_end"]}

        optimizer = optimizer(list(self.predictor.parameters()) + list(self.classifier_task.parameters()), **optimizer_params)
        scheduler = CosineAnnealingLR(optimizer, T_max=self.trainer.max_epochs, eta_min=self.optimizer_cfg.lr_end)
        return [optimizer], [scheduler]

    def on_validation_epoch_end(self):
        raise NotImplementedError

    def on_test_epoch_end(self):
        raise NotImplementedError


class MetaDataPrediction(MetaDataPredictionAbstract):
    def __init__(
        self,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.loss = self.get_loss_function(self.target, self.class_weights)

        self.validation_outputs = []
        self.test_outputs = []

    def forward(self, x):
        return self.classifier_task(self.predictor(x)[1])

    def training_step(self, batch, batch_idx):
        optimizer = self.optimizers()
        optimizer.zero_grad()

        x: torch.Tensor = batch["img"]  # our new data uses key "img"
        y: torch.Tensor = batch[self.target]  # target is "label"

        y_hat: torch.Tensor = self.forward(x).squeeze()
        loss = self.loss(y_hat, y)

        # l2 regularize y_hat outputs
        l2_penalty = (y_hat.pow(2).sum(dim=1)).mean() if self.target == "label_cat" else 0.0

        self.manual_backward(loss + l2_penalty * 0.1)
        optimizer.step()

        self.log("train/loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x: torch.Tensor = batch["img"]
        y: torch.Tensor = batch[self.target]

        y_hat = self.forward(x).squeeze()
        loss = self.loss(y_hat, y)

        self.validation_outputs.append((y.detach().cpu(), y_hat.detach().cpu()))
        return loss
    
    def test_step(self, batch, batch_idx):
        x: torch.Tensor = batch["img"]
        y: torch.Tensor = batch[self.target]

        y_hat = self.forward(x).squeeze()
        loss = self.loss(y_hat, y)

        self.test_outputs.append((y.detach().cpu(), y_hat.detach().cpu()))
        return loss

    def on_validation_epoch_end(self):
        y = torch.cat([y for y, _ in self.validation_outputs])
        y_hat = torch.cat([y_hat for _, y_hat in self.validation_outputs])

        self.calculate_log_metrics(y, y_hat, self.target)
        self.validation_outputs = []
        if not self.trainer.sanity_checking:
            # check if schedulers is list or single object
            if isinstance(self.lr_schedulers(), list):
                for scheduler in self.lr_schedulers():
                    scheduler.step()
            else:
                self.lr_schedulers().step()

    def on_test_epoch_end(self):
        y = torch.cat([y for y, _ in self.test_outputs])
        y_hat = torch.cat([y_hat for _, y_hat in self.test_outputs])
        metrics = self.calculate_metrics(y, y_hat, self.target)
        self.test_outputs = []
        self.final_metrics = metrics
        return metrics


class GDROPredictor(MetaDataPrediction):
    def __init__(
        self,
        protected_attributes: str,  # e.g. "b"
        conditional: bool = True,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.protected_attribute = protected_attributes
        self.loss_fn = self.get_loss_function(self.target, self.class_weights)
        self.conditional = conditional

    def _get_group_ids(self, y, b):
        """Returns group ids for each sample, and list of unique groups."""
        # Ensure both y and b are 1D long tensors
        y = y.view(-1).long()
        b = b.view(-1).long()
        stacked = torch.stack([y, b], dim=1) if self.conditional else b.view(-1, 1)  # shape: [batch_size, 2] or [batch_size, 1]

        # Unique (target, bias) pairs → group ids
        unique_pairs, group_ids = torch.unique(stacked, dim=0, return_inverse=True)
        return group_ids, unique_pairs

    def _compute_group_loss(self, y_hat, y, b):
        """Computes per-group loss and returns worst-group loss."""
        group_ids, unique_pairs = self._get_group_ids(y, b)
        group_losses = []

        for gid in range(len(unique_pairs)):
            mask = group_ids == gid
            if mask.sum() == 0:
                group_losses.append(torch.tensor(0.0, device=self.device))
            else:
                group_loss = self.loss_fn(y_hat[mask], y[mask])
                group_losses.append(group_loss)

        group_losses = torch.stack(group_losses)
        worst_group_loss = torch.max(group_losses)
        return worst_group_loss, group_losses, unique_pairs

    def training_step(self, batch, batch_idx):
        optimizer = self.optimizers()
        optimizer.zero_grad()

        x = batch["img"]
        y = batch[self.target]
        b = batch[self.protected_attribute[0]]

        y_hat = self.forward(x).squeeze()
        worst_group_loss, group_losses, group_keys = self._compute_group_loss(y_hat, y, b)

        l2_penalty = (y_hat.pow(2).sum(dim=1)).mean() if self.target == "label_cat" else 0.0

        self.manual_backward(worst_group_loss + l2_penalty * 0.1)
        optimizer.step()

        self.log("train/gdro_loss", worst_group_loss.detach(), prog_bar=True)

        # Optional: log each group loss for interpretability
        for i, loss in enumerate(group_losses):
            self.log(f"train/group_loss/{group_keys[i].tolist()}", loss.detach())

        return worst_group_loss


class AdversarialPredictor(MetaDataPrediction):
    def __init__(
        self,
        protected_attribute_tuples: list[tuple[str, str, int]], # e.g. [("attr1", "continuous", 1), ("attr2", "categorical", 3), ("attr3", "binary", 1)]
        protected_pred_head: str = "linear",
        lambda_protected: float = 1.0,
        lambda_unlearn: float = 1.0,
        conditional: bool = True,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.loss = self.get_loss_function(self.target, self.class_weights)
        self.lambda_protected = lambda_protected
        self.lambda_unlearn = lambda_unlearn
        self.conditional = conditional

        # note: as of python 3.7 and higher, dicts maintain insertion order
        # (see: https://discuss.python.org/t/store-set-items-in-an-orderly-manner-on-a-first-come-first-serve-basis/67911/8)
        # see below as order is important for optimizer and scheduler configuration
        self.protected_attribute_tuples = protected_attribute_tuples
        self.protected_attributes = [a for a, _, _ in protected_attribute_tuples]

        # Choose prediction head for protected attributes.
        prediction_head = PredictionHeadLinear if protected_pred_head == "linear" else PredictionHead

        self.protected_losses = {a: self.get_loss_function_protected(t) for a, t, _ in protected_attribute_tuples}

        self.pred_heads_protected = torch.nn.ModuleDict(
            {
                k: prediction_head(
                    self.predictor.feature_size,
                    n_out,
                )
                for k, _, n_out in protected_attribute_tuples
            }
        )

        self.save_hyperparameters(ignore=["predictor", "pred_heads"])

    def get_loss_function_protected(self, protected_type: str):
        if protected_type == "continuous":
            return F.mse_loss
        elif protected_type == "categorical":
            return F.cross_entropy
        elif protected_type == "binary":
            return BCEWithLogitsLoss()
        else:
            raise ValueError(f"Unsupported protected attribute type: {protected_type}")

    def forward(self, x):
        return self.classifier_task(self.predictor(x)[1])

    def configure_optimizers(self):
        optimizer_full = Adam(list(self.predictor.parameters()) + list(self.classifier_task.parameters()), lr=self.optimizer_cfg.lr)
        optimizer_protected = [Adam(p.parameters(), lr=self.optimizer_cfg.lr) for p in self.pred_heads_protected.values()]
        optimizer_features = Adam(self.predictor.parameters(), lr=self.optimizer_cfg.lr)

        scheduler_full = CosineAnnealingLR(optimizer_full, T_max=self.trainer.max_epochs, eta_min=self.optimizer_cfg.lr_end)
        scheduler_protected = [CosineAnnealingLR(o, T_max=self.trainer.max_epochs, eta_min=self.optimizer_cfg.lr_end) for o in optimizer_protected]
        scheduler_features = CosineAnnealingLR(optimizer_features, T_max=self.trainer.max_epochs, eta_min=self.optimizer_cfg.lr_end)

        return [optimizer_full, optimizer_features, *optimizer_protected], [scheduler_full, scheduler_features, *scheduler_protected]

    def split_batch(self, batch) -> list:
        """
        Dynamically splits batch by unique values of the target.
        Returns list of sub-batches, one per unique target class.
        """
        if self.conditional:
            y = batch[self.target]
            unique_labels = y.unique()
            batches = []

            for val in unique_labels:
                mask = y == val
                if mask.sum() == 0:
                    continue
                sub_batch = {k: v[mask] for k, v in batch.items()}
                batches.append(sub_batch)

        else:
            batches = [batch]

        return batches

    def training_step(self, batch, batch_idx):

        optimizer_full, optimizer_features, *optimizer_protected = self.optimizers()
        optimizer_full.zero_grad()
        optimizer_features.zero_grad()
        for o in optimizer_protected:
            o.zero_grad()
        
        # first, predict main task
        x: torch.Tensor = batch["img"]
        y: torch.Tensor = batch[self.target]

        y_hat: torch.Tensor = self.forward(x).squeeze()
        loss = self.loss(y_hat.squeeze(), y)

        l2_penalty = (y_hat.pow(2).sum(dim=1)).mean() if self.target == "label_cat" else 0.0

        self.manual_backward(loss + l2_penalty * 0.1)
        optimizer_full.step()
        optimizer_full.zero_grad()

        self.log("train/loss", loss)

        with torch.no_grad():
            features = self.predictor(x)[1]

        # then, predict protected attributes / bias attributes
        for i, (protected_attribute, o) in enumerate(zip(self.protected_attributes, optimizer_protected)):
            y_protected = batch[protected_attribute]
            y_hat_protected = self.pred_heads_protected[protected_attribute](features).squeeze()
            loss_protected = self.protected_losses[protected_attribute](y_hat_protected, y_protected) * self.lambda_protected
            self.manual_backward(loss_protected)
            self.log(f"train/loss_{protected_attribute}", loss_protected.detach().cpu())
            o.step()
            o.zero_grad()

        # unlearn bias attributes from backbone network
        batches = self.split_batch(batch)
        for b in batches:
            if not b or len(b["img"]) == 0:
                continue
            x = b["img"]
            features = self.predictor(x)[1]
            total_loss = 0.0
            for i, protected_attribute in enumerate(self.protected_attributes):
                y_protected = b[protected_attribute].squeeze()
                y_hat_protected = self.pred_heads_protected[protected_attribute](features).squeeze()
                loss_protected = - self.protected_losses[protected_attribute](y_hat_protected, y_protected) * self.lambda_unlearn
                total_loss += loss_protected
            total_loss = total_loss / len(self.protected_attribute_tuples)
            self.manual_backward(total_loss)
            optimizer_features.step()
            optimizer_features.zero_grad()
        return loss

    def validation_step(self, batch, batch_idx):
        x: torch.Tensor = batch["img"]
        y: torch.Tensor = batch[self.target]
        features = self.predictor(x)[1]
        y_hat = self.classifier_task(features).squeeze()
        loss = self.loss(y_hat, y)

        self.validation_outputs.append((y.detach().cpu(), y_hat.detach().cpu()))
        batches = self.split_batch(batch)

        for i, b in enumerate(batches):
            for protected_attribute in self.protected_attributes:
                y_protected = b[protected_attribute].squeeze()
                features_b = self.predictor(b["img"])[1]
                y_hat_protected = self.pred_heads_protected[protected_attribute](features_b).squeeze()
                loss_protected = self.protected_losses[protected_attribute](y_hat_protected, y_protected)
                self.log(f"val/loss_{protected_attribute}_{i}", loss_protected.detach().cpu())
        return loss


class cmmdRegularizedPredictor(MetaDataPrediction):
    def __init__(
        self,
        protected_attributes: list,  # e.g. ["b"]
        cmmd_lambda: float = 1.0,
        bw: float = 0.1,
        conditional: bool = True,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.cmmd_lambda = cmmd_lambda
        self.bw = bw
        self.protected_attributes = protected_attributes
        self.conditional = conditional

    def _compute_cmmd_loss(self, x, y, biases):
        """
        Compute (Conditional) MMD loss. If `conditional=True`, only compare groups
        defined by protected attributes *within* each fixed y value. Otherwise, ignore
        y and compare groups across the whole batch.

        Assumes y and each protected attribute in `biases` are integer-encoded
        categorical/binary tensors of shape [N].
        """
        device = x.device

        features = self.forward(x).squeeze()

        # Helper: build list of (within-slice) feature tensors for each protected-attr group
        def features_by_bias_groups(slice_mask: torch.Tensor):
            """
            Returns a list of feature tensors, one per protected-attr group,
            restricted to samples where slice_mask==True.
            """
            # Collect protected attributes and restrict to the slice
            grouped_attrs = [biases[attr].view(-1).long()[slice_mask] for attr in self.protected_attributes]
            if len(grouped_attrs) == 0:
                return []  # nothing to group on

            stacked = torch.stack(grouped_attrs, dim=1)  # [n_slice, n_attrs]
            unique_groups, group_ids = torch.unique(stacked, dim=0, return_inverse=True)

            feats_slice = features[slice_mask]
            feats_list = []
            for gid in range(unique_groups.size(0)):
                mask_g = (group_ids == gid)
                if mask_g.sum() > 2:  # skip tiny groups; avoids degenerate MMD
                    feats_list.append(feats_slice[mask_g])
            return feats_list

        mmd_total = torch.tensor(0.0, device=device)
        num_pairs = 0

        if self.conditional:
            # ---- 2) Conditional MMD: loop over unique y values
            y_long = y.view(-1).long()
            y_values = torch.unique(y_long)

            for yv in y_values:
                mask_y = (y_long == yv)
                # Need at least a few samples in this class to form groups
                if mask_y.sum() < 3:
                    continue

                feats_groups = features_by_bias_groups(mask_y)
                if len(feats_groups) <= 1:
                    continue

                # Pairwise MMD within this fixed y
                for i in range(len(feats_groups)):
                    for j in range(i + 1, len(feats_groups)):
                        mmd = mmd_compute(
                            feats_groups[i], feats_groups[j],
                            kernel_type="gaussian",
                            gamma=self.bw
                        )
                        mmd_total = mmd_total + mmd
                        num_pairs += 1
        else:
            # ---- 3) Unconditional: group by protected attributes over the whole batch
            mask_all = torch.ones(y.shape[0], dtype=torch.bool, device=device)
            feats_groups = features_by_bias_groups(mask_all)
            if len(feats_groups) > 1:
                for i in range(len(feats_groups)):
                    for j in range(i + 1, len(feats_groups)):
                        mmd = mmd_compute(
                            feats_groups[i], feats_groups[j],
                            kernel_type="gaussian",
                            gamma=self.bw
                        )
                        mmd_total = mmd_total + mmd
                        num_pairs += 1

        if num_pairs == 0:
            return torch.tensor(0.0, device=device)

        return mmd_total / num_pairs


    def training_step(self, batch, batch_idx):
        optimizer = self.optimizers()
        optimizer.zero_grad()

        x: torch.Tensor = batch["img"]
        y: torch.Tensor = batch[self.target]
        biases = {attr: batch[attr] for attr in self.protected_attributes}

        y_hat: torch.Tensor = self.forward(x).squeeze()
        base_loss = self.loss(y_hat, y)

        l2_penalty = (y_hat.pow(2).sum(dim=1)).mean() if self.target == "label_cat" else 0.0

        mmd_loss = self._compute_cmmd_loss(x, y, biases)
        total_loss = base_loss + self.cmmd_lambda * mmd_loss + l2_penalty * 0.1

        self.manual_backward(total_loss)
        optimizer.step()

        self.log("train/loss", total_loss.detach().cpu())
        self.log("train/mmd_loss", mmd_loss.detach().cpu())
        return total_loss

    def validation_step(self, batch, batch_idx):
        x: torch.Tensor = batch["img"]
        y: torch.Tensor = batch[self.target]
        biases = {attr: batch[attr] for attr in self.protected_attributes}

        y_hat = self.forward(x).squeeze()
        base_loss = self.loss(y_hat, y)

        self.validation_outputs.append((y.detach().cpu(), y_hat.detach().cpu()))

        mmd_loss = self._compute_cmmd_loss(x, y, biases)
        total_loss = base_loss + self.cmmd_lambda * mmd_loss

        self.log("val/loss", base_loss.detach().cpu())
        self.log("val/mmd_loss", mmd_loss.detach().cpu())
        self.log("val/full_loss", total_loss.detach().cpu())
        return total_loss


class cDiscoPredictor(MetaDataPrediction):
    def __init__(
        self,
        protected_attributes: list,  # e.g. ["cf"]
        bw: float = 1.0,
        cdcor_lambda: float = 1.0,
        method: str = "max",
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.bw = bw
        self.protected_attributes = protected_attributes
        self.cdcor_lambda = cdcor_lambda
        self.method = method
        if self.method in ["max", "mean", "standard"]:
            self.cdcor_func = partial(global_conditional_distance_correlation_true, method=self.method)
        elif self.method in ["efficient"]:
            self.cdcor_func = global_conditional_distance_correlation
        else:
            raise ValueError(f"Unsupported cdisco method: {self.method}")

    def training_step(self, batch, batch_idx):
        optimizer = self.optimizers()
        optimizer.zero_grad()

        x: torch.Tensor = batch["img"]  # new key
        y: torch.Tensor = batch[self.target]  # target is "label"
        y_hat: torch.Tensor = self.forward(x).squeeze()
        base_loss = self.loss(y_hat, y)
        z: torch.Tensor = batch[self.protected_attributes[0]]

        # l2 regularize y_hat outputs
        l2_penalty = (y_hat.pow(2).sum(dim=1)).mean() if self.target == "label_cat" else 0.0

        # y_hat softmax if target is label_cat
        y_hat = F.softmax(y_hat, dim=1) if self.target == "label_cat" else F.sigmoid(y_hat) if self.target == "label" else y_hat

        # # select ith element of y_hat where i is the index of y
        # b, p = y_hat.shape

        # # Use torch.arange to build batch indices
        # batch_indices = torch.arange(b)

        # # Use advanced indexing
        # y_hat = y_hat[batch_indices, y]

        # one hot encode y if it is label_cat
        if self.target == "label_cat":
            y = F.one_hot(y, num_classes=self.n_outputs).float()

        cdcor_loss = self.cdcor_func(y_hat, z, y, self.bw)

        loss = cdcor_loss * self.cdcor_lambda + base_loss + l2_penalty * 0.1

        self.manual_backward(loss)
        optimizer.step()

        self.log("train/loss", base_loss.detach().cpu())
        self.log("train/cdcor_loss", cdcor_loss.detach().cpu())
        self.log("train/full_loss", loss.detach().cpu())
        return loss

    def validation_step(self, batch, batch_idx):
        x: torch.Tensor = batch["img"]
        y: torch.Tensor = batch[self.target]
        z: torch.Tensor = batch[self.protected_attributes[0]]
        y_hat = self.forward(x).squeeze()
        base_loss = self.loss(y_hat, y)

        self.validation_outputs.append((y.detach().cpu(), y_hat.detach().cpu()))

        z: torch.Tensor = batch[self.protected_attributes[0]]  # e.g. "cf"

        # y_hat softmax if target is label_cat
        y_hat = F.softmax(y_hat, dim=1) if self.target == "label_cat" else F.sigmoid(y_hat) if self.target == "label" else y_hat
        
        # # select ith element of y_hat where i is the index of y
        # b, p = y_hat.shape

        # # Use torch.arange to build batch indices
        # batch_indices = torch.arange(b)

        # # Use advanced indexing
        # y_hat = y_hat[batch_indices, y]

        # one hot encode y if it is label_cat
        if self.target == "label_cat":
            y = F.one_hot(y, num_classes=self.n_outputs).float()

        cdcor_loss = self.cdcor_func(y_hat, z, y, self.bw)
        loss = cdcor_loss * self.cdcor_lambda + base_loss
        self.log("val/cdcor_loss", cdcor_loss.detach().cpu())
        self.log("val/loss", base_loss.detach().cpu())
        self.log("val/full_loss", loss.detach().cpu())
        return loss


class DiscoPredictor(MetaDataPrediction):
    def __init__(
        self,
        protected_attributes: list,  # e.g. ["cf"]
        cdcor_lambda: float = 1.0,
        conditional: bool = True,
        conditional_group_strategy: str = "all",  # 'all', 'random', or 'fixed'
        fixed_group_id: int = 0,  # Only used if conditional_group_strategy == 'fixed'
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.protected_attributes = protected_attributes
        self.cdcor_lambda = cdcor_lambda
        self.conditional = conditional
        self.conditional_group_strategy = conditional_group_strategy
        self.fixed_group_id = fixed_group_id

    def _get_group_ids_to_use(self, unique_ids):
        if self.conditional_group_strategy == "all":
            return unique_ids
        elif self.conditional_group_strategy == "random":
            return [random.choice(unique_ids.tolist())]
        elif self.conditional_group_strategy == "fixed":
            if self.fixed_group_id in unique_ids:
                return [self.fixed_group_id]
            else:
                return []  # if fixed group not in current batch
        else:
            raise ValueError(f"Unknown strategy: {self.conditional_group_strategy}")

    def training_step(self, batch, batch_idx):
        optimizer = self.optimizers()
        optimizer.zero_grad()

        x = batch["img"]
        y = batch[self.target]
        z = batch[self.protected_attributes[0]]
        y_hat = self.forward(x).squeeze()

        base_loss = self.loss(y_hat, y)

        dcor_loss = 0.0
        num_elements = 0

        l2_penalty = (y_hat.pow(2).sum(dim=1)).mean() if self.target == "label_cat" else 0.0

        # y_hat softmax if target is label_cat
        y_hat = F.softmax(y_hat, dim=1) if self.target == "label_cat" else y_hat
        
        # # select ith element of y_hat where i is the index of y
        # b, p = y_hat.shape

        # # Use torch.arange to build batch indices
        # batch_indices = torch.arange(b)

        # # Use advanced indexing
        # y_hat = y_hat[batch_indices, y]

        if self.conditional:
            group_ids = y.view(-1).long()
            unique_ids = group_ids.unique()
            selected_groups = self._get_group_ids_to_use(unique_ids)
            for gid in selected_groups:
                mask = group_ids == gid
                if mask.sum() == 0:
                    continue
                dcor = classical_distance_correlation(y_hat[mask], z[mask])
                dcor_loss += dcor
                num_elements += 1
        else:
            dcor_loss = classical_distance_correlation(y_hat, z)
            num_elements = 1

        if num_elements > 0:
            dcor_loss /= num_elements
            loss = dcor_loss * self.cdcor_lambda + base_loss
        else:
            loss = base_loss

        loss += l2_penalty * 0.1  # L2 regularization on y_hat outputs

        self.manual_backward(loss)
        optimizer.step()

        self.log("train/loss", base_loss.detach().cpu())
        self.log("train/dcor_loss", dcor_loss.detach().cpu())
        self.log("train/full_loss", loss.detach().cpu())
        return loss

    def validation_step(self, batch, batch_idx):
        x = batch["img"]
        y = batch[self.target]
        z = batch[self.protected_attributes[0]]
        y_hat = self.forward(x).squeeze()
        self.validation_outputs.append((y.detach().cpu(), y_hat.detach().cpu()))
        base_loss = self.loss(y_hat, y)

        dcor_loss = 0.0
        num_elements = 0

        # y_hat softmax if target is label_cat
        y_hat = F.softmax(y_hat, dim=1) if self.target == "label_cat" else y_hat
        
        # # select ith element of y_hat where i is the index of y
        # b, p = y_hat.shape

        # # Use torch.arange to build batch indices
        # batch_indices = torch.arange(b)

        # # Use advanced indexing
        # y_hat = y_hat[batch_indices, y]

        if self.conditional:
            group_ids = y.view(-1).long()
            unique_ids = group_ids.unique()
            selected_groups = self._get_group_ids_to_use(unique_ids)

            for gid in selected_groups:
                mask = group_ids == gid
                if mask.sum() == 0:
                    continue
                dcor = classical_distance_correlation(y_hat[mask], z[mask])
                dcor_loss += dcor
                num_elements += 1
        else:
            dcor_loss = classical_distance_correlation(y_hat, z)
            num_elements = 1

        if num_elements > 0:
            dcor_loss /= num_elements
            loss = dcor_loss * self.cdcor_lambda + base_loss

        self.log("val/dcor_loss", dcor_loss.detach().cpu())
        self.log("val/loss", loss.detach().cpu())
        self.log("val/full_loss", loss.detach().cpu())
        return loss


class HSCICPredictor(MetaDataPrediction):
    def __init__(
        self,
        protected_attributes: list,
        hscic_lambda: float = 1.0,
        ridge_lambda: float = 0.1,
        sigma2_yhat: float = 1.0,
        sigma2_z: float = 1.0,
        sigma2_y: float = 1.0,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.protected_attributes = protected_attributes
        self.hscic_lambda = hscic_lambda
        self.ridge_lambda = ridge_lambda

        self.sigma2_yhat = sigma2_yhat
        self.sigma2_z = sigma2_z
        self.sigma2_y = sigma2_y

    def ensure_2d(self, tensor):
        if not torch.is_tensor(tensor):
            tensor = torch.tensor(tensor, dtype=torch.float32)
        if tensor.ndim == 1:
            tensor = tensor.unsqueeze(1)  # Convert shape (n,) → (n,1)
        elif tensor.ndim > 2:
            raise ValueError(f"Expected a 1D or 2D tensor, but got shape {tensor.shape}")
        return tensor

    def gaussian_kernel(self, X, Y=None, sigma2=1.0):
        if Y is None:
            Y = X
        dist = torch.cdist(X, Y, p=2) ** 2
        return torch.exp(-dist / sigma2)
    
    def hscic(self, X, Z, Y):

        X = self.ensure_2d(X)
        Z = self.ensure_2d(Z)
        Y = self.ensure_2d(Y)

        Kx = self.gaussian_kernel(X, sigma2=self.sigma2_yhat)
        Kz = self.gaussian_kernel(Z, sigma2=self.sigma2_z)
        Ky = self.gaussian_kernel(Y, sigma2=self.sigma2_y)

        ridge_lambda = self.ridge_lambda
        eye = torch.eye(Ky.shape[0], device=Ky.device)
        WtKyy = torch.linalg.solve(Ky + ridge_lambda * eye, Ky)

        term_1 = (WtKyy * ((Kx * Kz) @ WtKyy)).sum()
        WkKxWk = WtKyy * (Kx @ WtKyy)
        KzWk = Kz @ WtKyy
        term_2 = (WkKxWk * KzWk).sum()
        term_3 = (WkKxWk.sum(dim=0) * (WtKyy * KzWk).sum(dim=0)).sum()

        return (term_1 - 2 * term_2 + term_3) / Ky.shape[0]
    
    def fukumizu_cond(self, X, Z, Y):
        X = self.ensure_2d(X)
        Z = self.ensure_2d(Z)
        Y = self.ensure_2d(Y)

        Kx = self.gaussian_kernel(X, sigma2=self.sigma2_yhat)
        Kz = self.gaussian_kernel(Z, sigma2=self.sigma2_z)
        Ky = self.gaussian_kernel(Y, sigma2=self.sigma2_y)

        ridge = self.ridge_lambda
        n = Ky.shape[0]
        eye = torch.eye(n, device=Ky.device)

        # — three ridge‑solves —
        Rx = torch.linalg.solve(Kx + ridge * eye, Kx)     # (Kx + λI)⁻¹ Kx
        Rz = torch.linalg.solve(Kz + ridge * eye, Kz)     # (Kz + λI)⁻¹ Kz
        Wy = torch.linalg.solve(Ky + ridge * eye, Ky)     # (Ky + λI)⁻¹ Ky

        # now exactly your three terms, but with Rx/Rz in place of raw Kx/Kz
        term_1 = (Wy * ((Rx * Rz) @ Wy)).sum()

        WkKxWk = Wy * (Rx @ Wy)
        KzWk   = Rz @ Wy
        term_2 = (WkKxWk * KzWk).sum()

        term_3 = (WkKxWk.sum(dim=0) * (Wy * KzWk).sum(dim=0)).sum()

        return (term_1 - 2 * term_2 + term_3) / n


    def training_step(self, batch, batch_idx):
        optimizer = self.optimizers()
        optimizer.zero_grad()

        x: torch.Tensor = batch["img"]
        y: torch.Tensor = batch[self.target]
        y_hat: torch.Tensor = self.forward(x).squeeze()

        loss = self.loss(y_hat, y)

        # Assume A is the protected attribute, X is the label, Y is the prediction
        A: torch.Tensor = batch[self.protected_attributes[0]]

        # one hot encode y if it is label_cat
        if self.target == "label_cat":
            y = F.one_hot(y, num_classes=self.n_outputs).float()

        l2_penalty = (y_hat.pow(2).sum(dim=1)).mean() if self.target == "label_cat" else 0.0

        hscic_loss = self.hscic(y_hat, A, y).to(self.device)

        total_loss = loss + self.hscic_lambda * hscic_loss + l2_penalty * 0.1

        self.manual_backward(total_loss)
        optimizer.step()

        self.log("train/total_loss", total_loss.detach().cpu())
        self.log("train/hscic_loss", hscic_loss.detach().cpu())
        self.log("train/loss", loss.detach().cpu())

        return total_loss

    def validation_step(self, batch, batch_idx):
        x: torch.Tensor = batch["img"]
        y: torch.Tensor = batch[self.target]
        y_hat = self.forward(x).squeeze()
        loss = self.loss(y_hat, y)
        self.validation_outputs.append((y.detach().cpu(), y_hat.detach().cpu()))

        A: torch.Tensor = batch[self.protected_attributes[0]]

        # one hot encode y if it is label_cat
        if self.target == "label_cat":
            y = F.one_hot(y, num_classes=self.n_outputs).float()

        hscic_loss = self.hscic(y_hat, A, y).to(self.device)
        total_loss = loss + self.hscic_lambda * hscic_loss


        self.log("val/hscic_loss", hscic_loss.detach().cpu())
        self.log("val/total_loss", total_loss.detach().cpu())
        self.log("val/loss", loss.detach().cpu())
        return total_loss


class CIRCEPredictor(MetaDataPrediction):
    def __init__(
        self,
        protected_attributes: list,
        ridge_lambdas: List[float] = [1e-2, 1e-1, 1.0, 10.0, 100.0],
        sigma2_list: List[float] = [0.95, 1.0, 0.1, 0.01, 0.001],
        sigma2_yhat: float = 1.0,
        sigma2_z: float = 1.0,
        loo_cond_mean: bool = True,
        circe_lambda: float = 1.0,
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.protected_attributes = protected_attributes
        self.ridge_lambdas = ridge_lambdas
        self.sigma2_list = sigma2_list
        self.loo_cond_mean = loo_cond_mean
        self.circe_lambda = circe_lambda

        self.sigma2_yhat = sigma2_yhat
        self.sigma2_z = sigma2_z
        self.sigma2_y = None # Will be set during fit

        self.register_buffer("W_1", None)
        self.register_buffer("W_2", None)
        self.register_buffer("Z_heldout", None)
        self.register_buffer("Y_heldout", None)

    def ensure_2d(self, tensor):
        if not torch.is_tensor(tensor):
            tensor = torch.tensor(tensor, dtype=torch.float32)
        if tensor.ndim == 1:
            tensor = tensor.unsqueeze(1)  # Convert shape (n,) → (n,1)
        elif tensor.ndim > 2:
            raise ValueError(f"Expected a 1D or 2D tensor, but got shape {tensor.shape}")
        return tensor

    def gaussian_kernel(self, X, Y=None, sigma2=1.0):
        if Y is None:
            Y = X
        dist = torch.cdist(X, Y, p=2) ** 2
        return torch.exp(-dist / sigma2)

    def on_fit_start(self):
        if not hasattr(self.trainer.datamodule, 'heldout_dataloader'):
            raise RuntimeError("CIRCE needs a heldout dataloader.")

        loader = self.trainer.datamodule.heldout_dataloader()
        all_cf, all_label_c = [], []
        for batch in loader:
            all_cf.append(batch[self.protected_attributes[0]])
            all_label_c.append(batch[self.target])

        Z = torch.cat(all_cf, dim=0).float().to(self.device)
        Y = torch.cat(all_label_c, dim=0).to(self.device)

        # one hot encode Y if it is label_cat
        if self.target == "label_cat":
            Y = F.one_hot(Y, num_classes=self.n_outputs).float()

        Z = self.ensure_2d(Z)
        Y = self.ensure_2d(Y)
        self.Z_heldout = Z
        self.Y_heldout = Y

        best_sigma2, best_lambda, W_1, W_2 = self._estimate_regression_params(Y, Z)
        self.W_1 = W_1
        self.W_2 = W_2
        self.sigma2_y = best_sigma2
        print(f"CIRCE - Best params: sigma2={best_sigma2}, lambda={best_lambda}")

    def _estimate_regression_params(self, Y, Z):
        n = Y.size(0)
        Kz = self.gaussian_kernel(Z, sigma2=self.sigma2_z).cpu().numpy()
        best_loss = float("inf")
        best_sigma2 = None
        best_lambda = None
        best_W1, best_W2 = None, None

        for sigma2 in self.sigma2_list:
            Ky = self.gaussian_kernel(Y, sigma2=sigma2).cpu().numpy()
            I = np.eye(n)
            for reg in self.ridge_lambdas:
                try:
                    RHS = np.concatenate([I, Kz], axis=1)
                    W_all = scp_solve(Ky + reg * I, RHS, assume_a='pos')
                    W_1 = torch.tensor(W_all[:, :n]).float().to(self.device)
                    W_2 = torch.tensor(W_all[:, n:]).float().to(self.device)

                    A = (0.5 * Ky.T @ W_all[:, n:] - Kz.T) @ W_all[:, :n] @ Ky
                    Kres = Kz + A + A.T
                    Kres = Kres * Ky
                    loss = Kres.mean()

                    print(f"Testing sigma2={sigma2}, lambda={reg}, loss={loss.item()}")

                    if loss < best_loss:
                        best_loss = loss
                        best_sigma2 = sigma2
                        best_lambda = reg
                        best_W1, best_W2 = W_1.clone(), W_2.clone()
                except Exception as e:
                    print(f"Skipping sigma2={sigma2}, lambda={reg}: {e}")
                    continue

        return best_sigma2, best_lambda, best_W1, best_W2

    def training_step(self, batch, batch_idx):
        optimizer = self.optimizers()
        optimizer.zero_grad()

        x = batch["img"]
        y = batch[self.target]
        z = batch[self.protected_attributes[0]].float()
        y_hat = self.forward(x).squeeze()
        base_loss = self.loss(y_hat, y)

        # one hot encode y if it is label_cat
        if self.target == "label_cat":
            y = F.one_hot(y, num_classes=self.n_outputs).float()

        l2_penalty = (y_hat.pow(2).sum(dim=1)).mean() if self.target == "label_cat" else 0.0

        # Compute CIRCE penalty
        penalty = self.circe_penalty(y_hat, z, y)
        loss = self.circe_lambda * penalty + base_loss + l2_penalty * 0.1

        self.manual_backward(loss)
        optimizer.step()

        self.log("train/loss", base_loss.detach().cpu())
        self.log("train/circe", penalty.detach().cpu())
        self.log("train/full_loss", loss.detach().cpu())
        return loss

    def circe_penalty(self, features, z, y):

        z = self.ensure_2d(z)
        y = self.ensure_2d(y)
        features = self.ensure_2d(features)

        z_all = torch.cat([z, self.Z_heldout], dim=0)
        y_all = torch.cat([y, self.Y_heldout], dim=0)

        Kz_all = self.gaussian_kernel(z_all, Y=z, sigma2=self.sigma2_z)
        Ky_all = self.gaussian_kernel(y_all, Y=y, sigma2=self.sigma2_y)

        n = y.shape[0]
        A = (0.5 * Ky_all[n:].T @ self.W_2 - Kz_all[n:].T) @ self.W_1 @ Ky_all[n:]
        Kres = Kz_all[:n, :n] + A + A.T
        Kres = Kres * Ky_all[:n]

        Kx = self.gaussian_kernel(features, sigma2=self.sigma2_yhat)
        idx = torch.triu_indices(n, n, 1)
        return (Kx * Kres)[idx[0], idx[1]].mean()

    def validation_step(self, batch, batch_idx):
        x = batch["img"]
        y = batch[self.target]
        y_hat = self.forward(x).squeeze()
        self.validation_outputs.append((y.detach().cpu(), y_hat.detach().cpu()))
        base_loss = self.loss(y_hat, y)
        z = batch[self.protected_attributes[0]].float()

        # one hot encode y if it is label_cat
        if self.target == "label_cat":
            y = F.one_hot(y, num_classes=self.n_outputs).float()

        penalty = self.circe_penalty(y_hat, z, y)
        loss = self.circe_lambda * penalty + base_loss
        self.log("val/total_loss", loss.detach().cpu())
        self.log("val/circe", penalty.detach().cpu())
        self.log("val/loss", base_loss.detach().cpu())
        return loss
