from Inversion_Image import image_reversion
import torch 
import torch.nn as nn
#import diffusers
from Inversion_Image import image_reversion
from SAM import sam
import numpy 
import torchvision
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline,DPMSolverMultistepScheduler,DDIMPipeline,DDIMScheduler
from diffusers import UNet2DModel
import torch.nn.functional as F
from diffusers import DDPMScheduler
from utils import ts_to_weight_func,fd_preprocess
from einops import rearrange
model_id = "google/ddpm-cifar10-32"
num_classes=10
torch_device="cuda"
def conf_fn():
    pass
def sampling(model,scheduler,latent,beta=1,inversion_model=None,classification_target=None,torch_device="cuda"):
    #global config
    if(type(latent)==image_reversion):
        latent=latent.image*scheduler.init_noise_sigma
    if(type(classification_target)==int):
        classification_target=torch.eye(config.num_classes)[classification_target].to(torch_device).unsqueeze(0).to(torch_device)
    latent=latent*scheduler.init_noise_sigma
    all_latents=[]
    ts=[]
    for i,t in enumerate(scheduler.timesteps):
        #print(classification_target.argmax(),end=" ")
        with torch.no_grad():
            latent=scheduler.scale_model_input(latent,t)
            if(inversion_model!=None):
                #print(1)
                residule=model(latent, timestep=t).sample+beta*(1 - scheduler.alphas_cumprod[t]).sqrt()*conf_fn(latent,classification_target,t,inversion_model)
            else:
                residule=model(latent, timestep=t).sample
        latent=scheduler.step(residule,t,latent).prev_sample
        templatent=latent.clone()
        ts.append(t)
        if(len(templatent.shape)==3):
            templatent=templatent.unsqueeze(0)
        # if(len(templatent.shape)==4):
        #     templatent=templatent.unsqueeze(1)
        #print(templatent.shape)
        all_latents.append((templatent/2+0.5).clamp(0,1))
    all_latents=torch.stack(all_latents,dim=1).to(torch_device)
    #print(all_latents.shape)
    #print(torch.tensor(ts).shape)
    ts=rearrange(torch.tensor(ts).to(torch_device).long(),"(p1 b) ->p1 b",p1=1).repeat(all_latents.shape[0],1)
    return all_latents,ts
bn_loss_save=[]
def hook_fn( module, input, output):
    # hook co compute deepinversion's feature distribution regularization
    nch = input[0].shape[1]
    mean = input[0].mean([0, 2, 3])
    var = input[0].permute(1, 0, 2, 3).contiguous().view([nch, -1]).var(1, unbiased=False)
    #forcing mean and variance to match between two distributions
    #other ways might work better, i.g. KL divergence
    r_feature = torch.norm(module.running_var.data - var, 2) + torch.norm(
        module.running_mean.data - mean,  2)
    bn_loss_save.append(r_feature)
bn_loss_save=[]
from einops import rearrange
from utils import fd_preprocess
from accelerate import Accelerator
from utils import ts_to_weight_func
def hook_fn( module, input, output):
    # hook co compute deepinversion's feature distribution regularization
    nch = input[0].shape[1]
    mean = input[0].mean([0, 2, 3])
    var = input[0].permute(1, 0, 2, 3).contiguous().view([nch, -1]).var(1, unbiased=False)
    #forcing mean and variance to match between two distributions
    #other ways might work better, i.g. KL divergence
    r_feature = torch.norm(module.running_var.data - var, 2) + torch.norm(
        module.running_mean.data - mean,  2)
    #calculate KL divergence between two gaussin distribution
    
    bn_loss_save.append(r_feature)

