from Classifer_Model import EncoderUnet
from ResNet_Time import ResNet_Time,BasicBlock_Time
from diffusers import UNet2DModel
def create_classifier_models(model_name,num_classes=10,**kwargs):
    if(model_name=="ResNet"):
        return ResNet_Time(BasicBlock_Time,num_classes=num_classes,**kwargs)
    if(model_name=="EncoderUnet"):
        return EncoderUnet(out_dim=num_classes,**kwargs)
def create_diffusion_model(model_name="celebahq-128"):
    
    if(model_name.find("celebahq-128")!=-1):
        return UNet2DModel(
    sample_size=128,  # the target image resolution
    in_channels=3,  # the number of input channels, 3 for RGB images
    out_channels=3,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(64, 64, 128, 128,256, 256, 512,512),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "AttnDownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",
        "AttnDownBlock2D",
    ),
    up_block_types=(
        "AttnUpBlock2D",
        "UpBlock2D",  # a regular ResNet upsampling block
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",
        "AttnUpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)
    if(model_name=="celebahq-256"):
        return UNet2DModel().from_pretrained("google/ddpm-ema_celebahq-256")
    if(model_name=="Cifar10"):
        return UNet2DModel().from_pretrained("google/ddpm-cifar10-32")
import torch 
from torchvision import datasets,transforms
def check_fined_loaded_parameters(model,dataset,batch_size=128):
    #dataset=datasets.CIFAR10(r'D:\Living_and_Study_In_University\Dataset\CIFA-10',train=True,download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]))
    dataloader=torch.utils.data.DataLoader(dataset,batch_size=batch_size,shuffle=True)
    accu=0
    for i,(data,target) in enumerate(dataloader):
        print(data.shape)
        output=model(data.to("cuda"))
        accu+=torch.sum(torch.argmax(output,dim=1)==target.to("cuda"))
    print(accu/len(dataset))
    return accu/len(dataset)