import argparse
import math
import os
import numpy as np
import torch
import torch.distributed as dist
#import torchvision
from ddp_utils import init_processes
from models import create_network
from pytorch_fid.fid_score import calculate_fid_given_paths
from sampler.karras_sample import karras_sample
from sampler.random_util import get_generator
from torch import nn
from torch.multiprocessing import Process
#from torchdiffeq import odeint_adjoint as odeint
from tqdm import tqdm
from PIL import Image
from diffusers.models import AutoencoderKL


if __name__ == "__main__":
    parser = argparse.ArgumentParser("direct models parameters")
    parser.add_argument("--use_origin_adm", action="store_true")
    parser.add_argument("--generator", type=str, default="determ", help="type of seed generator", choices=["dummy", "determ", "determ-indiv"],)
    parser.add_argument("--seed", type=int, default=100, help="seed used for initialization")
    parser.add_argument("--compute_fid", action="store_true", default=False, help="whether or not compute FID")
    parser.add_argument("--compute_nfe", action="store_true", default=False, help="whether or not compute NFE")
    parser.add_argument("--measure_time", action="store_true", default=False, help="wheter or not measure time")
    parser.add_argument("--epoch_id", type=int, default=1000)
    parser.add_argument("--n_sample", type=int, default=50000, help="number of sampled images")
    parser.add_argument("--model_type",type=str,default="DiT-B/2",help="model_type",choices=["adm", "ncsn++", "ddpm++", "DiT-B/2", "DiT-L/2", "DiT-XL/2"],)
    parser.add_argument("--image_size", type=int, default=256, help="size of image")
    parser.add_argument("--f", type=int, default=8, help="downsample rate of input image by the autoencoder")
    parser.add_argument("--scale_factor", type=float, default=0.18215, help="size of image")
    parser.add_argument("--num_in_channels", type=int, default=4, help="in channel image")
    parser.add_argument("--num_out_channels", type=int, default=4, help="in channel image")
    parser.add_argument("--nf", type=int, default=256, help="channel of image")
    parser.add_argument("--centered", action="store_false", default=True, help="-1,1 scale")
    parser.add_argument("--resamp_with_conv", type=bool, default=True)
    parser.add_argument("--num_res_blocks", type=int, default=2, help="number of resnet blocks per scale")
    parser.add_argument("--num_heads", type=int, default=4, help="number of head")
    parser.add_argument("--num_head_upsample", type=int, default=-1, help="number of head upsample")
    parser.add_argument("--num_head_channels", type=int, default=-1, help="number of head channels")
    parser.add_argument("--attn_resolutions", nargs="+", type=int, default=(16, 8, 4), help="resolution of applying attention")
    parser.add_argument("--ch_mult", nargs="+", type=int, default=(1, 2, 3, 4), help="channel mult") 
    parser.add_argument("--label_dim", type=int, default=0, help="label dimension, 0 if unconditional")
    parser.add_argument("--augment_dim", type=int, default=0, help="dimension of augmented label, 0 if not used")
    parser.add_argument("--dropout", type=float, default=0.0, help="drop-out rate")
    parser.add_argument("--num_classes", type=int, default=1, help="num classes")
    parser.add_argument("--label_dropout", type=float, default=0.0, help="Dropout probability of class labels for classifier-free guidance",)
    parser.add_argument("--cfg_scale", type=float, default=1.0, help="Scale for classifier-free guidance")
    parser.add_argument("--pretrained_autoencoder_ckpt", type=str, default="stabilityai/sd-vae-ft-mse")
    parser.add_argument("--output_log", type=str, default="")

    #######################################
    parser.add_argument("--num_steps", type=int, default=100)
    parser.add_argument("--batch_size", type=int, default=16, help="sample generating batch size")

    # sampling argument
    parser.add_argument("--use_karras_samplers", action="store_false", default=True)
    parser.add_argument("--method", type=str, default="euler", help="solver_method",)
    parser.add_argument("--step_size", type=float, default=0.01, help="step_size")
    parser.add_argument("--perturb", action="store_true", default=False)
  
    args = parser.parse_args()

    torch.set_grad_enabled(False)
    device = "cuda"
    to_range_0_1 = lambda x: (x + 1.0) / 2.0
    model = create_network(args).to(device)
    first_stage_model = AutoencoderKL.from_pretrained(args.pretrained_autoencoder_ckpt).to(device)
    first_stage_model.eval()
    ckpt = torch.load("direct_models.pth",map_location=device)
    model.load_state_dict(ckpt, strict=True)
    model.eval()    
    x0 = torch.randn(1, 4, 32, 32).to(device) 
    x1 = x0 + model(x0.new_ones([x0.shape[0]]), x0)
         

    fake_image = first_stage_model.decode(x1 / (args.scale_factor)).sample
    fake_image = torch.clamp(to_range_0_1(fake_image), 0, 1)
    
    Image.fromarray((fake_image.squeeze().permute(1, 2, 0).cpu().numpy()*255.).astype(np.uint8)).save("generated.png")


