import numpy as np
import torch

# Anomalie score just based on Data
def anomalie_score_mean_raw_fn():
        def anomalie_score_mean_raw(generated_samples, x):
                squared_diff = torch.tensor(x) **2
                #print("difference", squared_diff[0])
                mse_mean = squared_diff.mean(-1)
                #print("MSE", mse_mean.shape)
                return mse_mean, squared_diff
        return anomalie_score_mean_raw

# Reconstruction based using the mse to get the score
def anomalie_score_mean_fn():
        def anomalie_score_mean(generated_samples, x):
                squared_diff = (generated_samples - x) ** 2
                #print("difference", squared_diff[0])
                mse_mean = squared_diff.mean(-1)
                #print("MSE", mse_mean.shape)
                return mse_mean, squared_diff
        return anomalie_score_mean

def anomalie_score_mean_nodiff_fn():
        def anomalie_score_mean_nodiff(generated_samples, x):
                mse_mean = generated_samples.mean(-1)
                print("MSE", mse_mean.shape)
                return mse_mean, generated_samples
        return anomalie_score_mean_nodiff

def anomalie_score_sum_fn():
        def anomalie_score_add(generated_samples, x):
                squared_diff = (generated_samples - x) ** 2
                #print("difference", squared_diff[0])
                mse_mean = squared_diff.sum(-1)
                print("MSE", mse_mean[0])
                return mse_mean, squared_diff
        return anomalie_score_add