import numpy as np
import torch
import torch.nn.functional as F
from torch.optim import Adam
import os
import matplotlib.pyplot as plt
from my_utils import *
from tqdm import tqdm
import imageio


import os
import re
from typing import List, Optional, Tuple, Union
import sys
sys.path.append('stylegan3')

import click
import dnnlib
import numpy as np
import PIL.Image
import torch
import legacy
from autodecoder import grid_sample_customized




class temp_var(torch.nn.Module):

    def __init__(self, x_init):
        super(temp_var, self).__init__()
        self.x = torch.nn.Parameter(x_init)




def forward_operator(x, coords, model):

	out = model(coords , x)
	out_grad = gradient(out, coords, grad_outputs=None)
	# out_grad = multiple_gradient(out, coords)

	return out_grad, out




def poisson_first_order(exp_path, model, test_loader, train_loader = None,
	generative_model = None, aeder = None, flow = None):

	n_steps = 20000
	b = 1
	ngrid = 1
	k = 2
	image_size = 128
	learning_rate = 1e-1
	batch_pixels = 512
	lam = 0
	sample_number = 21

	device = model.linear1.weight.data.device
	mse_loss = F.mse_loss
	# mse_loss = F.l1_loss

	images_8k = next(iter(test_loader)).to(device)[sample_number:sample_number + 1]
	c = images_8k.shape[2]
	images_8k = images_8k.reshape(-1, 8*image_size, 8*image_size, c).permute(0,3,1,2)
	images = F.interpolate(images_8k, size = image_size, antialias = True, mode = 'bilinear')


	# generative_model = 'stylegan'

	if generative_model == None:

		x_init = 0.0
		cnt = 0
		for item, coords_super, image in train_loader:

			image = image.reshape(-1, 8*image_size, 8*image_size, c).permute(0,3,1,2)
			image = F.interpolate(image, size = image_size , antialias = True, mode = 'bilinear')

			x_init += torch.mean(image, axis = 0)
			cnt += 1

			if cnt > 9:
				break

		x_init /= 10
		print(x_init.shape)
		x_init = torch.unsqueeze(x_init , dim = 0)
		x_init = x_init.expand(b,-1,-1,-1)
		print(x_init.shape)

		t_coords = get_mgrid(image_size).reshape(-1, 2)
		t_coords = torch.unsqueeze(t_coords, dim = 0)
		t_coords = t_coords.expand(b,-1, -1).to(device)
		t_coords = t_coords.clone().detach().requires_grad_(True)

		print(images.shape, t_coords.shape)

		out_true = model(t_coords, images)
		y = gradient(out_true, t_coords).detach()
		out_true = out_true.detach().cpu().numpy()



		x_var = temp_var(x_init).to(device)
		optimizer = Adam(x_var.parameters(), lr=learning_rate)

		with tqdm(total=n_steps) as pbar:
		  	for i in range(n_steps):

		  		coords = get_mgrid(image_size).reshape(-1, 2)
		  		coords = torch.unsqueeze(coords, dim = 0)
		  		coords = coords.expand(b,-1, -1).to(device)
		  		coords = coords.clone().detach().requires_grad_(True)

		  		pixels = np.random.randint(low = 0, high = image_size**2, size = batch_pixels)
		  		batch_coords = coords[:,pixels]
		  		batch_y = y[:,pixels]
		  		batch_out_true = out_true[:,pixels]

		  		optimizer.zero_grad()


		  		y_hat, out_hat = forward_operator(x_var.x, batch_coords , model)
		  		loss = mse_loss(batch_y, y_hat)
		  		loss.backward()
		  		optimizer.step()

		  		out_hat = out_hat.detach().cpu().numpy()
		  		snr = SNR(batch_out_true, out_hat)

		  		pbar.set_description('Loss: {:.2f}| snr: {:.2f}'.format(loss, snr))
		  		pbar.update(1)



	elif generative_model == 'stylegan':


		network_pkl = 'stylegan3/stylegan2-celebahq-256x256.pkl'
		noise_mode = 'const'
		class_idx = None
		seeds = 3
		truncation_psi = 1


		print('Loading networks from "%s"...' % network_pkl)
		with dnnlib.util.open_url(network_pkl) as f:
		    G = legacy.load_network_pkl(f)['G_ema'].to(device)


		label = torch.zeros([1, G.c_dim], device=device)
		z_init = torch.from_numpy(np.random.randn(25, G.z_dim)).to(device)

		z_var = temp_var(z_init).to(device)
		optimizer = Adam(z_var.parameters(), lr=learning_rate)



		samples_folder = os.path.join(exp_path, 'Generated_samples')
		if not os.path.exists(samples_folder):
		    os.mkdir(samples_folder)
		image_path_pde = os.path.join(
		    samples_folder, 'PDE3')

		if not os.path.exists(image_path_pde):
		    os.mkdir(image_path_pde)

		z_init = flow.q0(b)[0]
		mean_flow = flow.q0.loc
		z_init = z_init * 0.0 + mean_flow
		# image_size = 2 * image_size
		image_size = image_size

		# images_2k = F.interpolate(images_8k, size = image_size, antialias = True, mode = 'bilinear')
		images_2k = images
		out_true = images_2k.permute(0,2,3,1).reshape(b , -1 , c).detach()

		

		t_coords = get_mgrid(image_size).reshape(-1, 2)
		t_coords = torch.unsqueeze(t_coords, dim = 0)
		t_coords = t_coords.expand(b,-1, -1)
		t_coords = t_coords.clone().detach().requires_grad_(True)
		t_coords = t_coords.reshape((b,image_size,image_size,2)).to(device)
		t_coords_grad = torch.zeros(t_coords.shape)
		t_coords_grad = 2 * torch.flip(t_coords , dims = [3])

		img_tmp = grid_sample_customized(images_2k, t_coords_grad, pad = 'reflect')
		y = gradient(img_tmp, t_coords, grad_outputs=None)
		y = y.reshape(b, -1 , 2).detach()
		y_np = torch.sqrt(torch.sum(torch.pow(y,2) , axis = 2))

		y_np = y_np.cpu().numpy()
		y_np = np.reshape(y_np, [-1, image_size, image_size, 1])
		y_write = y_np[sample_number].reshape(ngrid, ngrid,
			image_size, image_size, 1).swapaxes(1, 2).reshape(ngrid*image_size, -1, 1)

		plt.imsave(os.path.join(image_path_pde, 'derivatives_true.png'),
			y_write[:,:,0], cmap = 'gray')


		with tqdm(total=n_steps) as pbar:
		  	for i in range(n_steps):

		  		coords = get_mgrid(image_size).reshape(-1, 2)
		  		coords = torch.unsqueeze(coords, dim = 0)
		  		coords = coords.expand(b,-1, -1).to(device)
		  		coords = coords.clone().detach().requires_grad_(True)

		  		pixels = np.random.randint(low = 0, high = image_size**2, size = batch_pixels)
		  		batch_coords = coords[:,pixels]
		  		batch_y = y[:,pixels]
		  		batch_out_true = out_true[:,pixels]

		  		optimizer.zero_grad()

		  		img = G(z_var.x, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
		  		img = (F.interpolate(img, size = image_size, antialias = True, mode = 'bilinear')/2) + 0.5
		  		# print(img.max(), img.min())

		  		y_hat, out_hat = forward_operator(img, batch_coords , model)
		  		# loss = mse_loss(batch_y, y_hat)
		  		loss = mse_loss(batch_out_true, out_hat)
		  		loss.backward()
		  		optimizer.step()

		  		out_hat = out_hat.detach().cpu().numpy()
		  		batch_out_true = batch_out_true.detach().cpu().numpy()
		  		snr = SNR(batch_out_true, out_hat)

		  		if i % 100 == 0:

		  			coords = get_mgrid(image_size).reshape(-1, 2)
		  			coords = torch.unsqueeze(coords, dim = 0)
		  			coords = coords.expand(b , -1, -1).to(device)			    
		  			recon_np = batch_sampling(img, coords,c, model)
		  			recon_np = np.reshape(recon_np, [-1, image_size, image_size, c])

		  			recon_np = recon_np.reshape(b,-1)
		  			recon_np = recon_np - recon_np.min(axis = -1, keepdims = True)
		  			recon_np  = recon_np/recon_np.max(axis = -1, keepdims = True)
		  			recon_np *= 255.0
		  			recon_np = recon_np.reshape(-1, image_size, image_size, c)


		  			recon_write = recon_np.reshape(
				        ngrid, ngrid,
				        image_size, image_size, c).swapaxes(1, 2).reshape(ngrid*image_size, -1, c)


		  			recon_write = recon_write.clip(0, 255).astype(np.uint8)

		  			imageio.imwrite(os.path.join(image_path_pde, '_%d_recon.png' % (i,)),
				                recon_write)


		  			with open(os.path.join(image_path_pde, 'results.txt'), 'a') as file:
		  				file.write('iter: {:.0f}| Loss: {:.2f}| snr: {:.2f}'.format(i,loss, snr))
		  				file.write('\n')
	


		  		pbar.set_description('Loss: {:.2f}| snr: {:.2f}'.format(loss, snr))
		  		pbar.update(1)


	
	else:

		samples_folder = os.path.join(exp_path, 'Generated_samples')
		if not os.path.exists(samples_folder):
		    os.mkdir(samples_folder)
		image_path_pde = os.path.join(
		    samples_folder, 'PDE2_' + str(sample_number))

		if not os.path.exists(image_path_pde):
		    os.mkdir(image_path_pde)

		z_init = flow.q0(b)[0]
		mean_flow = flow.q0.loc
		z_init = z_init * 0.0 + mean_flow
		image_size = 2 * image_size
		# image_size = image_size

		images_2k = F.interpolate(images_8k, size = image_size, antialias = True, mode = 'bilinear')
		# images_2k = images
		out_true = images_2k.permute(0,2,3,1).reshape(b , -1 , c).detach()

		images_2k_np = images_2k.permute(0,2,3,1).detach().cpu().numpy()
		images_2k_write = images_2k_np.reshape(ngrid, ngrid,
			image_size, image_size, c).swapaxes(1, 2).reshape(ngrid*image_size, -1, c) * 255.0

		images_2k_write = images_2k_write.clip(0, 255).astype(np.uint8)
		imageio.imwrite(os.path.join(image_path_pde, 'gt.png'),
			images_2k_write)
		

		t_coords = get_mgrid(image_size).reshape(-1, 2)
		t_coords = torch.unsqueeze(t_coords, dim = 0)
		t_coords = t_coords.expand(b,-1, -1)
		t_coords = t_coords.clone().detach().requires_grad_(True).to(device)

		t_coords_grad = t_coords.reshape((b,image_size,image_size,2))
		t_coords_grad = 2 * torch.flip(t_coords_grad , dims = [3])

		img_tmp = grid_sample_customized(images_2k, t_coords_grad, pad = 'reflect')
		img_tmp = img_tmp.permute(0,2,3,1).reshape(b , -1 , c)
		# y = multiple_gradient(img_tmp , t_coords).detach()

		y = gradient(img_tmp, t_coords, grad_outputs=None)
		y = y.detach()

		print(y.mean() , y.std())


		# y = y.reshape(b, -1 , 6)
		y_np = torch.sqrt(torch.sum(torch.pow(y,2) , axis = 2))

		y_np = y_np.cpu().numpy()
		y_np = np.reshape(y_np, [-1, image_size, image_size, 1])
		y_write = y_np.reshape(ngrid, ngrid,
			image_size, image_size, 1).swapaxes(1, 2).reshape(ngrid*image_size, -1, 1)

		plt.imsave(os.path.join(image_path_pde, 'derivatives_gt.png'),
			y_write[:,:,0], cmap = 'gray')

		z_var = temp_var(z_init).to(device)
		step_size = 1000
		gamma = 0.5
		optimizer = torch.optim.Adam(z_var.parameters(), lr=learning_rate)
		optimizer_cnn = torch.optim.Adam(aeder.parameters(), lr=1e-5)
		# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)


		with tqdm(total=n_steps) as pbar:
		  	for i in range(n_steps):

		  		coords = get_mgrid(image_size).reshape(-1, 2)
		  		coords = torch.unsqueeze(coords, dim = 0)
		  		coords = coords.expand(b,-1, -1).to(device)
		  		coords = coords.clone().detach().requires_grad_(True)

		  		pixels = np.random.randint(low = 0, high = image_size**2, size = batch_pixels)
		  		batch_coords = coords[:,pixels]
		  		batch_y = y[:,pixels]
		  		batch_out_true = out_true[:,pixels]

		  		if (i > 1500 and  i < 3500) or (i > 4500):
		  			optimizer_cnn.zero_grad()

		  		else:
		  			optimizer.zero_grad()

		  		# optimizer.zero_grad()
		  		z_tilde = flow.sample_me(z_var.x)
		  		img = aeder.decoder(z_tilde)
		  		y_hat, out_hat = forward_operator(img, batch_coords , model)
		  		# y_hat = torch.sqrt(torch.sum(torch.pow(y_hat,2) , axis = 2))
		  		reg = flow.forward_kld(z_tilde)
		  		loss = mse_loss(batch_y, y_hat) + lam * reg
		  		# loss = mse_loss(batch_out_true, out_hat)
		  		# loss.backward()
		  		# optimizer.step()

		  		if (i > 1500 and i < 3500) or (i > 4500):
		  			loss.backward()
		  			optimizer_cnn.step()

		  		else:
		  			loss.backward()
		  			optimizer.step()

		  		out_hat = out_hat.detach().cpu().numpy()
		  		batch_out_true = batch_out_true.detach().cpu().numpy()
		  		snr = SNR(batch_out_true, out_hat)

		  		if i % 100 == 0:

		  			coords = get_mgrid(image_size).reshape(-1, 2)
		  			coords = torch.unsqueeze(coords, dim = 0)
		  			coords = coords.expand(b , -1, -1).to(device)			    
		  			recon_np = batch_sampling(img, coords,c, model)
		  			recon_np = np.reshape(recon_np, [-1, image_size, image_size, c])

		  			recon_np = recon_np.reshape(b,-1)
		  			recon_np = recon_np - recon_np.min(axis = -1, keepdims = True)
		  			recon_np  = recon_np/recon_np.max(axis = -1, keepdims = True)
		  			recon_np *= 255.0
		  			recon_np = recon_np.reshape(-1, image_size, image_size, c)

		  			recon_write = recon_np.reshape(
				        ngrid, ngrid,
				        image_size, image_size, c).swapaxes(1, 2).reshape(ngrid*image_size, -1, c)

		  			recon_write = recon_write.clip(0, 255).astype(np.uint8)

		  			imageio.imwrite(os.path.join(image_path_pde, '%d_recon.png' % (i,)),
				                recon_write)

		  			coords = get_mgrid(image_size).reshape(-1, 2)
		  			coords = torch.unsqueeze(coords, dim = 0)
		  			coords = coords.expand(b , -1, -1).to(device)	
		  			coords = coords.clone().detach().requires_grad_(True)		    
		  			recon_np = batch_grad(img, coords,c, model)
		  			recon_np = np.reshape(recon_np, [-1, image_size, image_size, 1])

		  			recon_write = recon_np.reshape(
				        ngrid, ngrid,
				        image_size, image_size, 1).swapaxes(1, 2).reshape(ngrid*image_size, -1, 1)


		  			plt.imsave(os.path.join(image_path_pde, '%d_recon_grad.png' % (i,)),
				                recon_write[:,:,0], cmap = 'gray')


		  			with open(os.path.join(image_path_pde, 'results.txt'), 'a') as file:
		  				file.write('iter: {:.0f}| Loss: {:.2f}| snr: {:.2f}'.format(i,loss, snr))
		  				file.write('\n')
		  		pbar.set_description('Loss: {:.2f}| snr: {:.2f}'.format(loss, snr))
		  		pbar.update(1)




