import numpy as np
import torch
import torch.nn.functional as F
import os
import imageio
import shutil
from my_utils import *
from datasets import *
import sys
from torch.utils.data.sampler import SubsetRandomSampler
import matplotlib.pyplot as plt
from skimage.transform import radon, iradon


def evaluator(training_mode, subset,data_loader, model, device, image_size, c,
    k_test, image_path_reconstructions, exp_path, ep,
    t1, t2, epochs_supercnn, loss_supercnn_epoch, aeder = None):
    sample_number = 25
    ngrid = int(np.sqrt(sample_number))

    if subset == 'test':

        images_8k = next(iter(data_loader)).to(device)#[:sample_number]
        images_8k = images_8k.reshape(-1, k_test*image_size, k_test*image_size, c).permute(0,3,1,2)
        images_4k = F.interpolate(images_8k, size = 4*image_size, antialias = True, mode = 'bilinear')
        images_4k_np = images_4k.permute(0, 2, 3, 1).detach().cpu().numpy()
        images_8k_np = images_8k.permute(0, 2, 3, 1).detach().cpu().numpy()


        images = F.interpolate(images_8k, size = image_size, antialias = True, mode = 'bilinear')

        if training_mode == 'factor':
            s = [2,4,8]
        else:
            s = [2]

        images_down = images
        images_down_inter = images
        snr_recon = [0,0,0]
        snr_interpolate = [0,0,0]
        for i in range(len(s)):

            res = s[i]*image_size

            # GT:
            images_temp = F.interpolate(images_8k, size = res , antialias = True, mode = 'bilinear')
            images_np = images_temp.permute(0, 2, 3, 1).detach().cpu().numpy()

            image_write = images_np[:sample_number].reshape(
                ngrid, ngrid,
                res, res,c).swapaxes(1, 2).reshape(ngrid*res, -1, c)*255.0

            image_write = image_write.clip(0, 255).astype(np.uint8)

            imageio.imwrite(os.path.join(image_path_reconstructions, subset +  '_%d_iter_gt_%d.png' % (ep,s[i])),
                        image_write)

            # Recon:
            coords = get_mgrid(res).reshape(-1, 2)
            coords = torch.unsqueeze(coords, dim = 0)
            coords = coords.expand(images_8k.shape[0] , -1, -1).to(device)
            recon_np = batch_sampling(images_down, coords,c, model)
            recon_np = np.reshape(recon_np, [-1, res, res, c])

            recon_write = recon_np[:sample_number].reshape(
                ngrid, ngrid, res, res, c).swapaxes(1, 2).reshape(ngrid*res, -1, c)*255.0

            recon_write = recon_write.clip(0, 255).astype(np.uint8)

            imageio.imwrite(os.path.join(image_path_reconstructions, subset + '_%d_iter_recon_%d.png' % (ep,s[i])),
                        recon_write)


            # Interpolate:
            interpolate = F.interpolate(images, size = res, mode = 'bilinear')
            interpolate_np = interpolate.detach().cpu().numpy().transpose(0,2,3,1)
            interpolate_write = interpolate_np[:sample_number].reshape(
                ngrid, ngrid,
                res, res, c).swapaxes(1, 2).reshape(ngrid*res, -1, c)*255.0

            interpolate_write = interpolate_write.clip(0, 255).astype(np.uint8)

            imageio.imwrite(os.path.join(image_path_reconstructions, subset + '_%d_iter_interpolate_%d.png' % (ep,s[i])),
                            interpolate_write) # mesh_based_recon


            snr_recon[i] = SNR_rescale(images_np, recon_np)
            snr_interpolate[i] = SNR_rescale(images_np, interpolate_np)

            recon_np = recon_np.transpose([0,3,1,2])
            images_down = torch.tensor(recon_np, dtype = images_down.dtype).to(device)





       
    ########################################################################################
    # Gradients:
        coords_2k = get_mgrid(2*image_size).reshape(-1, 2)
        coords_2k = torch.unsqueeze(coords_2k, dim = 0)
        coords_2k = coords_2k.expand(images_8k.shape[0] , -1, -1).to(device)
        coords_2k = coords_2k.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input

        # out = model(coords_64, test_images_down)
        # out_grad = gradient(out, coords_64)
        # out_grad = torch.sqrt(torch.sum(torch.pow(out_grad,2) , axis = 2)).detach().cpu().numpy()
        out_grad = batch_grad(images, coords_2k,c, model)
        out_grad = np.reshape(out_grad, [-1, 2*image_size, 2*image_size,1])

        out_grad_write = out_grad[:sample_number, :, :].reshape(
            ngrid, ngrid,
            2*image_size, 2*image_size, 1).swapaxes(1, 2).reshape(ngrid*2*image_size, -1, 1)*255.0


        plt.imsave(os.path.join(image_path_reconstructions, subset + '_%d_out_grad.png' % (ep,)),
                        out_grad_write[:,:,0], cmap='gray')


        # coords_2k = get_mgrid(2*image_size).reshape(-1, 2)
        # coords_2k = torch.unsqueeze(coords_2k, dim = 0)
        # coords_2k = coords_2k.expand(sample_number , -1, -1).to(device)
        # coords_2k = coords_2k.clone().detach().requires_grad_(True) 

        # out = torch.tensor(batch_sampling(images, coords_2k,c, model), dtype = coords_2k.dtype).to(device)
        # out_mat = out.reshape([-1, 2*image_size, 2*image_size, c]).permute(0,3,1,2)
        # true_grad = image_derivative(out_mat , c, 2*image_size)

        # true_grad = true_grad.permute(0,2,3,1).detach().cpu().numpy()

        # true_grad_write = true_grad[:sample_number, :, :].reshape(
        #     ngrid, ngrid,
        #     2*image_size, 2*image_size, 1).swapaxes(1, 2).reshape(ngrid*2*image_size, -1, 1)*255.0


        # plt.imsave(os.path.join(image_path_reconstructions, subset + '_%d_true_grad.png' % (ep,)),
        #                 true_grad_write[:,:,0], cmap='gray')

    ############################################################################################
    # Laplacian:
        # coords = get_mgrid(image_size).reshape(-1, 2)
        # coords = torch.unsqueeze(coords, dim = 0)
        # coords = coords.expand(sample_number , -1, -1).to(device)
        # coords = coords.clone().detach().requires_grad_(True) # allows to take derivative w.r.t. input

        # # out = model(coords, images)
        # # out_laplace = laplace(out, coords).detach().cpu().numpy()
        # out_laplace= batch_laplace(images, coords,c, model)
        # out_laplace = np.reshape(out_laplace, [-1, image_size, image_size,1])

        # out_laplace_write = out_laplace[:sample_number, :, :].reshape(
        #     ngrid, ngrid,
        #     image_size, image_size, 1).swapaxes(1, 2).reshape(ngrid*image_size, -1, 1)

        # plt.imsave(os.path.join(image_path_reconstructions, subset + '_%d_out_laplace.png' % (ep,)),
        #                 out_laplace_write[:,:,0], cmap='seismic')


        # out_mat = out.reshape([-1, image_size, image_size, c]).permute(0,3,1,2)
        # true_grad = image_derivative(out_mat , c, image_size)
        # print(true_grad.shape)

        # true_grad = true_grad.permute(0,2,3,1).detach().cpu().numpy()

        # true_grad_write = true_grad[:sample_number, :, :].reshape(
        #     ngrid, ngrid,
        #     image_size, image_size, 1).swapaxes(1, 2).reshape(ngrid*image_size, -1, 1)*255.0


        # plt.imsave(os.path.join(image_path_reconstructions, subset + '_%d_true_grad.png' % (ep,)),
        #                 true_grad_write[:,:,0], cmap='gray')

    ##############################################################################################

        if not training_mode == 'factor':
            coords = get_mgrid(8 * image_size).reshape(-1, 2)
            coords = torch.unsqueeze(coords, dim = 0)
            coords = coords.expand(images_8k.shape[0] , -1, -1).to(device)
            recon_1_8 = batch_sampling(images, coords,c, model)
            recon_1_8 = np.reshape(recon_1_8, [-1, 8*image_size, 8*image_size, c])

            recon_1_8_write = recon_1_8[:sample_number].reshape(
                ngrid, ngrid,
                8 * image_size, 8 * image_size, c).swapaxes(1, 2).reshape(ngrid*8*image_size, -1, c)*255.0

            recon_1_8_write = recon_1_8_write.clip(0, 255).astype(np.uint8)
            imageio.imwrite(os.path.join(image_path_reconstructions, subset + '_%d_recon_1_8.png' % (ep,)),
                        recon_1_8_write) # training images


            interpolate_1_8 = F.interpolate(images, size = 8 * image_size, mode = 'bilinear')
            interpolate_1_8 = interpolate_1_8.detach().cpu().numpy().transpose(0,2,3,1)
            interpolate_1_8_write = interpolate_1_8[:sample_number].reshape(
                ngrid, ngrid,
                8*image_size, 8*image_size, c).swapaxes(1, 2).reshape(ngrid*8*image_size, -1, c)*255.0

            interpolate_1_8_write = interpolate_1_8_write.clip(0, 255).astype(np.uint8)
            imageio.imwrite(os.path.join(image_path_reconstructions, subset + '_%d_interpolate_1_8.png' % (ep,)),
                            interpolate_1_8_write)


        ##############################################################################################

            coords = get_mgrid(4 * image_size).reshape(-1, 2)
            coords = torch.unsqueeze(coords, dim = 0)
            coords = coords.expand(images_8k.shape[0] , -1, -1).to(device)
            recon_1_4 = batch_sampling(images, coords,c, model)
            recon_1_4 = np.reshape(recon_1_4, [-1, 4*image_size, 4*image_size, c])

            recon_1_4_write = recon_1_4[:sample_number].reshape(
                ngrid, ngrid,
                4 * image_size, 4 * image_size, c).swapaxes(1, 2).reshape(ngrid*4*image_size, -1, c)*255.0

            recon_1_4_write = recon_1_4_write.clip(0, 255).astype(np.uint8)
            imageio.imwrite(os.path.join(image_path_reconstructions, subset + '_%d_recon_1_4.png' % (ep,)),
                        recon_1_4_write) # training images


            interpolate_1_4 = F.interpolate(images, size = 4 * image_size, mode = 'bilinear')
            interpolate_1_4 = interpolate_1_4.detach().cpu().numpy().transpose(0,2,3,1)
            interpolate_1_4_write = interpolate_1_4[:sample_number].reshape(
                ngrid, ngrid,
                4*image_size, 4*image_size, c).swapaxes(1, 2).reshape(ngrid*4*image_size, -1, c)*255.0

            interpolate_1_4_write = interpolate_1_4_write.clip(0, 255).astype(np.uint8)
            imageio.imwrite(os.path.join(image_path_reconstructions, subset + '_%d_interpolate_1_4.png' % (ep,)),
                            interpolate_1_4_write)


            #################################################################################################################


        if training_mode == 'factor':
            snr_recon_1_8 = snr_recon[2]
            snr_interpolate_1_8 = snr_interpolate[2]
            snr_recon_1_4 = snr_recon[1]
            snr_interpolate_1_4 = snr_interpolate[1]

        else:
            snr_recon_1_8 = SNR_rescale(images_8k_np , recon_1_8)
            snr_interpolate_1_8 = SNR_rescale(images_8k_np  , interpolate_1_8)
            snr_recon_1_4 = SNR_rescale(images_4k_np , recon_1_4)
            snr_interpolate_1_4 = SNR_rescale(images_4k_np  , interpolate_1_4)

        # snr_recon_1_8 = 0
        # snr_interpolate_1_8 = 0

        # snr_recon_1_4 = 0
        # snr_interpolate_1_4 = 0



        
        print('Set: {}, ep: {}/{} | time: {:.0f} | supercnn_loss {:.6f} | SNR_interpolate_1_2  {:.1f} | SNR_recon_1_2 {:.1f}  SNR_interpolate_1_4  {:.1f} | SNR_recon_1_4 {:.1f} | SNR_interpolate_1_8  {:.1f} | SNR_recon_1_8 {:.1f}'.format(subset, ep, epochs_supercnn,t2-t1,
                        loss_supercnn_epoch, snr_interpolate[0], snr_recon[0], snr_interpolate_1_4, snr_recon_1_4, snr_interpolate_1_8, snr_recon_1_8))


        with open(os.path.join(exp_path, 'results.txt'), 'a') as file:
            file.write('Set: {}, ep: {}/{} | time: {:.0f} | supercnn_loss {:.6f} | SNR_interpolate_1_2  {:.1f} | SNR_recon_1_2 {:.1f}  SNR_interpolate_1_4  {:.1f} | SNR_recon_1_4 {:.1f} | SNR_interpolate_1_8  {:.1f} | SNR_recon_1_8 {:.1f}'.format(subset, ep, epochs_supercnn,t2-t1,
                        loss_supercnn_epoch, snr_interpolate[0], snr_recon[0], snr_interpolate_1_4, snr_recon_1_4, snr_interpolate_1_8, snr_recon_1_8))
            file.write('\n')


    elif subset == 'ood':

        images_128 = next(iter(data_loader)).to(device)#[:sample_number]
        images_128 = images_128.reshape(-1, 2*image_size, 2*image_size, c).permute(0,3,1,2)
        images_64 = F.interpolate(images_128, size = image_size , antialias = True, mode = 'bilinear')
        
        images_128_np = images_128.permute(0, 2, 3, 1).detach().cpu().numpy()

        image_128_write = images_128_np[:sample_number].reshape(
            ngrid, ngrid,
            2*image_size, 2*image_size,c).swapaxes(1, 2).reshape(ngrid*2*image_size, -1, c)*255.0

        image_128_write = image_128_write.clip(0, 255).astype(np.uint8)
        imageio.imwrite(os.path.join(image_path_reconstructions, subset +  '_%d_gt_128.png' % (ep,)),
                    image_128_write) # Reconstructed training images



        coords = get_mgrid(2*image_size).reshape(-1, 2)
        coords = torch.unsqueeze(coords, dim = 0)
        coords = coords.expand(images_128.shape[0] , -1, -1).to(device)
        recon_64_128_np = batch_sampling(images_64, coords,c, model)
        recon_64_128_np = np.reshape(recon_64_128_np, [-1, 2*image_size, 2*image_size, c])

        recon_64_128_write = recon_64_128_np[:sample_number].reshape(
            ngrid, ngrid,
            2*image_size, 2*image_size, c).swapaxes(1, 2).reshape(ngrid*2*image_size, -1, c)*255.0

        recon_64_128_write = recon_64_128_write.clip(0, 255).astype(np.uint8)
        imageio.imwrite(os.path.join(image_path_reconstructions, subset + '_%d_recon_64_128.png' % (ep,)),
                    recon_64_128_write) # training images



        interpolate_64_128 = F.interpolate(images_64, size = 2*image_size, mode = 'bilinear')
        interpolate_64_128_np = interpolate_64_128.detach().cpu().numpy().transpose(0,2,3,1)
        interpolate_64_128_write = interpolate_64_128_np[:sample_number].reshape(
            ngrid, ngrid,
            2*image_size, 2*image_size, c).swapaxes(1, 2).reshape(ngrid*2*image_size, -1, c)*255.0

        interpolate_64_128_write = interpolate_64_128_write.clip(0, 255).astype(np.uint8)
        imageio.imwrite(os.path.join(image_path_reconstructions, subset + '_%d_interpolate_64_128.png' % (ep,)),
                        interpolate_64_128_write) # mesh_based_recon


        snr_recon_64_128 = SNR_rescale(images_128_np  , recon_64_128_np)
        snr_interpolate_64_128 = SNR_rescale(images_128_np  , interpolate_64_128_np)
        
        print('Set: {}, ep: {}/{} | time: {:.0f} | supercnn_loss {:.6f} | SNR_interpolate_64_128  {:.1f} | SNR_recon_64_128 {:.1f}'.format(subset, ep, epochs_supercnn,t2-t1,
                        loss_supercnn_epoch, snr_interpolate_64_128, snr_recon_64_128))


        with open(os.path.join(exp_path, 'results.txt'), 'a') as file:
            file.write('Set: {}, ep: {}/{} | time: {:.0f} | supercnn_loss {:.6f} | SNR_interpolate_64_128  {:.1f} | SNR_recon_64_128 {:.1f}'.format(subset, ep, epochs_supercnn,t2-t1,
                        loss_supercnn_epoch, snr_interpolate_64_128, snr_recon_64_128))
            file.write('\n')


    elif subset == 'generative':

        images_8k = next(iter(data_loader)).to(device)[:sample_number]
        images_8k = images_8k.reshape(-1, 8*image_size, 8*image_size, c).permute(0,3,1,2)

        images = F.interpolate(images_8k, size = image_size, antialias = True, mode = 'bilinear')

        images_np = images.permute(0, 2, 3, 1).detach().cpu().numpy()

        image_write = images_np[:sample_number].reshape(
            ngrid, ngrid,
            image_size, image_size,c).swapaxes(1, 2).reshape(ngrid*image_size, -1, c)*255.0

        image_write = image_write.clip(0, 255).astype(np.uint8)
        imageio.imwrite(os.path.join(image_path_reconstructions, subset +  '_%d_gt.png' % (ep,)),
                    image_write) # Reconstructed training images



        coords = get_mgrid(image_size).reshape(-1, 2)
        coords = torch.unsqueeze(coords, dim = 0)
        coords = coords.expand(sample_number , -1, -1).to(device)
        image_recon = aeder.decoder(aeder.encoder(images))
        recon_64_np = batch_sampling(image_recon, coords,c, model)
        recon_64_np = np.reshape(recon_64_np, [-1, image_size, image_size, c])

        recon_64_write = recon_64_np[:sample_number].reshape(
            ngrid, ngrid,
            image_size, image_size, c).swapaxes(1, 2).reshape(ngrid*image_size, -1, c)*255.0

        recon_64_write = recon_64_write.clip(0, 255).astype(np.uint8)
        imageio.imwrite(os.path.join(image_path_reconstructions, subset + '_%d_recon_64.png' % (ep,)),
                    recon_64_write) # training images


        snr_recon_64 = SNR_rescale(images_np  , recon_64_np)
        
        print('Set: {}, ep: {}/{} | time: {:.0f} | supercnn_loss {:.6f}| SNR_recon_64 {:.1f}'.format(subset, ep, epochs_supercnn,t2-t1,
                        loss_supercnn_epoch, snr_recon_64))


        with open(os.path.join(exp_path, 'results.txt'), 'a') as file:
            file.write('Set: {}, ep: {}/{} | time: {:.0f} | supercnn_loss {:.6f}| SNR_recon_64 {:.1f}'.format(subset, ep, epochs_supercnn,t2-t1,
                        loss_supercnn_epoch, snr_recon_64))
            file.write('\n')


    elif subset == 'CT':

        image_size_1, image_size_2 = 182, 60
        c = 1
        x_size = 128

        # Dataset:
        x_test_add = 'Projects/datasets/CT_dataset/new/x_test_uniform_snr_40/samples/'
        sin_test_add = 'Projects/datasets/CT_dataset/new/sin_test_uniform_snr_40/samples/'
        x_dataset = CT_sinogram(x_test_add)
        sin_dataset = CT_sinogram(sin_test_add)

        sin_loader = torch.utils.data.DataLoader(sin_dataset, batch_size=25, num_workers=8)
        gt_loader = torch.utils.data.DataLoader(x_dataset, batch_size=25, num_workers=8)


        sins = next(iter(sin_loader)).to(device)[:sample_number]
        sins = sins.reshape(-1, image_size_1, image_size_2, c).permute(0,3,1,2)

        gt = next(iter(gt_loader)).to(device)[:sample_number]*100.0
        gt = gt.reshape(-1, x_size, x_size, c)
        gt = gt.detach().cpu().numpy()

        gt_write = gt.reshape(
            ngrid, ngrid,x_size, x_size,c).swapaxes(1, 2).reshape(ngrid*x_size, -1, c)*255.0

        gt_write = gt_write.clip(0, 255).astype(np.uint8)
        imageio.imwrite(os.path.join(image_path_reconstructions, subset +  '_%d_gt.png' % (ep,)),
                    gt_write) # Reconstructed training images


        sins_np = sins.permute(0,2,3,1).detach().cpu().numpy()
        print(sins_np.max(), sins_np.min())
        fbp = fbp_batch(sins_np[:,:,:,0] * 100.0)

        fbp_write = fbp[:sample_number].reshape(
            ngrid, ngrid,
            x_size, x_size, c).swapaxes(1, 2).reshape(ngrid*x_size, -1, c) * 255.0

        fbp_write = fbp_write.clip(0, 255).astype(np.uint8)
        imageio.imwrite(os.path.join(image_path_reconstructions, subset + '_%d_fbp.png' % (ep,)),
                    fbp_write) # training images



        coords = get_mgrid_unbalanced(image_size_1, 2 * image_size_2).reshape(-1, 2)
        coords = torch.unsqueeze(coords, dim = 0)
        coords = coords.expand(sample_number , -1, -1).to(device)

        sins_f2 = batch_sampling(sins, coords,c, model)
        sins_f2 = np.reshape(sins_f2, [-1, image_size_1, 2 * image_size_2, c])
        fbp_f2 = fbp_batch(sins_f2[:,:,:,0] * 100.0)

        fbp2_f2_write = fbp_f2[:sample_number].reshape(
            ngrid, ngrid,
            x_size, x_size, c).swapaxes(1, 2).reshape(ngrid*x_size, -1, c) * 255.0

        fbp2_f2_write = fbp2_f2_write.clip(0, 255).astype(np.uint8)
        imageio.imwrite(os.path.join(image_path_reconstructions, subset + '_%d_fbp_f2.png' % (ep,)),
                    fbp2_f2_write)




        sins_f2_inter = F.interpolate(sins, size = (image_size_1, 2 * image_size_2), mode = 'bilinear')
        sins_f2_inter_np = sins_f2_inter.detach().cpu().numpy().transpose(0,2,3,1)
        fbp_f2_inter = fbp_batch(sins_f2_inter_np[:,:,:,0] * 100.0)

        fbp_f2_inter_write = fbp_f2_inter[:sample_number].reshape(
            ngrid, ngrid,
            x_size, x_size, c).swapaxes(1, 2).reshape(ngrid*x_size, -1, c) * 255.0

        fbp_f2_inter_write = fbp_f2_inter_write.clip(0, 255).astype(np.uint8)
        imageio.imwrite(os.path.join(image_path_reconstructions, subset + '_%d_fbp_f2_inter.png' % (ep,)),
                    fbp_f2_inter_write)


        snr_fbp_f2 = SNR(gt  , fbp_f2)
        snr_fbp_f2_inter = SNR(gt  , fbp_f2_inter)
        snr_fbp = SNR(gt  , fbp)
        
        print('Set: {}, ep: {}/{} | time: {:.0f} | supercnn_loss {:.6f}| SNR_fbp {:.1f} | SNR_fbp_f2_inter {:.1f}| SNR_fbp_f2 {:.1f}'.format(subset, ep, epochs_supercnn,t2-t1,
                        loss_supercnn_epoch, snr_fbp, snr_fbp_f2_inter, snr_fbp_f2))


        with open(os.path.join(exp_path, 'results.txt'), 'a') as file:
            file.write('Set: {}, ep: {}/{} | time: {:.0f} | supercnn_loss {:.6f}| SNR_fbp {:.1f} | SNR_fbp_f2_inter {:.1f}| SNR_fbp_f2 {:.1f}'.format(subset, ep, epochs_supercnn,t2-t1,
                        loss_supercnn_epoch, snr_fbp, snr_fbp_f2_inter, snr_fbp_f2))
            file.write('\n')




