"""
Sample new images from a pre-trained SiT.
"""
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
from torchvision.utils import save_image
from diffusers.models import AutoencoderKL
from download import find_model
from models import SiT_models
from train_utils import parse_ode_args, parse_sde_args, parse_transport_args
from transport import create_transport, Sampler
import argparse
import sys
from time import time
from torchvision import transforms
import os
from PIL import Image
from torchvision.transforms.functional import to_pil_image
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F

def center_crop_arr(pil_image, image_size):
    """
    Center cropping implementation from ADM.
    https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
    """
    while min(*pil_image.size) >= 2 * image_size:
        pil_image = pil_image.resize(
            tuple(x // 2 for x in pil_image.size), resample=Image.BOX
        )

    scale = image_size / min(*pil_image.size)
    pil_image = pil_image.resize(
        tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
    )

    arr = np.array(pil_image)
    crop_y = (arr.shape[0] - image_size) // 2
    crop_x = (arr.shape[1] - image_size) // 2
    return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])

def load_reference_images_mse(folder, device, image_size=256):
    transform = transforms.Compose([
        transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5]*3, std=[0.5]*3, inplace=True)
    ])
    images = []
    for fname in os.listdir(folder):
        if fname.lower().endswith(('.jpg', '.png', '.jpeg')):
            img = Image.open(os.path.join(folder, fname)).convert("RGB")
            img = transform(img).unsqueeze(0).to(device)
            images.append(img)
    if not images:
        raise ValueError("No images found in folder:", folder)
    imgs = torch.cat(images, dim=0)  # [N, 3, H, W]
    return imgs


def main(mode, args):
    # Setup PyTorch:
    torch.manual_seed(args.seed)
    torch.set_grad_enabled(False)
    device = "cuda:0" if torch.cuda.is_available() else "cpu"

    if args.ckpt is None:
        assert args.model == "SiT-XL/2", "Only SiT-XL/2 models are available for auto-download."
        assert args.image_size in [256, 512]
        assert args.num_classes == 1000
        assert args.image_size == 256, "512x512 models are not yet available for auto-download." # remove this line when 512x512 models are available
        learn_sigma = args.image_size == 256
    else:
        learn_sigma = False

    # Load model:
    latent_size = args.image_size // 8
    model = SiT_models[args.model](
        input_size=latent_size,
        num_classes=args.num_classes,
        learn_sigma=learn_sigma,
    ).to(device)
    # Auto-download a pre-trained model or load a custom SiT checkpoint from train.py:
    ckpt_path = args.ckpt or f"SiT-XL-2-{args.image_size}x{args.image_size}.pt"
    state_dict = find_model(ckpt_path)
    model.load_state_dict(state_dict)
    model.eval()  # important!
    transport = create_transport(
        args.path_type,
        args.prediction,
        args.loss_weight,
        args.train_eps,
        args.sample_eps
    )
    sampler = Sampler(transport)
    if mode == "ODE":
        if args.likelihood:
            assert args.cfg_scale == 1, "Likelihood is incompatible with guidance"
            sample_fn = sampler.sample_ode_likelihood(
                sampling_method=args.sampling_method,
                num_steps=args.num_sampling_steps,
                atol=args.atol,
                rtol=args.rtol,
            )
        else:
            '''
            sample_fn = sampler.sample_ode(
                sampling_method=args.sampling_method,
                num_steps=args.num_sampling_steps,
                atol=args.atol,
                rtol=args.rtol,
                reverse=args.reverse
            )
            '''
            
            sample_fn = sampler.sample_ode()
            
            
    elif mode == "SDE":
        sample_fn = sampler.sample_sde(
            sampling_method=args.sampling_method,
            diffusion_form=args.diffusion_form,
            diffusion_norm=args.diffusion_norm,
            last_step=args.last_step,
            last_step_size=args.last_step_size,
            num_steps=args.num_sampling_steps,
        )
    

    vae = AutoencoderKL.from_pretrained('sd_vae').to(device)

    # Labels to condition the model with (feel free to change):
    class_labels = [1,1,1,1]
    
    # Create sampling noise:
    n = len(class_labels)
    z = torch.randn(n, 4, latent_size, latent_size, device=device)
    y = torch.tensor(class_labels, device=device, dtype=torch.long)

    # Setup classifier-free guidance:
    z = torch.cat([z], 0)
    y = torch.cat([y], 0)
    model_kwargs = dict(y=y)
    # Sample images:
    start_time = time()
    samples = sample_fn(z, model.forward , **model_kwargs)[-1]
    samples = vae.decode(samples / 0.18215).sample

    reference_images_path = "data/all_classes"
    reference_images = load_reference_images_mse(reference_images_path, device=None)
    reference_dataset = TensorDataset(reference_images)
    reference_loader = DataLoader(reference_dataset, batch_size=128, shuffle=False)

    nearest_images = []

    with torch.no_grad():
        for sample in samples:  
            sample = sample.unsqueeze(0) 
            min_dist = float('inf')
            nearest_image = None

            for ref_batch in reference_loader:
                ref_batch = ref_batch[0].to(sample.device) 

                dists = F.mse_loss(ref_batch, sample.expand_as(ref_batch), reduction='none')  # (B, 3, h, w)
                dists = dists.view(dists.size(0), -1).mean(dim=1)  # (B,)

                batch_min_dist, batch_min_idx = torch.min(dists, dim=0)
                if batch_min_dist.item() < min_dist:
                    min_dist = batch_min_dist.item()
                    nearest_image = ref_batch[batch_min_idx]  # shape: (3, h, w)

            nearest_images.append(nearest_image)

    nearest_images_tensor = torch.stack(nearest_images, dim=0)

    final_image = torch.cat([samples,nearest_images_tensor],dim=0)

    save_image(final_image, "nearest_samples.png", nrow=4, normalize=True, value_range=(-1, 1))



if __name__ == "__main__":
    mode = "ODE"  
    
    class Namespace:
        def __init__(self, **kwargs):
            self.__dict__.update(kwargs)
    
    args = Namespace(
        model="SiT-B/2",
        ckpt="checkpoints/checkpoint.pt",
        vae="mse",
        image_size=256,
        num_classes=1,
        cfg_scale=0,
        num_sampling_steps=250,
        seed=6,
        path_type="Linear",
        prediction="noise",
        loss_weight="none",
        train_eps=1e-5,
        sample_eps=1e-5,
        sampling_method="dopri5",
        atol=1e-5,
        rtol=1e-5,
        reverse=False,
        likelihood=False
    )

    main(mode, args)




