import os
import logging
import time
import glob
import yaml
import numpy as np
import tqdm
import torch
import argparse
from torch import nn
import sys
sys.path.insert(0,'./')
import torchvision.utils as tvu
from guided_diffusion.models import Model
import random
from inversion_utils import *
from utils import *
from math import log10, sqrt

with open('configs/sr.yml', 'r') as f:
    task_config = yaml.safe_load(f)


for key in task_config:
    print(key, ':', task_config[key])

### Reproducibility
torch.set_printoptions(sci_mode=False)
ensure_reproducibility(task_config['seed'])


### parse config file and Load the Denoising model
with open( "celeba_hq.yml", "r") as f:
    config1 = yaml.safe_load(f)
config = dict2namespace(config1)
model = Model(config)
ckpt = "checkpoints/celeba_hq.ckpt"
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
logging.info("Using device: {}".format(device))
config.device = device
model.load_state_dict(torch.load(ckpt, map_location=device))
model.to(device)
model.eval()
for param in model.parameters():
    param.requires_grad = False
model = torch.nn.DataParallel(model)

### Define the DDIM scheduler
scheduler=DDIMScheduler(beta_start=config.diffusion.beta_start, beta_end=config.diffusion.beta_end, beta_schedule=config.diffusion.beta_schedule)
scheduler.set_timesteps(task_config['Denoising_steps'])




init_image = Image.open("imgs/00205.png").resize((256, 256))
img_np = np.array(init_image).astype(np.float32) / 255 * 2 - 1
img = torch.tensor(img_np).permute(2,0,1).unsqueeze(0)

downsampling_op = torch.nn.AdaptiveAvgPool2d((256//task_config['downsampling_ratio'],256//task_config['downsampling_ratio'])).cuda() 
for param in downsampling_op.parameters():
    param.requires_grad = False
#b, c, h, w = img.shape
downsampled = downsampling_op(img.cuda())


latent = torch.nn.parameter.Parameter(torch.randn( 1, config.model.in_channels, config.data.image_size, config.data.image_size).to(device)) #7  
l2_loss=nn.MSELoss() #nn.L1Loss()
optimizer = torch.optim.Adam([{'params':latent,'lr':task_config['lr']}])


last_loss = 100.
count = 0.
best_x_t = 0.
best_psnr = 0.
last_psnr=0.
timesteps = scheduler.timesteps#.flip(0)

for k in range(task_config['Optimization_steps']):
    optimizer.zero_grad()
    for i, t in enumerate(tqdm(timesteps)):
        t1 = (torch.ones(1) * t) .cuda()#.to(x_t.device)
        with torch.no_grad():      
            if i == 0:                
                noise_pred = model(latent, t1)   
            else:  
                noise_pred = model(x_t, t1) #.sample    
        noise_pred = noise_pred[:, :3]         
    
        if i == 0:
            x_t = scheduler.step(noise_pred, t, latent, return_dict=True, use_clipped_model_output=True).prev_sample
        else:
            x_t = scheduler.step(noise_pred, t, x_t, return_dict=True, use_clipped_model_output=True).prev_sample
    
    loss = l2_loss(downsampling_op(x_t), downsampled)
    loss.backward()  
    optimizer.step()  
    print(k, 'loss:', loss.item())
    if k % 1 == 0:
        Image.fromarray(np.concatenate([process(x_t, 0), process(img.cuda(), 0)], 1)).save('super_resolution.png')


