import numpy as np 
from scipy import linalg
import torch
import torch.nn as nn 
from einops import rearrange
from tqdm.auto import tqdm
from diffusers import UNet2DModel
from ema_pytorch import EMA
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """Numpy implementation of the Frechet Distance.
    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
    and X_2 ~ N(mu_2, C_2) is
            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).

    Stable version by Dougal J. Sutherland.

    Params:
    -- mu1   : Numpy array containing the activations of a layer of the
               inception net (like returned by the function 'get_predictions')
               for generated samples.
    -- mu2   : The sample mean over activations, precalculated on an
               representative data set.
    -- sigma1: The covariance matrix over activations for generated samples.
    -- sigma2: The covariance matrix over activations, precalculated on an
               representative data set.

    Returns:
    --   : The Frechet Distance.
    """

    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, \
        'Training and test mean vectors have different lengths'
    assert sigma1.shape == sigma2.shape, \
        'Training and test covariances have different dimensions'

    diff = mu1 - mu2

    # Product might be almost singular
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = ('fid calculation produces singular product; '
               'adding %s to diagonal of cov estimates') % eps
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    # Numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError('Imaginary component {}'.format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return (diff.dot(diff) + np.trace(sigma1)
            + np.trace(sigma2) - 2 * tr_covmean)
class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
    """Maintains moving averages of model parameters using an exponential decay.
    ``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
    `torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
    is used to compute the EMA.
    """

    def __init__(self, model, decay, device="cpu"):
        def ema_avg(avg_model_param, model_param, num_averaged):
            return decay * avg_model_param + (1 - decay) * model_param

        super().__init__(model, device, ema_avg, use_buffers=True)
def fd_preprocess(image,label,time,num_show,num_classes=10):
    image=rearrange(image[:,-num_show:],"a b ...->(a b)...")
    #process label
    label=rearrange(label,"a b ... ->(a b) ...")
    label=torch.eye(num_classes,device=label.device)[label]
    #process time
    time=rearrange(time[:,-num_show:],"a b ...->(a b)...")
    return image,label,time
ts_to_weight={(0, 20): 0.9989637305699481,
 (20, 40): 0.9978563772775991,
 (40, 60): 0.9968186638388123,
 (60, 80): 0.9989754098360656,
 (80, 100): 0.9913544668587896,
 (100, 120): 0.9890981169474727,
 (120, 140): 0.9790209790209791,
 (140, 160): 0.9649298597194389,
 (160, 180): 0.9424603174603174,
 (180, 200): 0.898854961832061,
 (200, 220): 0.8636363636363636,
 (220, 240): 0.8340163934426229,
 (240, 260): 0.7723735408560312,
 (260, 280): 0.7222787385554426,
 (280, 300): 0.6730401529636711,
 (300, 320): 0.6171428571428571,
 (320, 340): 0.5633367662203913,
 (340, 360): 0.5381443298969072,
 (360, 380): 0.4561904761904762,
 (380, 400): 0.44123711340206184,
 (400, 420): 0.4027370478983382,
 (420, 440): 0.33987603305785125,
 (440, 460): 0.33794466403162055,
 (460, 480): 0.3035532994923858,
 (480, 500): 0.2912912912912913,
 (500, 520): 0.26429980276134124,
 (520, 540): 0.2553816046966732,
 (540, 560): 0.2308457711442786,
 (560, 580): 0.22422422422422422,
 (580, 1000): 0.0,}
def ts_to_weight_func(ts):
    if(torch.is_tensor(ts)):
        ts=ts.item()
    for key in ts_to_weight:
        if(key[0]<=ts<key[1]):
            return ts_to_weight[key]
    return 0
def Euclidean_2(fi, power=2):
    """
    Calculate the Euclidean distance of a tensor.

    Args:
        fi (torch.Tensor): Input tensor.
        power (int): Power to raise the tensor to (default: 2).

    Returns:
        torch.Tensor: Euclidean distance of the tensor.
    """
    x1, x2 = fi.shape
    return torch.sqrt(torch.sum((fi) ** power, dim=-1) / x2)
def Euclidean(*args, **kwargs):
    """
    Calculate the Euclidean distance between two tensors.

    Args:
        a (torch.Tensor): First tensor.
        b (torch.Tensor): Second tensor.
        power (int): Power to raise the difference to (default: 2).

    Returns:
        torch.Tensor: Euclidean distance between the two tensors.
    """
    if len(args) == 2:
        a, b = args
        power = kwargs.get('power', 2)
        x1, x2 = a.shape
        return torch.sqrt(torch.sum((a - b) ** power, dim=-1) / x2)
    elif len(args) == 1:
        return Euclidean_2(args[0], kwargs.get('power', 2))
class Mean_Dis(nn.Module):
    def __init__(self, inverion_model) -> None:
        super().__init__()
        self.inversion_model=inverion_model
        self.flatten=nn.Flatten()
    def forward(self,x):
        x_new=self.inversion_model(x)
        x_fla=self.flatten(x)
        return torch.concat([x_new,x_fla],axis=-1)


def sampling(model,scheduler,latent):
    #global config
    #latent=latent*scheduler.init_noise_sigma
    #print(scheduler.init_noise_sigma)
    for i,t in tqdm(enumerate(scheduler.timesteps)):
        #print(classification_target.argmax(),end=" ")
        with torch.no_grad():
            latent=scheduler.scale_model_input(latent,t)
            residule=model(latent, timestep=t).sample
        latent=scheduler.step(residule,t,latent).prev_sample
        #print(latent.max(),latent.min(),latent.mean())
    print(latent.max(),latent.min(),latent.mean())
    latent=((latent)*255).clamp(0,255).type(torch.uint8)
    return latent
import torchvision
from PIL import Image
import os
def create_new_image_dataset(file_path,image_size=128,ori_file_path="CelebA_HQ_facial_identity_dataset",new_file_path=None):
    
    if(new_file_path==None):
        new_file_path=ori_file_path+"_"+str(image_size)
    #os.makedirs(new_file_path,exist_ok=True)
    if(file_path.find(".jpg")!=-1 or file_path.find(".png")!=-1):
        image_read=torchvision.transforms.Resize((image_size,image_size))(Image.open(file_path))
        image_read.save(file_path.replace(ori_file_path,new_file_path))
        print(image_size)
    else:
        os.makedirs(file_path.replace(ori_file_path,new_file_path),exist_ok=True)
        for path_new in os.listdir(file_path):
            create_new_image_dataset(os.path.join(file_path,path_new),image_size=image_size)
from create_models import create_diffusion_model
from train_config import Generate_ClassCond_Config
def load_inference_model(config:Generate_ClassCond_Config):
    if(config.self_train==False):
        model=UNet2DModel.from_pretrained(config.model_dir).to(config.device)
    else:
        ema_choose=""
        if(config.ema):
            ema_choose="_EMA_"
        model=create_diffusion_model(config.dataset_name)
        if(config.ema):
            model=EMA(model)
        model.load_state_dict(torch.load(config.model_dir))
        model.to(config.device)
    return model