# file: prism/evaluation/metrics.py
from abc import ABC, abstractmethod

import numpy as np
from scipy.stats import entropy
from sklearn.cluster import KMeans
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.linear_model import LinearRegression
from sklearn.metrics import adjusted_rand_score, mutual_info_score, normalized_mutual_info_score, silhouette_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import LinearSVC
import torch
from torchmetrics.image import FrechetInceptionDistance, LearnedPerceptualImagePatchSimilarity, StructuralSimilarityIndexMeasure

from prism.core.base_objects import BaseMetric
from prism.core.registry import METRICS


class _TorchmetricsWrapper(BaseMetric, ABC):
    def __init__(self, config):
        super().__init__(config)
        self.calculator = self._create_calculator()
        self._device_set = False

    @abstractmethod
    def _create_calculator(self):
        raise NotImplementedError

    @abstractmethod
    def _update_calculator(self, **kwargs):
        raise NotImplementedError

    def _ensure_device(self, device):
        if not self._device_set:
            self.calculator.to(device)
            self._device_set = True

    def update(self, device, **kwargs):
        self._ensure_device(device)
        if 'x_rec' in kwargs and 'data' in kwargs:
            x_rec_batch = kwargs['x_rec'].to(device)
            data_batch = kwargs['data'].to(device)
            self._update_calculator(x_rec=x_rec_batch, data=data_batch)
        else:
            raise ValueError("_TorchmetricsWrapper.update() requires 'x_rec' and 'data'.")

    def compute(self):
        result = self.calculator.compute().item()
        self.calculator.reset()
        return result

    def calculate(self, device, all_outputs=None, **kwargs):
        self._ensure_device(device)

        if all_outputs is not None:
            for batch_output in all_outputs:
                x_rec_batch = batch_output['x_rec'].to(device)
                data_batch = batch_output['data'].to(device)
                self._update_calculator(x_rec=x_rec_batch, data=data_batch)

        elif 'x_rec' in kwargs and 'data' in kwargs:
            x_rec_batch = kwargs['x_rec'].to(device)
            data_batch = kwargs['data'].to(device)
            self._update_calculator(x_rec=x_rec_batch, data=data_batch)

        else:
            raise ValueError(
                "_TorchmetricsWrapper.calculate() requires either 'all_outputs' (for epoch-end) "
                "or both 'x_rec' and 'data' (for batch-end) to be provided in arguments."
            )

        return self.compute()


def _discretize_continuous_variable(variable, num_bins):
    min_val, max_val = np.min(variable), np.max(variable)
    if np.allclose(min_val, max_val):
        return np.zeros_like(variable, dtype=np.int64)
    bins = np.linspace(start=min_val, stop=max_val, num=num_bins + 1)
    return np.digitize(x=variable, bins=bins, right=False)


def _calculate_mig(z_latents, labels, num_bins):
    z_np = z_latents.cpu().numpy()
    labels_np = labels.cpu().numpy()

    if z_np.ndim > 2:
        z_np = z_np.reshape(z_np.shape[0], -1)

    num_latents = z_np.shape[1]
    z_discrete = np.empty_like(z_np, dtype=np.int64)
    for i in range(num_latents):
        z_discrete[:, i] = _discretize_continuous_variable(z_np[:, i], num_bins)

    mi_scores = np.zeros(num_latents)
    for i in range(num_latents):
        mi_scores[i] = mutual_info_score(labels_true=z_discrete[:, i], labels_pred=labels_np)

    entropy_labels = mutual_info_score(labels_true=labels_np, labels_pred=labels_np)
    if np.isclose(entropy_labels, 0):
        return 0.0

    sorted_mi = np.sort(mi_scores)[::-1]
    mi_max = sorted_mi[0]
    mi_second_max = sorted_mi[1] if num_latents > 1 else 0.0

    mig_score = (mi_max - mi_second_max) / entropy_labels
    return float(mig_score)


def _normalized_entropy(p, axis=1):
    eps = 1e-8
    p = np.clip(p, eps, 1.0)
    unnorm_entropy = entropy(p, axis=axis)
    log_k = np.log(p.shape[axis])
    return unnorm_entropy / log_k


