# file: user_extensions/baselines/fader_networks/metrics.py
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import LinearSVC

from prism.core.base_objects import BaseMetric
from prism.core.registry import METRICS
from prism.evaluation.metrics import _calculate_dci, _calculate_sap


@METRICS.register("AttributeInvarianceProbe")
class AttributeInvarianceProbe(BaseMetric):
    def calculate(self, z_full, y_targets, **kwargs):
        z_np = z_full.cpu().numpy().reshape(z_full.shape[0], -1)
        y_np = y_targets.cpu().numpy()

        indices = np.arange(len(z_np))
        train_indices, val_indices = train_test_split(
            indices, test_size=0.3, random_state=42, stratify=y_np
        )
        z_train, y_train = z_np[train_indices], y_np[train_indices]
        z_val, y_val = z_np[val_indices], y_np[val_indices]

        scaler = StandardScaler()
        z_train_scaled = scaler.fit_transform(z_train)
        z_val_scaled = scaler.transform(z_val)

        probe = LinearSVC(C=0.01, max_iter=1000, tol=1e-4, random_state=42, dual='auto')
        probe.fit(z_train_scaled, y_train)

        accuracy = probe.score(z_val_scaled, y_val) * 100
        return {"accuracy": accuracy}


@METRICS.register("BaselineDCI")
class BaselineDCIMetric(BaseMetric):
    def calculate(self, z_full, y_targets, y_style, style_feature_map, **kwargs):
        z = z_full

        ground_truth_factors = {}
        if y_style is not None and style_feature_map is not None:
            style_np = y_style.cpu().numpy()
            for name, index in style_feature_map.items():
                ground_truth_factors[name] = style_np[:, index]

        if not ground_truth_factors:
            return {}

        dci_cfg = self.config.evaluation.metric_settings
        return _calculate_dci(
            latents=z,
            ground_truth_factors=ground_truth_factors,
            n_estimators=dci_cfg.dci_estimators,
            max_depth=dci_cfg.dci_max_depth
        )


@METRICS.register("BaselineSAP")
class BaselineSAPMetric(BaseMetric):
    def calculate(self, z_full, y_targets, y_style, style_feature_map, **kwargs):
        z = z_full

        ground_truth_factors = {}
        if y_style is not None and style_feature_map is not None:
            style_np = y_style.cpu().numpy()
            for name, index in style_feature_map.items():
                ground_truth_factors[name] = style_np[:, index]

        if not ground_truth_factors:
            return 0.0

        return _calculate_sap(
            latents=z,
            ground_truth_factors=ground_truth_factors
        )