"""
Evaluation module for the FeatureInversionPipeline.

Each module receives the pair of (generated, target) and returns the evaluation results as:
dict[str, list[float]]: metric_name -> list[float], element i corresponds to the i-th stimulus
"""
import numpy as np
import torch
EPSILON = 1e-8

def cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """
    Calculate batch-wise cosine similarity between two tensors with arbitrary feature shapes.
    Args:
        x (torch.Tensor): Tensor of shape (batch_size, ...).
        y (torch.Tensor): Tensor of shape (batch_size, ...).
    Returns:
        torch.Tensor: Cosine similarity of shape (batch_size,).
    """
    x_flat = x.view(x.size(0), -1)
    y_flat = y.view(y.size(0), -1)
    x_norm = torch.norm(x_flat, dim=1)
    y_norm = torch.norm(y_flat, dim=1)
    dot_product = torch.sum(x_flat * y_flat, dim=1)
    return dot_product / (x_norm * y_norm + EPSILON) # Add small value to avoid division by zero


def correlation(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """
    Calculate batch-wise Pearson correlation coefficient between two tensors with arbitrary feature shapes.
    Args:
        x (torch.Tensor): Tensor of shape (batch_size, ...).
        y (torch.Tensor): Tensor of shape (batch_size, ...).
    Returns:
        torch.Tensor: Pearson correlation coefficient of shape (batch_size,).
    """
    x_flat = x.view(x.size(0), -1)
    y_flat = y.view(y.size(0), -1)
    x_mean = x_flat.mean(dim=1, keepdim=True)
    y_mean = y_flat.mean(dim=1, keepdim=True)
    x_centered = x_flat - x_mean
    y_centered = y_flat - y_mean
    cov = torch.sum(x_centered * y_centered, dim=1)
    x_std = torch.norm(x_centered, dim=1)
    y_std = torch.norm(y_centered, dim=1)
    return cov / (x_std * y_std + EPSILON)  # Add small value to avoid division by zero


def cosine_distance(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """
    Calculate batch-wise cosine distance between two tensors.
    
    Args:
        y (torch.Tensor): Tensor of shape (batch_size, dim*).
        x (torch.Tensor): Tensor of shape (batch_size, dim*).
    
    Returns:
        torch.Tensor: Cosine distance of shape (batch_size,).
    """
    return 1 - cosine_similarity(x, y)


def correlation_distance(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """
    Calculate batch-wise correlation distance between two tensors.
    
    Args:
        x (torch.Tensor): Tensor of shape (batch_size, dim*).
        y (torch.Tensor): Tensor of shape (batch_size, dim*).
    
    Returns:
        torch.Tensor: Correlation distance of shape (batch_size,).
    """
    return 1 - correlation(x, y)


def l2_distance(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """
    Calculate batch-wise L2 distance between two tensors.
    
    Args:
        x (torch.Tensor): Tensor of shape (batch_size, dim1, dim2,...).
        y (torch.Tensor): Tensor of shape (batch_size, dim1, dim2,...).
    
    Returns:
        torch.Tensor: L2 distance of shape (batch_size,).
    """
    dim = tuple(range(1, x.ndim))  # all dimensions except the batch dimension
    return torch.linalg.vector_norm(x - y, dim=dim)


# So, the old feature inversion pipeline gives the reconstructed features and target features
# to the metric. But to avoid the confusion, we will only use the reconstructed features.
# TODO: This should be fixed in the future.
class FeatureCosineDistance:
    """
    Evaluation metric that calculate the cosine distance between the generated features
    and the target features.

    Args:
        reference_features (dict[str, torch.Tensor], optional): Reference features to calculate the distance.
        prefix (str, optional): Prefix for the metric names.
    """
    def __init__(self, reference_features: dict[str, torch.Tensor], prefix: str = ''):
        self.reference_features = reference_features
        self.prefix = prefix

    def __call__(self, features: dict[str, torch.Tensor], _: dict[str, torch.Tensor]):
        distances = {}  # metric_name -> list[float]
        for layer in self.reference_features.keys():
            name = self.prefix + f'feature_cosine_distance_{layer}'
            distances[name] = cosine_distance(
                features[layer], self.reference_features[layer]
            ).tolist()
        return distances


class FeatureCorrelationDistance:
    """
    Evaluation metric that calculate the correlation distance between the generated features
    and the target features.

    Args:
        reference_features (dict[str, torch.Tensor], optional): Reference features to calculate the distance.
        prefix (str, optional): Prefix for the metric names.
    """
    def __init__(self, reference_features: dict[str, torch.Tensor], prefix: str = ''):
        self.reference_features = reference_features
        self.prefix = prefix

    def __call__(self, features: dict[str, torch.Tensor], _: dict[str, torch.Tensor]):
        distances = {}  # metric_name -> list[float]
        for layer in self.reference_features.keys():
            name = self.prefix + f'feature_correlation_distance_{layer}'
            distances[name] = correlation_distance(
                features[layer], self.reference_features[layer]
            ).tolist()
        return distances


def feature_cosine_similarity(
        features: dict[str, torch.Tensor],
        target_features: dict[str, torch.Tensor],
):
    """Calculate cosine similarity between the generated and target features for each sample.

    Args:
        features (dict[str, torch.Tensor]): Generated features indexed by the layer names.
            Each feature tensor has the shape (batch_size, *).
        target_features (dict[str, torch.Tensor]): Target features indexed by the layer names.
    
    Returns:
        dict[str, float]: Keys are:
            - feature_cosine_similarity_{layer_name}: Cosine similarity for each layer and each image sample.
            - feature_cosine_similarity_mean: Mean cosine similarity across layers for each image sample.
    """
    layers = list(features.keys())
    n_samples = features[layers[0]].shape[0]

    # metric_name -> list[float]
    sims = {}
    for layer_name, feature in features.items():
        target_feature = target_features[layer_name]
        layer_sim = torch.nn.functional.cosine_similarity(
            feature.flatten(start_dim=1), 
            target_feature.flatten(start_dim=1), 
            dim=1
        )
        sims[f'feature_cosine_similarity_{layer_name}'] = layer_sim.tolist()

    # calculate the mean cosine similarity across layers
    mean_sims = [
        np.mean([values[i] for values in sims.values()]) for i in range(n_samples)
    ]
    sims['feature_cosine_similarity_mean'] = mean_sims
    return sims


def batch_correlation(A, B):
    """
    Computes the Pearson correlation coefficient between corresponding rows of A and B.
    
    Args:
        A: Tensor of shape (batch_size, dim)
        B: Tensor of shape (batch_size, dim)
    
    Returns:
        Tensor of shape (batch_size) containing the correlation coefficients.
    """
    A_mean = A.mean(dim=1, keepdim=True)
    B_mean = B.mean(dim=1, keepdim=True)
    
    A_centered = A - A_mean
    B_centered = B - B_mean
    
    covariance = (A_centered * B_centered).sum(dim=1)
    A_std = A_centered.pow(2).sum(dim=1).sqrt()
    B_std = B_centered.pow(2).sum(dim=1).sqrt()
    
    correlation = covariance / (A_std * B_std + 1e-8)  # Add small value to avoid division by zero
    return correlation


def feature_correlation(
        features: dict[str, torch.Tensor],
        target_features: dict[str, torch.Tensor],
):
    """Calculate correlation between the generated and target features for each sample.

    Args:
        features (dict[str, torch.Tensor]): Generated features indexed by the layer names.
            Each feature tensor has the shape (batch_size, *).
        target_features (dict[str, torch.Tensor]): Target features indexed by the layer names.
    
    Returns:
    """
    sims = {}  # metric_name -> list[float]
    for layer_name, feature in features.items():
        feature = feature.flatten(start_dim=1)
        target_feature = target_features[layer_name].flatten(start_dim=1)
        corr = batch_correlation(feature, target_feature)
        sims[f'feature_correlation_{layer_name}'] = corr.tolist()

    # add mean across layers
    mean_corrs = [
        np.mean([values[i] for values in sims.values()]) for i in range(feature.shape[0])
    ]
    sims['feature_correlation_mean'] = mean_corrs
    return sims


def feature_mse(features: dict[str, torch.Tensor], target_features: dict[str, torch.Tensor]):
    """
    Evaluation metric that calculate the mean squared error between the features and the target features.
    """
    layers = list(features.keys())
    batch_size = features[layers[0]].shape[0]
    n_samples = features[layers[0]].shape[0]

    mses = {}  # metric_name -> list[float]
    for layer_name, features in features.items():
        target_feature = target_features[layer_name]
        mse = torch.nn.functional.mse_loss(features, target_feature, reduction='none')
        # shape (batch_size, *feature_size) -> (batch_size, feature_size)
        mse = mse.reshape(batch_size, -1)

        # layer wise mse: (batch_size,)
        mses[f'feature_mse_{layer_name}'] = mse.mean(dim=1).tolist()

    # calculate all layer mse
    all_layers = []
    for i in range(n_samples):
        mse_i = []
        for layer_name in mses:
            mse_i.append(mses[layer_name][i])
        all_layers.append(torch.mean(torch.tensor(mse_i)).item())
    mses['feature_mse'] = all_layers
    return mses


class FeatureCosineSimilarity:
    """
    Evaluation metric that calculate the cosine similarity between the generated features
    and the target features.

    Args:
        override_features (dict[str, torch.Tensor], optional): Override the target features with the given features.
        prefix (str, optional): Prefix for the metric names.
    """
    def __init__(
            self,
            override_features: dict[str, torch.Tensor] | None = None,
            prefix: str = '',
        ):
        self.override_features = override_features
        self.prefix = prefix

    def __call__(
            self,
            features: dict[str, torch.Tensor],
            target_features: dict[str, torch.Tensor],
        ):
        """
        Calculate the cosine similarity between the generated features and the target features.

        Args:
            features (dict[str, torch.Tensor]): Features of generated image.
            target_features (dict[str, torch.Tensor]): Target features.
        """
        if self.override_features is not None:
            target_features = self.override_features
        similarities = feature_cosine_similarity(features, target_features)
        # add prefix to the keys
        similarities = {self.prefix + k : v for k, v in similarities.items()}
        return similarities
    

class FeatureCorrelation:
    """
    Evaluation metric that calculate the correlation between the generated features
    and the target features.
    """
    def __init__(
            self,
            override_features: dict[str, torch.Tensor] | None = None,
            prefix: str = '',
        ):
        self.override_features = override_features
        self.prefix = prefix

    def __call__(
            self,
            features: dict[str, torch.Tensor],
            target_features: dict[str, torch.Tensor],
        ):
        """
        Calculate the correlation between the generated features and the target features.

        Args:
            features (dict[str, torch.Tensor]): Features of generated image.
            target_features (dict[str, torch.Tensor]): Target features.
        """
        if self.override_features is not None:
            target_features = self.override_features
        similarities = feature_correlation(features, target_features)
        similarities = {self.prefix + k: v for k, v in similarities.items()}
        return similarities
    

class FeatureMSE:
    """
    Evaluation metric that calculate the mean squared error between the generated features
    and the target features.
    """
    def __init__(
            self,
            override_features: dict[str, torch.Tensor] | None = None,
            prefix: str = '',
        ):
        self.override_features = override_features
        self.prefix = prefix

    def __call__(
            self,
            features: dict[str, torch.Tensor],
            target_features: dict[str, torch.Tensor],
        ):
        """
        Calculate the mean squared error between the generated features and the target features.

        Args:
            features (dict[str, torch.Tensor]): Features of generated image.
            target_features (dict[str, torch.Tensor]): Target features.
        """
        if self.override_features is not None:
            target_features = self.override_features
        mses = feature_mse(features, target_features)
        mses = {self.prefix + k: v for k, v in mses.items()}
        return mses


class TrueFeatureCosineSimilarity:
    """
    Evaluation metric that calculate the similarity between true feature (no noise)
    and the reconstructed feature.
    """
    def __init__(
            self,
            true_features: dict[str, torch.Tensor],
        ):
        """
        Args:
            true_features (dict[str, torch.Tensor]): True features without noise.
        """
        self.true_features = true_features

    def __call__(
            self, 
            features: dict[str, torch.Tensor],
            target_features: dict[str, torch.Tensor],
        ):
        """
        Calculate the cosine similarity between the true features and the reconstructed features.
        Do no use target features (true + noise).
        """
        similarities =  feature_cosine_similarity(features, self.true_features)
        # rename the keys: prefix "true_" to the keys
        similarities = {f'true_{k}': v for k, v in similarities.items()}
        return similarities


class TrueFeatureCorrelation:
    """
    Evaluation metric that calculate the correlation between true feature (no noise)
    and the reconstructed feature.
    """
    def __init__(
            self,
            true_features: dict[str, torch.Tensor],
        ):
        """
        Args:
            true_features (dict[str, torch.Tensor]): True features without noise.
        """
        self.true_features = true_features

    def __call__(
            self, 
            features: dict[str, torch.Tensor],
            target_features: dict[str, torch.Tensor],
        ):
        """
        Calculate the correlation between the true features and the reconstructed features.
        Do no use target features (true + noise).
        """
        similarities =  feature_correlation(features, self.true_features)
        similarities = {f'true_{k}': v for k, v in similarities.items()}
        return similarities


class TrueFeatureMSE:
    """
    Evaluation metric that calculate the mean squared error between true feature (no noise)
    """
    def __init__(self, true_features: dict[str, torch.Tensor]):
        self.true_features = true_features
    
    def __call__(self, features: dict[str, torch.Tensor], target_features: dict[str, torch.Tensor]):
        mses = feature_mse(features, self.true_features)
        mses = {f'true_{k}': v for k, v in mses.items()}
        return mses


class PixelCorrelation:
    """
    Evaluation metric for pixel correlation.
    """
    def __init__(self, true_images: torch.Tensor, domain=None):
        """
        Args:
            true_images (torch.Tensor): true images in tensor with shape (batch_size, 3, H, W), values in [0, 1],
                which is the common domain.
            domain (str, optional): Domain of the images.
        """
        self.true_images = true_images
        self.domain = domain

    def __call__(self, generated_images):
        batch_size = self.true_images.shape[0]
        corrs = []  # list[float], correlation for each image
        if self.domain is not None:
            # convert the generated images to the common domain
            generated_images = self.domain.receive(generated_images)
        for i in range(batch_size):
            tru_img = self.true_images[i].cpu().numpy()
            gen_img = generated_images[i].cpu().numpy()
            corr = np.corrcoef(tru_img.flatten(), gen_img.flatten())[0, 1]
            corrs.append(corr)
        return {'pixel_correlation': corrs}


class PixelCosineSimilarity:
    """
    Evaluation metric for pixel cosine similarity.
    """
    def __init__(self, true_images: torch.Tensor, domain=None):
        """
        Args:
            true_images (torch.Tensor): true images in tensor with shape (batch_size, 3, H, W), values in [0, 1],
                which is the common domain.
            domain (str, optional): Domain of the images.
        """
        self.true_images = true_images
        self.domain = domain

    def __call__(self, generated_images):
        results = {}  # name -> 'pixel_cosine_similarity' -> value
        if self.domain is not None:
            # convert the generated images to the common domain
            generated_images = self.domain.receive(generated_images)
        
        # flatten the images
        true_images = self.true_images.reshape(self.true_images.shape[0], -1).to(generated_images.device)
        generated_images = generated_images.reshape(generated_images.shape[0], -1)

        # calculate the cosine similarity for each sample
        cos = torch.nn.functional.cosine_similarity(
            true_images,
            generated_images,
            dim=1
        ).cpu().numpy()

        # add the results
        results = {'pixel_cosine_similarity': cos.tolist()}
        return results


class PixelMSE:
    """
    Evaluation metric for pixel mean squared error.
    """
    def __init__(self, true_images: torch.Tensor, domain=None):
        """
        Args:
            true_images (torch.Tensor): true images in tensor with shape (batch_size, 3, H, W), values in [0, 1],
                which is the common domain.
            domain (str, optional): Domain of the images.
        """
        self.true_images = true_images
        self.domain = domain

    def __call__(self, generated_images):
        results = {}
        if self.domain is not None:
            # convert the generated images to the common domain
            generated_images = self.domain.receive(generated_images)
        
        # turn [0, 1] to [0, 255]
        true_images = self.true_images * 255
        generated_images = generated_images * 255

        mse = torch.nn.functional.mse_loss(
            true_images,
            generated_images,
            reduction='none'
        ).mean(dim=(1, 2, 3)).cpu().numpy()
        results = {'pixel_mse': mse.tolist()}
        return results


class ThresholdStopCriteria:
    """
    Stop criteria based on the threshold value of the metric.

    Args:
        metric_name (str): Name of the metric to use for the stop criteria.
        threshold (float): Threshold value for the metric.
        mode (str): Whether the larger value is better or not.
            If 'larger', values larger than the threshold are considered as meeting the criteria.
            If 'smaller', values smaller than the threshold are considered as meeting the criteria.
    """
    def __init__(
            self, 
            metric_name: str, 
            threshold: float, 
            mode: str = 'larger',
        ):
        self.metric_name = metric_name
        self.threshold = threshold
        self.mode = mode

    def __call__(self, step_metrics: dict[str, list[float]]) -> bool:
        """
        Check if the stop criteria is met.

        Args:
            step_metrics (dict[str, list[float]]): Metrics for each sample.
                metric_name -> list[float], element i corresponds to the i-th stimulus.

        Returns:
            bool: Whether the stop criteria is met.
        """
        values = step_metrics[self.metric_name]
        if self.mode == 'larger':
            return all(v >= self.threshold for v in values)
        elif self.mode == 'smaller':
            return all(v <= self.threshold for v in values)
        else:
            raise ValueError(f"Unknown mode: {self.mode}")