import torch
# the first flag below was False when we tested this script but True makes A100 training a lot faster:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

# for FID
import torch
from diffusers.models import AutoencoderKL
from torchmetrics.image.fid import FrechetInceptionDistance

import argparse
from data_loaders.data_loader import CelebAHQ256
from models.model import DiT_models


def denormalize(tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
    mean = torch.tensor(mean).view(3, 1, 1).to(tensor.device)  # Reshape for broadcasting
    std = torch.tensor(std).view(3, 1, 1).to(tensor.device)
    tensor = tensor * std + mean  # Reverse normalization
    tensor = torch.clamp(tensor, 0, 1)
    return tensor

def convert_to_uint8(images):
    images = (images * 255).clamp(0, 255).to(torch.uint8)
    return images


class FID:
    def __init__(self, device):
        self.fid = FrechetInceptionDistance(feature=2048).to(device)
        self.device = device

    def update_batch(self, images, real):
        self.fid.update(images, real=real)

    def update(self, x, real, batch_size=50):
        for i in range(0, x.size(0), batch_size):
            x_cuda = x[i:i + batch_size].to(self.device)
            self.update_batch(x_cuda, real=real)

    def get_fid(self):
        return self.fid.compute()


def eval_model(model, args, timesteps):

    print("Loading CelebA dataset...")
    train_dataset = CelebAHQ256(image_folder="./dataset/CelebAHQ256/", train=True)
    valid_dataset = CelebAHQ256(image_folder="./dataset/CelebAHQ256/", train=False)
    x = torch.cat([train_dataset.images, valid_dataset.images], dim=0)
    x = denormalize(x)
    x = convert_to_uint8(x)
    N = 50000

    fid_calc = FID(args.device)
    fid_calc.update(x, real=True)
    del x

    vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(args.device)

    x0 = torch.randn((N, 4, args.latent_size, args.latent_size), device=args.device)
    num_batches = int(len(x0)/args.sampling_batch_size)

    delta_t = 1.0 / timesteps
    model.eval()

    if args.train_type == "naive":
        print("Training a FM model, so setting timesteps to 128...")
        base_val = torch.log2(torch.tensor(128, dtype=torch.int32, device=x0.device))
    else:
        base_val = torch.log2(torch.tensor(timesteps, dtype=torch.int32, device=x0.device))

    dt_base = torch.full((args.sampling_batch_size,), base_val, device=x0.device)

    for n in range(num_batches):
        if n % 100 == 0:
            print(n, end=" ", flush=True)
        x_t = x0[n*args.sampling_batch_size : (n+1)*args.sampling_batch_size]
        with torch.no_grad():
            for ti in range(timesteps):
                t = ti * delta_t
                t_vector = torch.full((args.sampling_batch_size,), t, device=x0.device)
                y_ = torch.zeros_like(t_vector, dtype=torch.int32, device=x0.device)
                v_pred = model(x_t, t_vector, y_, dt_base)
                x_t = x_t + v_pred * delta_t
            x_t = vae.decode(x_t / 0.18215).sample
            x_t = convert_to_uint8(denormalize(x_t))
        fid_calc.update_batch(x_t, real=False)

    print()
    return fid_calc.get_fid()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Evaluation Configuration')
    parser.add_argument('--gpu', type=int, default=0, help='gpu id')
    parser.add_argument('--checkpoint_file', type=str, help="checkpoint to load from, used during model evaluation")
    parser.add_argument('--train_type', type=str, default='ST', help="training type")
    parser.add_argument('--sampling_batch_size', type=int, default=50)
    parser.add_argument('--shortcut_steps', type=int, default=128)
    _args_ = parser.parse_args()

    if _args_.checkpoint_file:
        _args_.checkpointPath = f"./checkpoints/{_args_.train_type}/{_args_.checkpoint_file}"
        print(f"Evaluating for: {_args_.checkpointPath}, timesteps = {_args_.shortcut_steps}", flush=True)
        device = torch.device(f"cuda:{_args_.gpu}")
        checkpoint = torch.load(_args_.checkpointPath, map_location=device, weights_only=False)
        args = checkpoint["args"]
        args.sampling_batch_size = _args_.sampling_batch_size
        args.gpu = _args_.gpu
        args.device = torch.device(f"cuda:{args.gpu}")

        ema = DiT_models[args.model](input_size=args.latent_size, num_classes=1).to(args.device)

        ema.load_state_dict(checkpoint["ema"])
        fids = eval_model(ema, args=args, timesteps=_args_.shortcut_steps)
        print(fids.item(), flush=True)
