import os
import logging
import time
import glob
import yaml
import numpy as np
import tqdm
import torch
import argparse
from torch import nn
import sys
sys.path.insert(0,'./')
import torchvision.utils as tvu
from guided_diffusion.models import Model
import random
from inversion_utils import *
from utils import *
from math import log10, sqrt

with open('configs/inpaint.yml', 'r') as f:
    task_config = yaml.safe_load(f)


for key in task_config:
    print(key, ':', task_config[key])

### Reproducibility
torch.set_printoptions(sci_mode=False)
ensure_reproducibility(task_config['seed'])



### parse config file and Load the Denoising model
with open( "celeba_hq.yml", "r") as f:
    config1 = yaml.safe_load(f)
config = dict2namespace(config1)
model = Model(config)
ckpt = "checkpoints/celeba_hq.ckpt"
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
logging.info("Using device: {}".format(device))
config.device = device
model.load_state_dict(torch.load(ckpt, map_location=device))
model.to(device)
model.eval()
for param in model.parameters():
    param.requires_grad = False
model = torch.nn.DataParallel(model)

### Define the DDIM scheduler
scheduler=DDIMScheduler(beta_start=config.diffusion.beta_start, beta_end=config.diffusion.beta_end, beta_schedule=config.diffusion.beta_schedule)
scheduler.set_timesteps(task_config['Denoising_steps'])




init_image = Image.open("imgs/00205.png").resize((256, 256))
img_np = np.array(init_image).astype(np.float32) / 255 * 2 - 1
img = torch.tensor(img_np).permute(2,0,1).unsqueeze(0)



latent = torch.nn.parameter.Parameter(torch.randn( 1, config.model.in_channels, config.data.image_size, config.data.image_size).to(device))  
l2_loss=nn.MSELoss() #nn.L1Loss()
optimizer = torch.optim.Adam([{'params':latent,'lr':task_config['lr']}])



mask = np.ones((256, 256))   
for i in range(128-40, 128+40):
    for j in range(128-40, 128+40):
        mask[i, j]=0.


t_mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0).cuda()



last_loss = 100.
count = 0.
best_x_t = 0.
best_psnr = 0.
last_psnr=0.
timesteps = scheduler.timesteps#.flip(0)

for k in range(task_config['Optimization_steps']):
    optimizer.zero_grad()
    for i, t in enumerate(tqdm(timesteps)):
        t1 = (torch.ones(1) * t) .cuda()#.to(x_t.device)
        with torch.no_grad():      
            if i == 0:                
                noise_pred = model(latent, t1)   
            else:  
                noise_pred = model(x_t, t1) #.sample    
        noise_pred = noise_pred[:, :3]         
    
        if i == 0:
            x_t = scheduler.step(noise_pred, t, latent, return_dict=True, use_clipped_model_output=True).prev_sample
        else:
            x_t = scheduler.step(noise_pred, t, x_t, return_dict=True, use_clipped_model_output=True).prev_sample
    
    loss = l2_loss(x_t*t_mask, img.cuda()*t_mask)
    loss.backward()  
    optimizer.step()  
    print(k, 'loss:', loss.item())
    if k % 1 == 0:
        Image.fromarray(np.concatenate([process(x_t, 0), process(img.cuda(), 0), process(img.cuda()*t_mask, 0), process(img.cuda()*t_mask + x_t * (1.-t_mask), 0)], 1)).save('inpainted.png')