def _calculate_dci(latents, ground_truth_factors, n_estimators, max_depth):
    z_np = latents.cpu().numpy()
    if z_np.ndim > 2:
        z_np = z_np.reshape(z_np.shape[0], -1)

    num_latents = z_np.shape[1]
    num_factors = len(ground_truth_factors)
    importance_matrix = np.zeros((num_latents, num_factors))
    informativeness_scores = np.zeros(num_factors)

    for i, (name, labels) in enumerate(ground_truth_factors.items()):
        stratify_labels = labels if 'target_id' in name and len(np.unique(labels)) > 1 else None
        z_train, z_test, labels_train, labels_test = train_test_split(z_np, labels, test_size=0.2, random_state=42, stratify=stratify_labels)

        if 'target_id' in name:
            model = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth, n_jobs=-1, random_state=42)
        else:
            model = RandomForestRegressor(n_estimators=n_estimators, max_depth=max_depth, n_jobs=-1, random_state=42)

        model.fit(z_train, labels_train)
        informativeness_scores[i] = model.score(z_test, labels_test)
        importance_matrix[:, i] = model.feature_importances_

    row_sums = importance_matrix.sum(axis=1)
    non_zero_rows = row_sums > 1e-12
    disentanglement_per_latent = np.zeros_like(row_sums)
    if np.any(non_zero_rows):
        norm_rows = importance_matrix[non_zero_rows] / row_sums[non_zero_rows, np.newaxis]
        disentanglement_per_latent[non_zero_rows] = 1.0 - _normalized_entropy(norm_rows, axis=1)
    disentanglement = np.mean(disentanglement_per_latent)

    col_sums = importance_matrix.sum(axis=0)
    non_zero_cols = col_sums > 1e-12
    completeness_per_factor = np.zeros_like(col_sums)
    if np.any(non_zero_cols):
        norm_cols = importance_matrix[:, non_zero_cols] / col_sums[non_zero_cols]
        completeness_per_factor[non_zero_cols] = 1.0 - _normalized_entropy(norm_cols, axis=0)
    completeness = np.mean(completeness_per_factor)

    informativeness = np.mean(informativeness_scores)

    return {
        'disentanglement': disentanglement,
        'completeness': completeness,
        'informativeness': informativeness
    }


def _calculate_sap(latents, ground_truth_factors):
    z_np = latents.cpu().numpy()
    if z_np.ndim > 2:
        z_np = z_np.reshape(z_np.shape[0], -1)

    num_samples, num_latents = z_np.shape
    num_factors = len(ground_truth_factors)
    score_matrix = np.zeros((num_latents, num_factors))

    for i, (name, labels) in enumerate(ground_truth_factors.items()):
        stratify_labels = labels if 'target_id' in name and len(np.unique(labels)) > 1 else None
        z_train, z_test, labels_train, labels_test = train_test_split(z_np, labels, test_size=0.2, random_state=42, stratify=stratify_labels)

        scaler = StandardScaler()
        z_train_scaled = scaler.fit_transform(z_train)
        z_test_scaled = scaler.transform(z_test)

        for j in range(num_latents):
            z_train_j = z_train_scaled[:, j:j + 1]
            z_test_j = z_test_scaled[:, j:j + 1]

            if 'target_id' in name:
                model = LinearSVC(C=0.01, max_iter=1000, tol=1e-4, random_state=42, dual='auto')
                model.fit(z_train_j, labels_train)
                score = model.score(z_test_j, labels_test)
            else:
                model = LinearRegression()
                model.fit(z_train_j, labels_train)
                score = max(0.0, model.score(z_test_j, labels_test))

            score_matrix[j, i] = score

    sap_per_factor = []
    for i in range(num_factors):
        factor_scores = score_matrix[:, i]
        sorted_scores = np.sort(factor_scores)[::-1]
        top_score = sorted_scores[0]
        second_top_score = sorted_scores[1] if num_latents > 1 else 0.0
        sap_per_factor.append(top_score - second_top_score)

    return float(np.mean(sap_per_factor))


def _calculate_linear_probe_accuracy(z_train, y_train, z_val, y_val):
    scaler = StandardScaler()
    z_train_scaled = scaler.fit_transform(z_train)
    z_val_scaled = scaler.transform(z_val)

    probe = LinearSVC(C=0.01, max_iter=1000, tol=1e-4, random_state=42, dual='auto')
    probe.fit(z_train_scaled, y_train)

    return probe.score(z_val_scaled, y_val) * 100