class Model_Inversion:
    def __init__(self,batch=3,channels=3,height=32,width=32,dir="google/ddpm-cifar10-32",timesteps=10,torch_device="cuda",usesam=False,epoches=50) -> None:
        # self.scheduler=DPMSolverMultistepScheduler()
        # self.scheduler.set_timesteps(timesteps)
        # self.model=UNet2DModel.from_pretrained(dir).to(torch_device)
        self.timesetps=timesteps
        self.dir=dir
        self.torch_device=torch_device
        self.usesam=usesam
        self.epoches=epoches
        self.channels=channels
        self.height=height
        self.width=width
        self.batch=batch
        self.model=UNet2DModel.from_pretrained(self.dir).to(self.torch_device)
    def __call__(self,inversion_model,sampling,show_epoch=50,num_show=10,lr=0.001):
        #optim settings
        scheduler2=DDPMScheduler().from_pretrained("google/ddpm-cifar10-32")
        scheduler=DDIMScheduler().from_config(scheduler2.config)
        #scheduler=DPMSolverMultistepScheduler(solver_order=2).from_config(scheduler2.config)
        scheduler.set_timesteps(self.timesetps)
        
        all_up=100
        lr=lr
        latents=image_reversion(self.batch,self.channels,self.height,self.width).to(self.torch_device)
        if(self.usesam==False):
            optim=torch.optim.AdamW(latents.parameters(),lr=lr)
        else:
            base_optim=torch.optim.AdamW
            optim=sam.SAM(latents.parameters(),base_optim,lr=lr)
        loss_function=nn.CrossEntropyLoss(reduction="none")
        #add module hook
        for module in inversion_model.modules():
            if isinstance(module, nn.BatchNorm2d):
                module.register_forward_hook(hook_fn)
                
        
        #start inversion
        bar=tqdm(range(0,self.epoches+1))
        classification_target=None
        for epoch in bar:
            classification_target, images, plot_images, prediction, loss = self.cal_loss(inversion_model, 
                                                                                         sampling, 
                                                                                         num_show, 
                                                                                         scheduler, 
                                                                                         self.model, 
                                                                                         all_up, 
                                                                                         latents, 
                                                                                         loss_function, 
                                                                                         classification_target)
            
            #define closure for SAM optimization
            def closure():
                _,__,___,____,loss=self.cal_loss(inversion_model, 
                                                  sampling, 
                                                  num_show, 
                                                  scheduler, 
                                                  self.model, 
                                                  all_up, 
                                                  latents, 
                                                  loss_function, 
                                                  classification_target)
                return loss
            #use SAM for optimization
            if(self.usesam):
                optim.step(closure)
            else:
                optim.step()
            optim.zero_grad()
            bar.set_description("loss:{}".format(loss.item()))
            if(epoch%(show_epoch+1)==-1):
                
                prediction=rearrange(prediction,"(a b) ... -> a b ...",a=self.batch,b=num_show).argmax(dim=-1)
                #print(ts)
                #print(prediction.argmax(dim=-1))
                #print(prediction.shape)
                print("epoch:{}".format(epoch),"classification loss:{}".format(loss.item()),"bn_loss:{}".format(sum(bn_loss_save)),classification_target[:,0])
                
                for j in range(plot_images.shape[0]):
                    plt.figure(figsize=(15,20))
                    print(prediction[j])
                    for i in range(min(15,num_show)):
                        plt.subplot(1,min(15,num_show),i+1)
                        plt.imshow(plot_images[j,scheduler.num_inference_steps-min(15,num_show)+i])
                    plt.show()
        return images[:,-1]

    def cal_loss(self, inversion_model, sampling, num_show, scheduler, model, all_up, latents, loss_function, classification_target):
        images,ts=sampling(model,scheduler,latents,torch_device=self.torch_device)
        ori_images=images.clone().detach()
        ts_weight=[]
        for i in range(ts.shape[1]):
                #print(ts[:,i])
            ts_weight.append(torch.tensor(ts_to_weight_func(ts[0,i])))
        ts_weight=torch.stack(ts_weight,dim=0).unsqueeze(0).repeat(ts.shape[0],1)[:,-num_show:]
        ts_weight=rearrange(ts_weight,"a b ... -> (a b) ...")
            #extract plot images
        plot_images=images.clone().detach().cpu().permute(0,1,3,4,2).numpy()
            #start loss cal
        bn_loss_save.clear()
        if(classification_target==None):
                #print(images[:,-1].shape,ts[:,-1].shape)
                #classification_target=torch.tensor([1]).repeat(self.batch).unsqueeze(1).repeat(1,num_show).to(self.torch_device)
            classification_target=inversion_model(images[:,-1],ts[:,-1]).argmax(dim=-1).unsqueeze(1).repeat(1,num_show)
        images,temp_cross_entropy,ts=fd_preprocess(images,classification_target,ts,num_show=num_show)
        prediction=inversion_model(images,ts) 
        bn_losses = self.cal_bn_loss(all_up)
        loss=(loss_function(prediction,temp_cross_entropy)*ts_weight.to(self.torch_device)+bn_losses).mean()

        loss.backward(retain_graph=True)
        return classification_target,ori_images,plot_images,prediction,loss

    def cal_bn_loss(self, all_up):
        bn_loss_save.sort()
        bn_loss_start=[]
        for i in range(min(all_up,len(bn_loss_save))):
            if(sum(bn_loss_start)<=sum(bn_loss_save)*0.01):
                bn_loss_start.append(bn_loss_save[i])
            else:
                break
        if(bn_loss_start!=[]):
            bn_losses=torch.stack(bn_loss_start).mean()/bn_loss_start[-1]
        else:
            bn_losses=0
        return bn_losses
