from typing import Any
from tqdm.auto import tqdm
from pytorch_fid.inception import InceptionV3
from torch.utils.data import Dataset
import numpy as np
import torch 
from scipy import linalg
class Calculate_FID:
    def mean_std(self,dataset):
        return np.mean(dataset,axis=0),np.cov(dataset,rowvar=False)
    def __init__(self,target_dataset,batch_size=16,device="cuda"):
        self.target_dataset=target_dataset
        block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
        self.device=device
        self.target_model=InceptionV3([block_idx]).to(device)
        self.batch_size=batch_size
        data_temp = self.data2vec(target_dataset)
        self.target_mean,self.target_cov=self.mean_std(data_temp)

    def data2vec(self, target_dataset):
        if(isinstance(target_dataset,Dataset)):
            target_dataset=torch.utils.data.DataLoader(target_dataset,batch_size=self.batch_size,shuffle=False)
        data_temp=[]
        with torch.inference_mode():
            for batch in tqdm(target_dataset):
                if(isinstance(batch,list)):
                    batch=batch[0]
                #normalize from -1 to 1 to 0 to 1
                if(batch.min().item()<0):
                    batch=(batch+1)/2
                batch=batch.to(self.device)
                with torch.no_grad():
                    data_temp.append(self.target_model(batch)[0].cpu().numpy()[:,:,0,0])
                batch=batch.cpu()
        return np.concatenate(data_temp,axis=0)
    def __call__(self, target_dataset):
        data_temp = self.data2vec(target_dataset)
        mean, cov = self.mean_std(data_temp)
        print(mean.shape,cov.shape)
        fid = self.calculate_frechet_distance(mean, cov, self.target_mean, self.target_cov)
        return fid
    def calculate_frechet_distance(self,mu1, sigma1, mu2, sigma2, eps=1e-6):
        """Numpy implementation of the Frechet Distance.
        The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
        and X_2 ~ N(mu_2, C_2) is
                d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).

        Stable version by Dougal J. Sutherland.

        Params:
        -- mu1   : Numpy array containing the activations of a layer of the
                inception net (like returned by the function 'get_predictions')
                for generated samples.
        -- mu2   : The sample mean over activations, precalculated on an
                representative data set.
        -- sigma1: The covariance matrix over activations for generated samples.
        -- sigma2: The covariance matrix over activations, precalculated on an
                representative data set.

        Returns:
        --   : The Frechet Distance.
        """
        #print(mu1.shape,mu2.shape,sigma1.shape,sigma2.shape)
        mu1 = np.atleast_1d(mu1)
        mu2 = np.atleast_1d(mu2)

        sigma1 = np.atleast_2d(sigma1)
        sigma2 = np.atleast_2d(sigma2)
        #print(mu1.shape,mu2.shape,sigma1.shape,sigma2.shape)
        assert mu1.shape == mu2.shape, \
            'Training and test mean vectors have different lengths'
        assert sigma1.shape == sigma2.shape, \
            'Training and test covariances have different dimensions'

        diff = mu1 - mu2

        # Product might be almost singular
        covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
        if not np.isfinite(covmean).all():
            msg = ('fid calculation produces singular product; '
                'adding %s to diagonal of cov estimates') % eps
            print(msg)
            offset = np.eye(sigma1.shape[0]) * eps
            covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

        # Numerical error might give slight imaginary component
        if np.iscomplexobj(covmean):
            if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
                m = np.max(np.abs(covmean.imag))
                raise ValueError('Imaginary component {}'.format(m))
            covmean = covmean.real

        tr_covmean = np.trace(covmean)

        return (diff.dot(diff) + np.trace(sigma1)
                + np.trace(sigma2) - 2 * tr_covmean)