def calculate_probe_gap(z_full, y_targets, config):
    z_np = z_full.cpu().numpy()
    y_np = y_targets.cpu().numpy()

    indices = np.arange(len(z_np))
    train_indices, val_indices = train_test_split(indices, test_size=0.3, random_state=42, stratify=y_np)
    z_train, y_train = z_np[train_indices], y_np[train_indices]
    z_val, y_val = z_np[val_indices], y_np[val_indices]

    latent_cfg = config.model.latent_space
    target_slice = slice(latent_cfg.target_slice_start, latent_cfg.target_slice_stop)
    nontarget_slice = slice(latent_cfg.nontarget_slice_start, latent_cfg.nontarget_slice_stop)

    def to_numpy_flat(z, sl):
        return z[:, sl].reshape(z.shape[0], -1)

    acc_target = _calculate_linear_probe_accuracy(
        to_numpy_flat(z_train, target_slice), y_train,
        to_numpy_flat(z_val, target_slice), y_val
    )
    acc_nontarget = _calculate_linear_probe_accuracy(
        to_numpy_flat(z_train, nontarget_slice), y_train,
        to_numpy_flat(z_val, nontarget_slice), y_val
    )
    return {
        'probe_target': acc_target,
        'probe_nontarget': acc_nontarget,
        'probe_gap': acc_target - acc_nontarget
    }


def _run_kmeans_and_get_labels(latents, n_clusters):
    z_np = latents.cpu().numpy()
    if z_np.ndim > 2:
        z_np = z_np.reshape(z_np.shape[0], -1)

    kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init='auto')
    cluster_labels = kmeans.fit_predict(z_np)
    return z_np, cluster_labels


def _get_residual_subspace(config, z_full):
    latent_cfg = config.model.latent_space
    nontarget_slice = slice(latent_cfg.nontarget_slice_start, latent_cfg.nontarget_slice_stop)
    return z_full[:, nontarget_slice]


def _get_ground_truth_factors_for_z0(y_targets, y_style, style_feature_map):
    ground_truth_factors = {}

    if y_style is not None and style_feature_map is not None:
        style_np = y_style.cpu().numpy()
        for name, index in style_feature_map.items():
            ground_truth_factors[name] = style_np[:, index]

    ground_truth_factors['target_id_leakage'] = y_targets.cpu().numpy()

    return ground_truth_factors


@METRICS.register("ssim")
class SSIMMetric(_TorchmetricsWrapper):
    def _create_calculator(self):
        return StructuralSimilarityIndexMeasure(
            data_range=self.config.evaluation.metric_settings.ssim_data_range
        )

    def _update_calculator(self, x_rec, data, **kwargs):
        self.calculator.update(preds=x_rec, target=data)


@METRICS.register("lpips")
class LPIPSMetric(_TorchmetricsWrapper):
    def _create_calculator(self):
        return LearnedPerceptualImagePatchSimilarity(net_type='vgg')

    def _preprocess_images(self, image_batch):
        if image_batch.shape[1] == 1:
            return image_batch.repeat(1, 3, 1, 1)
        return image_batch

    def _update_calculator(self, x_rec, data, **kwargs):
        x_rec_processed = self._preprocess_images(x_rec)
        data_processed = self._preprocess_images(data)
        self.calculator.update(img1=x_rec_processed, img2=data_processed)


@METRICS.register("fid")
class FIDMetric(_TorchmetricsWrapper):
    def _create_calculator(self):
        return FrechetInceptionDistance(feature=2048)

    def _preprocess_images_for_fid(self, image_batch):
        img = torch.clamp(image_batch, -1.0, 1.0)
        img = ((img + 1.0) / 2.0 * 255).to(torch.uint8)
        if img.shape[1] == 1:
            img = img.repeat(1, 3, 1, 1)
        return img

    def _update_calculator(self, x_rec, data, **kwargs):
        real_images = self._preprocess_images_for_fid(data)
        fake_images = self._preprocess_images_for_fid(x_rec)

        self.calculator.update(real_images, real=True)
        self.calculator.update(fake_images, real=False)


@METRICS.register("mig")
class MIGMetric(BaseMetric):
    def calculate(self, z_full, y_targets, **kwargs):
        latent_cfg = self.config.model.latent_space
        target_slice = slice(latent_cfg.target_slice_start, latent_cfg.target_slice_stop)
        z_target = z_full[:, target_slice]
        num_bins = self.config.evaluation.metric_settings.mig_bins
        return _calculate_mig(z_target, y_targets, num_bins)


@METRICS.register("dci")
class DCIMetric(BaseMetric):
    def calculate(self, z_full, y_targets, y_style, style_feature_map, **kwargs):
        z0 = _get_residual_subspace(self.config, z_full)
        ground_truth_factors = _get_ground_truth_factors_for_z0(y_targets, y_style, style_feature_map)

        dci_cfg = self.config.evaluation.metric_settings
        return _calculate_dci(
            latents=z0,
            ground_truth_factors=ground_truth_factors,
            n_estimators=dci_cfg.dci_estimators,
            max_depth=dci_cfg.dci_max_depth
        )


