# Copyright (c) 2024, [NAME] [NAME]. All rights reserved.

# This work is licensed under APACHE LICENSE, VERSION 2.0
# You should have received a copy of the license along with this
# work. If not, see [URL]

# This file has been modified from the original located at:
# [URL]

# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

import os
import numpy as np
import scipy.linalg
from . import sid_metric_utils as metric_utils


from functools import partial

import dnnlib



#----------------------------------------------------------------------------

def compute_fid_and_clip(opts, max_real, num_gen,batch_size=64,compute_clip=False):
    
    
    # Direct TorchScript translation of [URL]
    detector_url = '[URL]'

    detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer.

    if hasattr(opts, 'batch_size') and opts.batch_size is not None:
        batch_size = opts.batch_size


    if opts.data_stat is not None:   #use the precomputed dataset stats to save computation time
        loaded_mu_sigma = opts.data_stat
        mu_real = loaded_mu_sigma['mu'] #.cpu().numpy()
        sigma_real = loaded_mu_sigma['sigma'] #.cpu().numpy()
    else:
    
        mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset(
            opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
            rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real,batch_size=batch_size).get_mean_cov()

        print('mu sigma computation finished')
        
        if opts.run_dir is not None:
            
            if opts.rank == 0:
                class_name = opts.dataset_kwargs.class_name
                local_musigma_path = os.path.join(opts.run_dir, f'{class_name}.pkl')

                if not os.path.exists(local_musigma_path):
                    try:
                        save_pkl(data=dict(mu=mu_real, sigma=sigma_real), fname=local_musigma_path)
                        print('mu sigma save finished, return')
                    except Exception as e:
                        print(f"Error saving mu sigma: {e}")
                else:
                    print("mu sigma file already exists:", local_musigma_path)

    open_clip_detector_url=opts.metric_open_clip_path

    if not compute_clip:
        mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator(
            opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
            rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen,batch_size=batch_size).get_mean_cov()
        open_clip_score = float('nan')
        clip_score = float('nan')
    else:
        clip_score_fn = opts.clip_score_fn
        stats,open_clip_score, clip_score = metric_utils.compute_feature_stats_for_generator(
            opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs,
            rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen,batch_size=batch_size,compute_clip=compute_clip, clip_score_fn=clip_score_fn,open_clip_detector_url=open_clip_detector_url)
        mu_gen, sigma_gen = stats.get_mean_cov()

    if opts.rank != 0:
        if not compute_clip:
            return float('nan')
        else:
            return float('nan'),float('nan'),float('nan')

    m = np.square(mu_gen - mu_real).sum()
    s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
    fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))

    if not compute_clip:
        return float(fid)
    else:
        return float(fid), float(open_clip_score), float(clip_score) 

#----------------------------------------------------------------------------