if __name__=="__main__":
    #model=Model_Inversion()
    #model()
    #read hyper parameters from the console
    import argparse
    import os 
    parser=argparse.ArgumentParser()
    parser.add_argument("--target_dataset",type=str,default="Cifar-10")
    parser.add_argument("--image_generate_nums",type=int,default=200)
    parser.add_argument("--train_epoch",type=int,default=150)
    parser.add_argument("--start_train_epoch",type=int,default=0)
    parser.add_argument("--use_sam",type=int,default=0)
    parser.add_argument("--lr",type=float,default=0.001)
    parser.add_argument("--num_show",type=int,default=30)
    parser.add_argument("--timesteps",type=int,default=50)
    target_dataset=parser.parse_args().target_dataset
    image_generate_nums=parser.parse_args().image_generate_nums
    train_epoch=parser.parse_args().train_epoch
    start_train_epoch=parser.parse_args().start_train_epoch
    num_show=parser.parse_args().num_show
    lr=parser.parse_args().lr
    timesetps=parser.parse_args().timesteps
    usesam=parser.parse_args().use_sam
    print(usesam)
    image_save_dir='./output_image/'+target_dataset+"-diff_class_unet"+"-"+str(train_epoch)+"-"+str(lr).split('.')[-1]+"-"+str(num_show)+"-timesteps-"+str(timesetps)+"-sam-"+str(usesam)
    os.makedirs(image_save_dir,exist_ok=True)
    
    inver=Model_Inversion(batch=15,channels=3,height=32,width=32,dir="google/ddpm-cifar10-32",timesteps=timesetps,torch_device="cuda",usesam=usesam,epoches=train_epoch)
    from Classifer_Model import EncoderUnet
    from train_config import BaseConfig 
    config=BaseConfig()
    inversion_model=EncoderUnet(dim=config.unet_dim,out_dim=config.num_classes,dim_mults=config.dim_mults)
    inversion_model.load_state_dict(torch.load(config.out_dir+"/epoch302.pth"))
    inversion_model=inversion_model.to(device=torch_device)
    
    for param in inversion_model.parameters():
        param.requires_grad=True
    print("start training")
    for num in range(1+start_train_epoch+105,image_generate_nums+1+start_train_epoch):
        print("epoch: ",num)
        im=inver(inversion_model,sampling,lr=lr,show_epoch=train_epoch,num_show=num_show)
        #torchvision.utils.save_image(im,os.path.join(image_save_dir,str(num)+".jpg"))
        for i in range(im.shape[0]):
            torchvision.utils.save_image(im[i],os.path.join(image_save_dir,str(num)+"-"+str(i)+".png"))