import pandas as pd
import numpy as np
from scipy.spatial.distance import mahalanobis
from sentence_transformers import SentenceTransformer
import scipy.linalg
import warnings

# Load the Sentence-BERT model
model = SentenceTransformer('all-MiniLM-L6-v2')

def calculate_mean_covariance(embeddings):
    # Calculate the mean and covariance matrix of embeddings.
    mean_vector = np.mean(embeddings, axis=0)
    covariance_matrix = np.cov(embeddings, rowvar=False)
    return mean_vector, covariance_matrix

def frechet_distance(mu1, sigma1, mu2, sigma2):
    """
    Params:
    -- mu1 : Numpy array containing the activations of a layer of the
             inception net (like returned by the function 'predict')
             for generated samples.
    -- mu2   : The sample mean over activations, precalculated on an
               representative data set.
    -- sigma1: The covariance matrix over activations for generated samples.
    -- sigma2: The covariance matrix over activations, precalculated on an
               representative data set.

    Returns:
    --   : The Frechet Distance.
    """
    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
    assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"

    diff = mu1 - mu2

    # Formula from the paper.
    covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = f'fid calculation produces singular product; adding {eps} to diagonal of cov estimates'
        warnings.warn(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    # Numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError(f'Imaginary component {m}')
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean)

def main():
    # Load your data from the CSV file
    dataset = 'banking'
    if dataset == 'clinc':
        discovered_intents = 'discovered_categories.txt'
        gt_intents = 'gt_intents.txt'
        known_intents = 'known_intents_clinc'
    else:
        discovered_intents = 'discovered_categories.txt'
        gt_intents = 'gt_intents.txt'
        known_intents = 'known_intents_banking'
    # read gt intents
    with open(gt_intents, 'r') as f:
        gt_intents = f.readlines()
    gt_intents = [x.strip() for x in gt_intents]
    
    # read discovered intents
    with open(discovered_intents, 'r') as f:
        discovered_intents = f.readlines()
    discovered_intents = [x.strip() for x in discovered_intents]

    # read known intents
    with open(known_intents, 'r') as f:
        known_intents = f.readlines()
    known_intents = [x.strip() for x in known_intents]

    # remove known intents from discovered_intents
    discovered_intents = [x for x in discovered_intents if x not in known_intents]
    real_intents_to_discover = [x for x in gt_intents if x not in known_intents]

    # Compute embeddings
    def get_embeddings(model, text_list):
        embeddings = model.encode(text_list, convert_to_numpy=True)
        return embeddings

    pred_embeddings = get_embeddings(model, discovered_intents)
    true_embeddings = get_embeddings(model, real_intents_to_discover)

    # Calculate mean and covariance for both sets of embeddings
    mean_pred, cov_pred = calculate_mean_covariance(pred_embeddings)
    mean_true, cov_true = calculate_mean_covariance(true_embeddings)

    # Compute the Fréchet Distance
    fid_score = frechet_distance(mean_pred, cov_pred, mean_true, cov_true)
    print("FID Score:", fid_score)

if __name__ == "__main__":
    main()