@METRICS.register("sap")
class SAPMetric(BaseMetric):
    def calculate(self, z_full, y_targets, y_style, style_feature_map, **kwargs):
        z0 = _get_residual_subspace(self.config, z_full)
        ground_truth_factors = _get_ground_truth_factors_for_z0(y_targets, y_style, style_feature_map)

        return _calculate_sap(
            latents=z0,
            ground_truth_factors=ground_truth_factors
        )


@METRICS.register("linear_probe")
class LinearProbeMetric(BaseMetric):
    def calculate(self, z_full, y_targets, **kwargs):
        return calculate_probe_gap(z_full, y_targets, self.config)


class _ClusteringMetric(BaseMetric, ABC):
    @abstractmethod
    def _get_subspace(self, z_full):
        raise NotImplementedError

    @abstractmethod
    def _get_n_clusters(self):
        raise NotImplementedError


class _ARIMetric(_ClusteringMetric, ABC):
    def calculate(self, z_full, y_targets, **kwargs):
        subspace = self._get_subspace(z_full)
        n_clusters = self._get_n_clusters()
        y_np = y_targets.cpu().numpy()

        _, cluster_labels = _run_kmeans_and_get_labels(subspace, n_clusters)
        return adjusted_rand_score(y_np, cluster_labels)


class _SilhouetteMetric(_ClusteringMetric, ABC):
    def calculate(self, z_full, **kwargs):
        subspace = self._get_subspace(z_full)
        n_clusters = self._get_n_clusters()
        z_np, cluster_labels = _run_kmeans_and_get_labels(subspace, n_clusters)
        return silhouette_score(z_np, cluster_labels)


class _NMIMetric(_ClusteringMetric, ABC):
    def calculate(self, z_full, y_targets, **kwargs):
        subspace = self._get_subspace(z_full)
        n_clusters = self._get_n_clusters()
        y_np = y_targets.cpu().numpy()

        _, cluster_labels = _run_kmeans_and_get_labels(subspace, n_clusters)
        return normalized_mutual_info_score(y_np, cluster_labels)


@METRICS.register("identity_ari")
class IdentityARIMetric(_ARIMetric):
    def _get_subspace(self, z_full):
        latent_cfg = self.config.model.latent_space
        return z_full[:, slice(latent_cfg.target_slice_start, latent_cfg.target_slice_stop)]

    def _get_n_clusters(self):
        return self.config.data.num_classes


@METRICS.register("identity_silhouette")
class IdentitySilhouetteMetric(_SilhouetteMetric):
    def _get_subspace(self, z_full):
        latent_cfg = self.config.model.latent_space
        return z_full[:, slice(latent_cfg.target_slice_start, latent_cfg.target_slice_stop)]

    def _get_n_clusters(self):
        return self.config.data.num_classes


@METRICS.register("style_leakage_ari")
class StyleLeakageARIMetric(_ARIMetric):
    def _get_subspace(self, z_full):
        latent_cfg = self.config.model.latent_space
        return z_full[:, slice(latent_cfg.nontarget_slice_start, latent_cfg.nontarget_slice_stop)]

    def _get_n_clusters(self):
        return self.config.data.num_classes


@METRICS.register("style_silhouette")
class StyleSilhouetteMetric(_SilhouetteMetric):
    def _get_subspace(self, z_full):
        latent_cfg = self.config.model.latent_space
        return z_full[:, slice(latent_cfg.nontarget_slice_start, latent_cfg.nontarget_slice_stop)]

    def _get_n_clusters(self):
        return self.config.evaluation.metric_settings.style_cluster_k


@METRICS.register("identity_nmi")
class IdentityNMIMetric(_NMIMetric):
    def _get_subspace(self, z_full):
        latent_cfg = self.config.model.latent_space
        return z_full[:, slice(latent_cfg.target_slice_start, latent_cfg.target_slice_stop)]

    def _get_n_clusters(self):
        return self.config.data.num_classes


@METRICS.register("style_leakage_nmi")
class StyleLeakageNMIMetric(_NMIMetric):
    def _get_subspace(self, z_full):
        latent_cfg = self.config.model.latent_space
        return z_full[:, slice(latent_cfg.nontarget_slice_start, latent_cfg.nontarget_slice_stop)]

    def _get_n_clusters(self):
        return self.config.data.num_classes