import torch
import numpy as np
import pickle
from PIL import Image



def flatten_feats(t):
    return t.reshape(t.size(0), -1)


@torch.no_grad()
def compute_kid_grad_poly3(real_latent, gen_latent):
    """Computing gradient of 3D polynomial kernel w.r.t the latent x_t.
    k(x, y) = ((x^T y) / D + 1) ^ 3"""
    
    X = flatten_feats(real_latent)
    Y = flatten_feats(gen_img)

    Nr, D = X.size()
    Ng = Y.size(0)
    yk = Y[Ng - 1] 

    W = torch.ones(Nr, device=X.device, dtype=X.dtype)
    W = (W / W.sum().clamp_min(1e-9)).detach()

    s_yy_k = Y.matmul(yk) / D + 1.0 
    s_xy_k = X.matmul(yk) / D + 1.0  

    c_yy = 3.0 * (s_yy_k * s_yy_k)           
    c_xy = 3.0 * (s_xy_k * s_xy_k) * W        

    grad_gen_g = (2.0 / (Ng * Ng * D)) * (Y.t().matmul(c_yy))   
    grad_real_g = (1.0 / (Ng * D)) * (X.t().matmul(c_xy))       

    gk = grad_gen_g - 2.0 * grad_real_g 

    return gk.view_as(gen_img[-1])


def get_all_latents(self, F, f):
    # Concatenate all features
    F_ = torch.cat((F, f), dim=0) 
    return F_


@torch.no_grad()
def mmd_guidance(img, gen_latent, real_latent, ts, last_step, guidance_alpha):
    """This function is called after each update of x_t (from x_t -> x_{t-1}) to call gradient function
    and apply guidance scale. At the last step the final image in latent space is saved to gen_latent
    matrix for usange in next spteps of MMD"""
    latent = latent_sampler(img, ts)    

    if gen_latent is not None:

        gen_l = get_all_latents(gen_latent, latent)
        grads = compute_kid_grad_poly3(real_latent, gen_l)

        normalized_grads = grads / (grads.norm(2).detach() + 1e-5) * img.norm(2).detach()
        grads = normalized_grads * guidance_alpha
    
    else:
        # First sample
        grads = torch.zeros_like(img) 

    if last_step == NUM_STEPS - 1:
        if gen_latent is None:
            gen_latent = latent
        else:
            gen_latent = self.get_all_latents(gen_latent, latent)
        

    return grads, gen_latent



def precompute_real_latent(pre_process_clip, pipline, img_paths):
    """precomputations of the latent feature of user's image."""

    torch_preprocess = transforms.Compose([
        transforms.Resize(512, interpolation=InterpolationMode.BICUBIC, max_size=None, antialias='warn'),
        transforms.CenterCrop(size=(512, 512)),
        transforms.ToTensor(),])
        

    F_M_real = None

    for img_path in img_paths:
        img = Image.open(img_path)

        image = torch_preprocess(img)
        image = (image - 0.5) * 2
        image = image.unsqueeze(0).to('cuda')
        features_ = pipline.vae.encode(image.half()).latent_dist.mean * pipline.vae.config.scaling_factor 
        

        if F_M_real is None:
            F_M_real = features_.detach()
        else:
            F_M_real = torch.cat((F_M_real, features_.detach()), dim=0)  # Concatenate features
        

    with open(f'F_M_real_all.pkl', 'wb') as f:
        pickle.dump(F_M_real, f)
    print(f'F_M_real_all_{dataset}_{num_real_samples}.pkl saved!')



if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='MMD Guidance')
    parser.add_argument('--user_dir', type=str,
                        default='',
                        help="User dir path")

    DiffusionPipeline = #Insert difusion model

    real_latent = precompute_real_latent(pre_process_clip, pipline, args.user_dir)

    """The mmd_guidance, get_all_latents, compute_kid_grad_poly3, flatten_feats functions could be added
    to samplers file of DiffusionPipeline with slight changes (if any) to produce MMD-guided samples."""


