"""
Like image_sample.py, but use a noisy image classifier to guide the sampling
process towards more realistic images.
"""

import argparse
import os

import numpy as np
import torch as th
import torch.distributed as dist
import torch.nn.functional as F
import matplotlib.pyplot as plt 
from guided_diffusion import dist_util, logger
from guided_diffusion.script_util import (
    NUM_CLASSES,
    model_and_diffusion_defaults,
    classifier_defaults,
    create_model_and_diffusion,
    create_classifier,
    add_dict_to_argparser,
    args_to_dict,
    create_gaussian_diffusion
)
from Classifer_Model import EncoderUnet
import torch 
from train_config import BaseConfig as config
from train_config import OneSampleConfig,BaseConfig,Generate_ClassCond_Config,Generate_OneSample_Config
from diffusers import UNet2DModel
from utils import load_inference_model     
from create_models import create_classifier_models
config = Generate_OneSample_Config()

if __name__=="__main__":
    import argparse
    import os 
    parser=argparse.ArgumentParser()
    parser.add_argument("--target_dataset",type=str,default="celebahq-synthetic",choices=["cifar10","celebahq","celebahq-128","Cifar10-synthetic","lsun_church_outdoor","lsun_tower","lsun_kitchen","lsun_classroom","lsun_conference","lsun_dining_room","lsun_restaurant","lsun_library","lsun_museum","lsun_restaurant"]
                        ,help="the target dataset to generate images")
    parser.add_argument("--image_generate_nums",type=int,default=5000)
    #parser.add_argument("--train_epoch",type=int,default=150)
    parser.add_argument("--classifier_guidance",type=float,default=0)
    parser.add_argument("--start_train_epoch",type=int,default=0)
    parser.add_argument("--ddim",type=int,default=1)
    parser.add_argument("--class_cond",type=int,default=1)
    parser.add_argument("--class_id",type=int,default=-1)
    parser.add_argument("--generate_choose",type=int,default=1500)
    parser.add_argument("--classifier_choose",type=int,default=902)
    #parser.add_argument("--use_sam",type=int,default=0)
    #parser.add_argument("--lr",type=float,default=0.001)
    #parser.add_argument("--num_show",type=int,default=30)
    #parser.add_argument("--timesteps",type=int,default=50)
    target_dataset=parser.parse_args().target_dataset
    image_generate_nums=parser.parse_args().image_generate_nums
    classifier_guidance=parser.parse_args().classifier_guidance
    start_train_epoch=parser.parse_args().start_train_epoch
    ddim=parser.parse_args().ddim
    class_cond=parser.parse_args().class_cond
    class_id=parser.parse_args().class_id
    generate_choose=parser.parse_args().generate_choose
    classifier_choose=parser.parse_args().classifier_choose
    #print(usesam)
    if(ddim):
        diffusion=create_gaussian_diffusion(timestep_respacing="ddim100")
    else:
        diffusion=create_gaussian_diffusion(steps=1000)
    if(class_cond):
        config=Generate_ClassCond_Config(dataset_name=target_dataset,classifier_choose=classifier_choose,generate_choose=generate_choose)
    elif(class_cond==-1):
        config=Generate_OneSample_Config()
    if(class_id==-1):
        pre=0
        aft=config.num_classes
    else:
        pre=class_id
        aft=class_id+1
    config.batch_size=24
    model=load_inference_model(config)
    #print(model.config)
    if(config.classifier_type=="ResNet"):
        #print(config.dim_mults,config.classifier_type,config.inversion_model_dir)
        classifier=create_classifier_models(config.classifier_type,config.num_classes,layers=config.dim_mults)
    if(config.classifier_type=="EncoderUnet"):
        classifier=create_classifier_models(config.classifier_type,config.num_classes,dim=config.unet_dim,dim_mults=config.dim_mults)
    print(sum(p.numel() for p in classifier.parameters() if p.requires_grad))
    classifier.load_state_dict(torch.load(config.inversion_model_dir))
    inversion_model=classifier.to(config.torch_device)
    classifier.eval()
    inversion_model.eval()
    
    print(config.batch_size)
    def cond_fn(x, t, y=None):
        assert y is not None
        with th.enable_grad():
            multi=classifier_guidance
            x_in = x.detach().requires_grad_(True)
            logits = classifier(x_in, t)
            log_probs = F.log_softmax(logits, dim=-1)
            selected = log_probs[range(len(logits)), y.view(-1)]
            return th.autograd.grad(selected.sum(), x_in)[0] * multi
    def model_fn(x, t, y=None):
            assert y is not None
            return model(x, t, y )
    all_images = []
    all_labels = []
    from tqdm.auto import tqdm
    num_iterations = image_generate_nums // config.batch_size+1
    bar=tqdm(range(num_iterations))
    image_save_dir='./output_image/'+target_dataset+"-class_diffusion-"+str(classifier_guidance)+"-ddim-"+str(ddim)+"-classcond-"+str(class_cond)+"-classid-"+str(class_id)+"-generate_choose-"+str(generate_choose)+"-classifier_choose-"+str(classifier_choose)
    os.makedirs(image_save_dir,exist_ok=True)
    step=1
    for _ in bar:
        model_kwargs = {}
        classes = th.randint(low=pre, high=aft, size=(config.batch_size,), device=config.device)
        model_kwargs["y"] = classes
        if(ddim):
            print("use ddim")
            sample_fn=diffusion.ddim_sample_loop
        else:
            sample_fn = diffusion.p_sample_loop 
        sample = sample_fn(
            model_fn,
            (config.batch_size, 3, config.image_size, config.image_size),
            clip_denoised=True,
            model_kwargs=model_kwargs,
            cond_fn=cond_fn,
            device=config.device,
            progress=True
        )
        samples = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
        samples = samples.permute(0, 2, 3, 1)
        samples = samples.contiguous()
        for i in range(samples.shape[0]):
            #save the corresponding images
            plt.imsave(os.path.join(image_save_dir,str(step)+"-"+str(i)+"-class-"+str(classes[i].item())+".png"),samples[i].cpu().numpy())
        step+=1
        #all_images.extend([sample.cpu().numpy() for sample in samples])
        #all_labels.extend([labels.cpu().numpy() for labels in classes])
        #logger.log(f"created {len(all_images) * config.batch_size} samples")
        bar.set_description()