import argparse
import json
import skimage.io
import torch
# from IPython.core.pylabtools import figsize
from numpy.ma.core import indices
from tqdm import tqdm
import numpy as np
import pandas as pd
import cv2
import imageio
import matplotlib.pyplot as plt
from skimage.io import imread
from skimage.io import imsave
from skimage.transform import warp
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as ssim
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as LPIPS
from torchmetrics.image.psnr import PeakSignalNoiseRatio as torch_psnr
from torchmetrics.image.ssim import StructuralSimilarityIndexMeasure as torch_ssim
from utils.common_utils import *
import warnings
from torchinfo import summary
from skimage import segmentation
from networks.conv_layers import *
from networks.skip import skip
from networks.unet import UNet
from siren_pytorch import SirenNet
import wandb
import os
import logging
import diff_operators


class ParseAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        print('parse action called!')
        if type(values) == list and all([type(x) == int for x in values]):
            # [1, 2, 3] -> [1, 2, 3]
            pass
        elif type(values) == list and type(values[0]) == str and len(values) == 1 \
                and '[' not in values[0] and ']' not in values[0] \
                and ',' not in values[0]:
            # string of space separated numbers: '1 2 3' -> [1, 2, 3]
            values = list(map(int, values[0].split()))
        elif type(values) == list and type(values[0]) == str and len(values) == 1 \
                and '[' in values[0] and ']' in values[0]:
            # string of list of numbers: '[1, 2, 3]' -> [1, 2, 3]
            # int() takes care of excess spaces in string int('   -1 ')-> -1
            values = list(map(int, values[0][1:-1].split(',')))
        elif type(values) == list and type(values[0]) == str and len(values) >= 1 \
                and '[' not in values[0] and ']' not in values[0] \
                and ',' not in values[0]:
            # string of list of numbers: ['1', '2', '3'] -> [1, 2, 3]
            values = list(map(int, values))
        else:
            raise ValueError(f"Could not parse {values} to list of int")
        setattr(namespace, self.dest, values)


class Logger:
    def __init__(self, log_file):
        self.terminal = sys.stdout
        self.log = open(log_file, "a", encoding="utf-8")

    def write(self, message):
        try:
            self.terminal.write(message)
        except UnicodeEncodeError as e:
            self.terminal.write(message.encode('utf-8', 'replace').decode('utf-8'))

        try:
            self.log.write(message)
        except UnicodeEncodeError as e:
            self.log.write(message.encode('utf-8', 'replace').decode('utf-8'))

    def flush(self):
        try:
            self.terminal.flush()
        except UnicodeEncodeError as e:
            pass

        try:
            self.log.flush()
        except UnicodeEncodeError as e:
            pass

    def isatty(self):
        return self.terminal.isatty()


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)


# Setup Fourier Feature Transform function
class GaussianFourierFeatureTransform_B(torch.nn.Module):
    """
    An implementation of Gaussian Fourier feature mapping.

    "Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains":
       https://arxiv.org/abs/2006.10739
       https://people.eecs.berkeley.edu/~bmild/fourfeat/index.html

    Given an input of size [batches, num_input_channels, width, height],
     returns a tensor of size [batches, mapping_size*2, width, height].
    """

    def __init__(self, num_input_channels, B, mapping_size=256, scale=10):
        super().__init__()

        self._num_input_channels = num_input_channels
        self._mapping_size = mapping_size
        self._B = B * scale

    #         self._B = torch.load('{}/{}_tensor_B.pt'.format(fresult,pname))

    def forward(self, x):
        assert x.dim() == 4, 'Expected 4D input (got {}D input)'.format(x.dim())

        batches, channels, width, height = x.shape

        assert channels == self._num_input_channels, \
            "Expected input to have {} channels (got {} channels)".format(self._num_input_channels, channels)

        # Make shape compatible for matmul with _B.
        # From [B, C, W, H] to [(B*W*H), C].
        x = x.permute(0, 2, 3, 1).reshape(batches * width * height, channels)

        x = x @ self._B.to(x.device)

        # From [(B*W*H), C] to [B, W, H, C]
        x = x.view(batches, width, height, self._mapping_size)
        # From [B, W, H, C] to [B, C, W, H]
        x = x.permute(0, 3, 1, 2)

        x = 2 * np.pi * x
        return torch.cat([torch.sin(x), torch.cos(x)], dim=1)


def train_batch(wandb_run_name,
                environment: str,
                exp_name: str,
                debug: bool,
                pname: str,
                cur_batch_str: str,
                result_folder: str,
                dataset_name: str,
                batch_size: int,
                start_f: int,
                siren: bool = False,
                scale_factor: float = 0.5,
                scale_factor_str: str = '05',
                dim_in_imgen: int = 256,
                w0_first_imgen: float = 15.0,
                num_layers_imgen: int = 3,
                width_layers_imgen=None,
                w0_first_grid: float = 15.0,
                num_layers_grid: int = 2,
                width_layers_grid=None,
                bandwidth_img: int = 8,
                num_iter_initialize: int = 1000,
                lr_init: float = 1e-4,
                num_iter_optim: int = 1000,
                lr_optim: float = 1e-4,
                arch_type: str = 'height_grad', ):
    save_model = False

    frames_to_save = 25 if debug else 0

    if width_layers_grid is None:
        width_layers_grid = [128, 128]
    if width_layers_imgen is None:
        width_layers_imgen = [256, 256, 256]
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True

    dtype = torch.float32

    warnings.filterwarnings("ignore")

    imsize = -1

    extensions = ['.jpg', '.JPG', '.png', '.ppm', '.bmp', '.pgm', '.tif']

    fname = os.path.join('data', pname, 'img')
    fgt_exists = False
    for ext in extensions:
        fgt = os.path.join('data', pname, 'GT', f'{pname}_GT{ext}')
        if os.path.exists(fgt):
            fgt_exists = True
            break
    assert fgt_exists, f"GT file not found for {pname}"

    fdepth = os.path.join('data', pname, 'depth')

    # return false if fdepth is an empty folder
    if os.path.exists(fdepth):
        if os.listdir(fdepth):
            gt_npy = load_gt_npy(fdepth, batch_size)
    else:
        gt_npy = None
    # fgt = 'test_data/Our/Synthetic/Set1/GT/{}.jpg'.format(pname)

    fresult = os.path.join(result_folder, cur_batch_str)
    os.makedirs(fresult, exist_ok=True)
    os.makedirs(os.path.join(fresult, 'final_results'), exist_ok=True)
    if debug:
        os.makedirs(os.path.join(fresult, 'imgen_init_turb_avg_out_predxy'), exist_ok=True)
        os.makedirs(os.path.join(fresult, 'train_all'), exist_ok=True)
        for batch_idx in range(batch_size):
            os.makedirs(os.path.join(fresult, 'train_all', str(batch_idx)), exist_ok=True)

    # # Load Turbulence Images

    # Load reference GT pattern. If none, load a single turbulence image
    img_gt_rgb, img_gt_np = get_image(fgt, imsize)
    img_gt_np = img_gt_np[:3, :, :]
    dim_gt = img_gt_np.shape[0]
    if dim_gt == 1:
        img_gt_np = np.concatenate((img_gt_np, img_gt_np, img_gt_np), 0)

    images = []
    i = start_f
    # # Load turbulence image batch

    if 'sim128' in pname:
        sort_key = lambda x: int(x.split('_')[-1].rsplit('.', 1)[0])
    else:
        sort_key = lambda x: int(x[-7:].rsplit('.', 1)[0])

    for target in sorted(os.listdir(fname), key=sort_key)[start_f:start_f+batch_size]:
        d = os.path.join(fname, target)
        if has_file_allowed_extension(d, extensions):
            i = i + 1
            print(d)
            rgb, imgs = get_image(d, imsize)

            imgs = pil_to_np(im_resize(rgb, scale_factor))
            dim = imgs.shape[0]
            if dim == 1:
                imgs = np.concatenate((imgs, imgs, imgs), 0)
            images.append(imgs)

    img_gt_rgb_resize = im_resize(img_gt_rgb, scale_factor)
    img_gt_rgb_resize.save(os.path.join(fresult, f'{pname}_{cur_batch_str}_gt_scalef_{scale_factor_str}.png'))

    images_warp_np = np.array(images)
    print(images_warp_np.shape)
    images_mean_np = np.mean(images_warp_np, axis=0)
    # print(images_mean_np.shape)
    dim, nr, nc = images_mean_np.shape

    if dim > 1:
        img_gt_np = cv2.resize(
            img_gt_np.transpose(1, 2, 0), dsize=(nc, nr), interpolation=cv2.INTER_AREA).transpose(2, 0, 1)
    else:
        img_gt_np = cv2.resize(
            img_gt_np.transpose(1, 2, 0), dsize=(nc, nr), interpolation=cv2.INTER_AREA)

    out_imshow = np.concatenate(
        [images_warp_np[0].transpose(1, 2, 0), images_mean_np.transpose(1, 2, 0), img_gt_np.transpose(1, 2, 0)], axis=1)
    # plt.figure(figsize=(out_imshow.shape[1] / 100, out_imshow.shape[0] / 100))
    plt.figure(figsize=(12, 3))
    plt.imshow(out_imshow)
    plt.title('Turbulence Image - Average Image - GT Image')
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(os.path.join(fresult, f'{pname}_{cur_batch_str}_turb_avg_gt.png'))
    plt.close()
    # cv2.imwrite('{}/{}_turb_avg_gtcv2.png'.format(fresult, pname), np_to_cv2(out_imshow))

    # Generate straight grid batch for shape image
    xy_grid_batch = []
    coords_x = np.linspace(-1, 1, nc)
    coords_y = np.linspace(-1, 1, nr)
    coords_t = np.linspace(-1, 1, batch_size)
    xy_grid = np.stack(np.meshgrid(coords_x, coords_y), -1)
    xyt_grid = np.stack(np.meshgrid(coords_x, coords_y, coords_t), -1)

    xy_grid_var = np_to_torch(xy_grid.transpose(2, 0, 1)).type(dtype).cuda()
    xyt_grid_var = np_to_torch(xyt_grid.transpose(3, 0, 1, 2)).type(dtype).cuda()

    xy_grid_batch_var = xy_grid_var.repeat(batch_size, 1, 1, 1)
    print(f'xy_grid_batch_var: {xy_grid_batch_var.shape}')
    print(f'xyt_grid_var: {xyt_grid_var.shape}')

    # Setup Image Generator
    if siren:
        model_imgen = SirenNet(
            dim_in=dim_in_imgen,  # input dimension, ex. 2d coor
            dim_hidden=width_layers_imgen,  # hidden dimension
            dim_out=3,  # output dimension, ex. rgb value
            num_layers=num_layers_imgen,  # number of layers
            image_width=nc,
            image_height=nr,
            w0_initial=w0_first_imgen,
            # different signals may require different omg_0 in the first layer - it's a hyperparameter
            siren_batchnorm=False,  # whether to use batchnorm1d in first Siren layer
            final_activation=nn.Sigmoid(),  # activation of final layer (nn.Identity() for direct output)
        ).to(device)
    else:
        model_imgen = conv_layers(dim_in_imgen, 3)

    summary_input_tensor_shape = (batch_size, dim_in_imgen, nr, nc)
    model_imgen = model_imgen.type(dtype).cuda()

    if start_f == 0:
        summary(model_imgen, input_size=summary_input_tensor_shape,
                col_names=["input_size", "output_size", "num_params", "trainable"])
    model_imgen.train()
    torch.manual_seed(0)

    compare_lpips_troch = LPIPS(normalize=True).cuda()
    compare_psnr_torch = torch_psnr(data_range=1.0).cuda()
    compare_ssim_torch = torch_ssim(data_range=1.0).cuda()

    B_var = torch.randn(2, dim_in_imgen // 2)
    print(B_var.shape)

    # # Setup Height Prediction Network

    if siren:
        model_grid = SirenNet(
            dim_in=3,  # input dimension, ex. 2d coor
            dim_hidden=width_layers_grid,  # hidden dimension
            dim_out=1,  # output dimension, ex. rgb value
            num_layers=num_layers_grid,  # number of layers
            image_width=nc,
            image_height=nr,
            w0_initial=w0_first_grid,
            # different signals may require different omg_0 in the first layer - it's a hyperparameter
            siren_batchnorm=False,  # whether to use batchnorm1d in first Siren layer
            final_activation=None,  # activation of final layer (nn.Identity() for direct output)
        ).to(device)
    else:
        raise NotImplementedError("UniGrid deformer network not implemented for ReLU activation")

    if start_f == 0:
        summary(model_grid, input_size=(1, 3, nr, nc, batch_size),
                col_names=["input_size", "output_size", "num_params", "trainable"])
        model_grid.train()


    FB_img = bandwidth_img

    vec_scale = 1.1


    img_gt_batch_var = torch.from_numpy(images_warp_np).type(dtype).cuda()  # I_1 to I_k
    straight_grid_input = GaussianFourierFeatureTransform_B(2, B_var, dim_in_imgen // 2, FB_img)(xy_grid_batch_var)

    grid_input_single_gd = xy_grid_var.detach().clone()
    grid_t_input_single_gd = xyt_grid_var.detach().clone().requires_grad_(True)


    grid_input = straight_grid_input.detach().clone()  # [batch_size, C=256, nr, nc] - gamma(G_U) x batch_size

    model_params_list = [{'params': model_grid.parameters()}]

    model_params_list.append({'params': model_imgen.parameters()})

    # print(model_params_list)
    optimizer_init = torch.optim.Adam(model_params_list, lr=lr_init)

    num_iter_i = num_iter_initialize
    np_to_pil(images_warp_np[0]).save(os.path.join(fresult, f'{pname}_{cur_batch_str}_turb_img_frame_{0}.png'))
    np_to_pil(images_mean_np).save(os.path.join(fresult, f'{pname}_{cur_batch_str}_avg_img_BS_{batch_size}.png'))

    imgen_init_save_path = os.path.join(fresult, 'imgen_init_turb_avg_out_predxy')
    for epoch in tqdm(range(num_iter_i), file=sys.stdout):
        optimizer_init.zero_grad()

        refined_xy, h_t = predict_model_grid(grid_t_input_single_gd, model_grid, vec_scale, xy_grid_batch_var,
                                             imgen_init_save_path, epoch, debug, num_iter_initialize)

        generated = model_imgen(grid_input)  # output: batch_size X J, grid_input is batch_size X gamma(G_u)

        loss_init = torch.nn.functional.l1_loss(img_gt_batch_var, generated)  # (I_i - J)
        loss_init += torch.nn.functional.l1_loss(xy_grid_batch_var, refined_xy)  # (G_u - G_i)

        loss_init.backward()
        optimizer_init.step()

        if epoch % 100 == 0 or epoch == num_iter_i - 1:
            # visualize_water_height(h_t, epoch)
            tqdm.write('Epoch %d, loss_init = %.03f' % (epoch, float(loss_init)))
            out_img = generated[0].detach().cpu().numpy().transpose(1, 2, 0)
            pred_xy = refined_xy[0].detach().cpu().numpy().transpose(1, 2, 0)
            out_imshow = np.concatenate(
                [images_warp_np[0].transpose(1, 2, 0), images_mean_np.transpose(1, 2, 0), out_img,
                 visualize_rgb(pred_xy)],
                axis=1)
            if debug:
                if epoch % 500 == 0:
                    visualize_water_height(h_t, epoch, save_path=imgen_init_save_path)
                plt.figure()
                plt.imshow(out_imshow)
                plt.title('Turbulence Image\tAverage Image\tPred Img\tPred xy'
                          .replace("\t", "       "))
                plt.axis('off')
                plt.tight_layout()
                plt.savefig(os.path.join(imgen_init_save_path, f'{epoch}.png'))
                plt.close()

    if debug and num_iter_initialize > 0:
        refined_xy, h_t = predict_model_grid(grid_t_input_single_gd, model_grid, vec_scale, xy_grid_batch_var,
                                             imgen_init_save_path, num_iter_initialize, debug, num_iter_initialize)
        np_to_pil(out_img).save(os.path.join(fresult, f'{pname}_{cur_batch_str}_imgen_init_final.png'))
        visualize_water_height(h_t, num_iter_i - 1, save_path=imgen_init_save_path)

    if save_model:
        torch.save(model_imgen, os.path.join(fresult, f'{pname}_ig_scale_{scale_factor}_FB_{FB_img}_{start_f}.pth'))
        torch.save(model_grid, os.path.join(fresult, f'{pname}_gd_scale_{scale_factor}_FB_{FB_img}_{start_f}.pth'))
    # if load_model:
    #     model_imgen = torch.load('{}/{}_ig_scale_{}_FB_{}_{}.pth'.format(fresult,pname,scale_factor,FB_img,start_f)).type(dtype).cuda()
    #     model_grid = torch.load('{}/{}_gd_scale_{}_FB_{}_{}.pth'.format(fresult,pname,scale_factor,FB_img,start_f))

    img_gt_np = img_gt_np.clip(0, 1)

    num_iter = num_iter_optim


    # the reference frame in the plot
    i = 0

    loss_arr = np.zeros(num_iter)
    psnr_arr_sharp = np.zeros(num_iter)
    ssim_arr_sharp = np.zeros(num_iter)
    lpips_arr_sharp = np.zeros(num_iter)
    ssd_arr_sharp = np.zeros(num_iter)
    ssdg_arr_sharp = np.zeros(num_iter)
    psnr_arr_turb = np.zeros([num_iter, batch_size])
    ssim_arr_turb = np.zeros([num_iter, batch_size])
    lpips_arr_turb = np.zeros([num_iter, batch_size])

    torch_I_hat_k_minus_I_k = torch.zeros([num_iter, batch_size])

    del optimizer_init
    optimizer_optim = torch.optim.Adam(model_params_list, lr=lr_optim)

    for epoch in tqdm(range(num_iter), file=sys.stdout):

        optimizer_optim.zero_grad()

        refined_xy, h_t = predict_model_grid(grid_t_input_single_gd, model_grid, vec_scale, xy_grid_batch_var,
                                             os.path.join(fresult, 'train_all'), epoch, debug, num_iter_optim)
        refined_warp = refined_xy - xy_grid_batch_var  # only the displacement -  for visualization
        refined_uv = torch.cat(
            ((nc - 1.0) * refined_warp[:, 0:1, :, :] / 2, (nr - 1.0) * refined_warp[:, 1:2, :, :] / 2),
            1)  # for visualization

        # Get mask for the warp field
        mask_u1 = (refined_xy[:, 0:1, :, :] > -1).float() * 1
        mask_u2 = (refined_xy[:, 0:1, :, :] < 1).float() * 1
        mask_v1 = (refined_xy[:, 1:2, :, :] > -1).float() * 1
        mask_v2 = (refined_xy[:, 1:2, :, :] < 1).float() * 1
        mask = mask_u1 * mask_u2 * mask_v1 * mask_v2

        # predict sharp image using straight grid
        sharp_imgs_predict = model_imgen(grid_input)  # ImageGenerator(gamma(G_U) x batch_size) = J x batch_size

        # predict turbulent image using forward mapping
        refined_turb_imgs = backwarp_grid(sharp_imgs_predict, refined_xy)  # J_hat_1 to J_hat_k

        # predict turbulent images using sampling grid\
        generated_turb_imgs = model_imgen(GaussianFourierFeatureTransform_B(2, B_var, dim_in_imgen // 2, FB_img)(
            refined_xy))  # I_hat_1 to I_hat_k = ImageGen(gamma(G_1 to G_k))

        # loss function
        loss_optim = torch.nn.functional.l1_loss(generated_turb_imgs * mask,
                                                 img_gt_batch_var * mask)  # (I_hat_i - I_i) * mask
        loss_optim += torch.nn.functional.l1_loss(refined_turb_imgs * mask,
                                                  img_gt_batch_var * mask)  # (J_hat_i - I_i) * mask
        loss_optim += torch.nn.functional.l1_loss(generated_turb_imgs * mask,
                                                  refined_turb_imgs * mask)  # (I_hat_i - J_hat_i) * mask

        loss_I_and_I_hat = torch.nn.functional.l1_loss(generated_turb_imgs * mask,
                                                       img_gt_batch_var * mask).detach().clone()  # (I_hat_i - I_i) * mask
        loss_J_hat_and_I = torch.nn.functional.l1_loss(refined_turb_imgs * mask,
                                                       img_gt_batch_var * mask).detach().clone()  # (J_hat_i - I_i) * mask
        loss_I_hat_J_hat = torch.nn.functional.l1_loss(generated_turb_imgs * mask,
                                                       refined_turb_imgs * mask).detach().clone()  # (I_hat_i - J_hat_i) * mask

        loss_arr[epoch] = loss_optim.detach().clone().cpu().numpy()
        psnr_arr_sharp[epoch] = compare_psnr(
            img_gt_np,
            sharp_imgs_predict[i].detach().cpu().numpy())

        ssim_arr_sharp[epoch] = ssim(
            img_gt_np.transpose(1, 2, 0),
            sharp_imgs_predict[i].detach().cpu().numpy().transpose(1, 2, 0),
            channel_axis=-1, data_range=1.0)

        lpips_arr_sharp[epoch] = compare_lpips_troch.forward(
            torch.from_numpy(img_gt_np).unsqueeze(0).type(dtype).cuda(),
            sharp_imgs_predict[i].detach().unsqueeze(0).type(dtype).cuda())

        ssd_arr_sharp[epoch] = calculate_ssd(img_gt_np, sharp_imgs_predict[i].detach().cpu().numpy())
        ssdg_arr_sharp[epoch] = calculate_ssdg(img_gt_np, sharp_imgs_predict[i].detach().cpu().numpy())

        # log metrics to wandb
        wandb.log({"loss_I_and_I_hat": loss_I_and_I_hat,
                   "loss_J_hat_and_I": loss_J_hat_and_I,
                   "loss_I_hat_J_hat": loss_I_hat_J_hat,
                   "loss_optim": loss_optim,
                   "psnr_sharp": float(psnr_arr_sharp[epoch]),
                   "ssim_sharp": float(ssim_arr_sharp[epoch]),
                   "lpips_sharp": float(lpips_arr_sharp[epoch]),
                   "ssd_sharp": float(ssd_arr_sharp[epoch]),
                   "ssdg_sharp": float(ssdg_arr_sharp[epoch]),
                   })

        for batch_idx in range(batch_size):
            torch_I_hat_k_minus_I_k[epoch][batch_idx] = torch.nn.functional.l1_loss(
                generated_turb_imgs[batch_idx] * mask[batch_idx],
                img_gt_batch_var[batch_idx] * mask[batch_idx]).detach().clone()

            psnr_arr_turb[epoch][batch_idx] = compare_psnr(
                images_warp_np[batch_idx],
                generated_turb_imgs[batch_idx].detach().cpu().numpy())

            ssim_arr_turb[epoch][batch_idx] = ssim(
                images_warp_np[batch_idx].transpose(1, 2, 0),
                generated_turb_imgs[batch_idx].detach().cpu().numpy().transpose(1, 2, 0),
                channel_axis=-1, data_range=1.0)

            lpips_arr_turb[epoch][batch_idx] = compare_lpips_troch.forward(
                torch.from_numpy(images_warp_np[batch_idx]).unsqueeze(0).type(dtype).cuda(),
                generated_turb_imgs[batch_idx].detach().unsqueeze(0).type(dtype).cuda())

            # wandb.log({f"psnr_turb_{batch_idx}": float(psnr_arr_turb[epoch][batch_idx]),
            #            f"ssim_turb_{batch_idx}": float(ssim_arr_turb[epoch][batch_idx]),
            #            f"lpips_turb_{batch_idx}": float(lpips_arr_turb[epoch][batch_idx]),
            #            f"torch_I_hat_k_minus_I_k_{batch_idx}": float(torch_I_hat_k_minus_I_k[epoch][batch_idx]),
            #            })

        loss_optim.backward()
        optimizer_optim.step()

        if epoch % 100 == 0:
            # tqdm.write('Epoch %d, loss = %.03f, psnr_sharp = %.03f, psnr_turb = %.03f' % (
            #     epoch, float(loss), float(psnr_arr_sharp[epoch]), float(psnr_arr_turb[epoch])))

            tqdm.write(f'Epoch {epoch}, loss_optim = {float(loss_optim):.03f},'
                       f' psnr_sharp = {float(psnr_arr_sharp[epoch]):.03f},'
                       f' psnr_turb = {float(psnr_arr_turb[epoch].mean()):.03f}'
                       f' ssim_sharp = {float(ssim_arr_sharp[epoch]):.03f},'
                       f' ssim_turb = {float(ssim_arr_turb[epoch].mean()):.03f}'
                       f' lpips_sharp = {float(lpips_arr_sharp[epoch]):.03f},'
                       f' lpips_turb = {float(lpips_arr_turb[epoch].mean()):.03f}'
                       f' ssd_sharp = {float(ssd_arr_sharp[epoch]):.03f},'
                       f' ssdg_sharp = {float(ssdg_arr_sharp[epoch]):.03f}')
            if debug:
                for i in range(batch_size):
                    path_cur_folder = os.path.join(fresult, 'train_all', str(i))
                    out_img = refined_turb_imgs[i] * mask[i]
                    out_img = out_img.detach().cpu().numpy().transpose(1, 2, 0)
                    sharp_img = sharp_imgs_predict[i].detach().cpu().numpy().transpose(1, 2, 0)
                    warp_img = refined_uv[i].detach().cpu().numpy().transpose(1, 2, 0)
                    out_target = images_warp_np[i].transpose(1, 2, 0)

                    out_imshow = np.concatenate([out_target,
                                                 generated_turb_imgs[i].detach().cpu().numpy().transpose(1, 2, 0),
                                                 out_img,
                                                 visualize_rgb(warp_img),
                                                 sharp_img], axis=1)
                    plt.figure(figsize=(20, 5))
                    plt.imshow(np.clip(out_imshow, 0, 1))
                    plt.title(
                        f'I - input\tI_hat (pred dist) \tJ_hat (resm dist)\t deform warp \tJ rcvrd - output\n'
                        f'psnr turb: {psnr_arr_turb[epoch][i]:.03f} ssim turb: {ssim_arr_turb[epoch][i]:.03f} lpips turb: {lpips_arr_turb[epoch][i]:.03f}\n'
                        f'psnr sharp: {psnr_arr_sharp[epoch]:.03f} ssim sharp: {ssim_arr_sharp[epoch]:.03f} lpips sharp: {lpips_arr_sharp[epoch]:.03f}'.replace(
                            "\t", "       "))
                    plt.axis('off')
                    plt.tight_layout()
                    plt.savefig(os.path.join(path_cur_folder, f'img_{i}_epoch_{epoch}.png'))
                    plt.close()

                    # save I_hat, J_hat, J, warp,
                    # J
                    np_to_pil(sharp_img).save(os.path.join(path_cur_folder, f'J_img_{i}_epoch_{epoch}.png'))
                    np_to_pil(generated_turb_imgs[i].detach().cpu().numpy().transpose(1, 2, 0)).save(
                        os.path.join(path_cur_folder, f'I_hat_img_{i}_epoch_{epoch}.png'))
                    # I_hat
                    np_to_pil(out_img).save(os.path.join(path_cur_folder, f'J_hat_img_{i}_epoch_{epoch}.png'))
                    # warp
                    np_to_pil(visualize_rgb(warp_img)).save(
                        os.path.join(path_cur_folder, f'warp_img_{i}_epoch_{epoch}.png'))
                    if epoch == 0:
                        # I
                        np_to_pil(out_target).save(os.path.join(path_cur_folder, f'I_gt_img_{i}.png'))

        if debug and (epoch % 500 == 0):
            visualize_water_height(h_t, epoch, gt_npy, save_path=os.path.join(fresult, 'train_all'))

    print('creating wandb custom chart')
    wandb.log({"custon_main_loss_per_k": wandb.plot.line_series(
        xs=list(range(num_iter_optim)),
        ys=[torch_I_hat_k_minus_I_k[:, i] for i in range(batch_size)],
        keys=[f"|I{i} - I^{i}|" for i in range(batch_size)],
        title=r"Absolute difference between Ik and I^k",
        xname='Epoch',
    )})
    wandb.log({"custom_psnr": wandb.plot.line_series(
        xs=list(range(num_iter_optim)),
        ys=[psnr_arr_turb[:, i] for i in range(batch_size)],
        keys=[f"PSNR between I{i} and I^{i}" for i in range(batch_size)],
        title=r"PSNR between Ik and I^k",
        xname='Epoch',
    )})
    wandb.log({"custon_SSIM": wandb.plot.line_series(
        xs=list(range(num_iter_optim)),
        ys=[ssim_arr_turb[:, i] for i in range(batch_size)],
        keys=[f"SSIM between I{i} and I^{i}" for i in range(batch_size)],
        title=r"SSIM between Ik and I^k",
        xname='Epoch',
    )})
    wandb.log({"custom_LPIPS": wandb.plot.line_series(
        xs=list(range(num_iter_optim)),
        ys=[lpips_arr_turb[:, i] for i in range(batch_size)],
        keys=[f"LPIPS between I{i} and I^{i}" for i in range(batch_size)],
        title=r"LPIPS between Ik and I^k",
        xname='Epoch',
    )})

    # wandb.log({f"psnr_turb_{batch_idx}": float(psnr_arr_turb[epoch][batch_idx]),
    #            f"ssim_turb_{batch_idx}": float(ssim_arr_turb[epoch][batch_idx]),
    #            f"lpips_turb_{batch_idx}": float(lpips_arr_turb[epoch][batch_idx]),
    #            f"torch_I_hat_k_minus_I_k_{batch_idx}": float(torch_I_hat_k_minus_I_k[epoch][batch_idx]),
    #            })
    print(float(loss_optim))
    plt.plot(loss_arr[:])
    plt.ylabel('loss optim')
    plt.title('loss optim')
    plt.axis('on')
    plt.tight_layout()
    plt.savefig(os.path.join(fresult, f'{pname}_{batch_size}_loss.png'))
    plt.close()

    if save_model:
        torch.save(model_grid, os.path.join(fresult, f'{pname}_gd_final_{FB_img}_{batch_size}.pth'))
        torch.save(model_imgen, os.path.join(fresult, f'{pname}_ig_final_{FB_img}_{batch_size}.pth'))


    refined_xy, h_t = predict_model_grid(grid_t_input_single_gd, model_grid, vec_scale, xy_grid_batch_var,
                                         os.path.join(fresult, 'final_results'), num_iter_optim, True, num_iter_optim)
    refined_warp = refined_xy - xy_grid_batch_var
    refined_uv = torch.cat([(nc - 1.0) * refined_warp[:, 0:1, :, :] / 2, (nr - 1.0) * refined_warp[:, 1:2, :, :] / 2],
                           1)

    visualize_water_height(h_t, num_iter - 1, gt_npy,
                           save_path=os.path.join(fresult, 'final_results'))
    generated_turb_imgs = model_imgen(
        GaussianFourierFeatureTransform_B(2, B_var, dim_in_imgen // 2, FB_img)(refined_xy))

    sharp_img_predict_torch = model_imgen(grid_input[0, ...].unsqueeze(0))
    sharp_img_predict = sharp_img_predict_torch.squeeze(0).detach().cpu().numpy().transpose(1, 2, 0)  # J x 1 x nr x nc

    psnr_final_sharp = compare_psnr(img_gt_np.transpose(1, 2, 0), sharp_img_predict)
    ssim_final_sharp = ssim(img_gt_np.transpose(1, 2, 0), sharp_img_predict, channel_axis=-1, data_range=1.0)
    lpips_final_sharp = compare_lpips_troch.forward(torch.from_numpy(img_gt_np).unsqueeze(0).type(dtype).cuda(),
                                                    sharp_img_predict_torch)
    ssd_final_sharp = calculate_ssd(img_gt_np.transpose(1, 2, 0), sharp_img_predict)
    ssdg_final_sharp = calculate_ssdg(img_gt_np.transpose(1, 2, 0), sharp_img_predict)

    psnr_final_turb = compare_psnr_torch(
        torch.from_numpy(images_warp_np).type(dtype).cuda(),
        generated_turb_imgs)
    ssim_final_turb = compare_ssim_torch(
        torch.from_numpy(images_warp_np).type(dtype).cuda(),
        generated_turb_imgs)
    lpips_final_turb = compare_lpips_troch.forward(
        torch.from_numpy(images_warp_np).type(dtype).cuda(),
        generated_turb_imgs)

    np_to_pil(sharp_img_predict).save(
        os.path.join(fresult, 'final_results', f'sharp_img_{start_f}_{start_f + batch_size - 1}.png'))

    print(
        f'PSNR sharp: {psnr_final_sharp}, SSIM sharp: {ssim_final_sharp}, LPIPS sharp: {lpips_final_sharp}, SSD sharp: {ssd_final_sharp}, SSDG sharp: {ssdg_final_sharp}')
    print(f'PSNR turb: {psnr_final_turb}, SSIM turb: {ssim_final_turb}, LPIPS turb: {lpips_final_turb}')

    for j in range(batch_size):
        out_img = generated_turb_imgs[j].detach().cpu().numpy().transpose(1, 2, 0)

        warp_img = refined_uv[j].detach().cpu().numpy().transpose(1, 2, 0)
        warp_img -= np.min(warp_img)
        warp_img /= np.max(warp_img)

        out_target = img_gt_batch_var[j].detach().cpu().numpy().transpose(1, 2, 0)
        out_imshow = np.concatenate([out_target, out_img, visualize_rgb_norm(warp_img), sharp_img_predict], axis=1)
        plt.figure(figsize=(15, 5))
        plt.suptitle(f'image - {j + start_f}')
        plt.title(
            f'Out Target\tPred grid Backwarp image\tPred grid warp\tSharp Image \n'
            f'psnr turb: {psnr_final_turb:.03f} ssim turb: {ssim_final_turb:.03f} lpips turb: {lpips_final_turb:.03f} \n'
            f'psnr sharp: {psnr_final_sharp:.03f} ssim sharp: {ssim_final_sharp:.03f} lpips sharp: {lpips_final_sharp:.03f}'
            .replace("\t", "       "))
        plt.imshow(out_imshow)
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(os.path.join(fresult, 'final_results', f'final_for_img_{j}.png'))
        plt.close()

        np_to_pil(out_target).save(os.path.join(fresult, 'final_results', f'turb_img_gt_{j + start_f}_FB_{FB_img}.png'))
        np_to_pil(out_img).save(os.path.join(fresult, 'final_results', f'turb_img_{j + start_f}_FB_{FB_img}.png'))
        np_to_pil(visualize_rgb_norm(warp_img)).save(
            os.path.join(fresult, 'final_results', f'warp_img_{j + start_f}_FB_{FB_img}.png'))

    plt.figure(figsize=(8, 5))
    plt.plot(ssim_arr_sharp[:])
    plt.ylabel('SSIM')
    plt.xlabel('Epoch')
    plt.savefig(os.path.join(f'{fresult}', 'final_results', 'ssim_sharp_final.png'))
    plt.close()

    plt.figure(figsize=(8, 5))
    plt.plot(psnr_arr_sharp[:])
    plt.ylabel('PSNR')
    plt.xlabel('Epoch')
    plt.savefig(os.path.join(f'{fresult}', 'final_results', 'psnr_sharp_final.png'))
    plt.close()

    plt.figure(figsize=(8, 5))
    plt.plot(lpips_arr_sharp[:])
    plt.ylabel('LPIPS')
    plt.xlabel('Epoch')
    plt.savefig(os.path.join(f'{fresult}', 'final_results', 'lpips_sharp_final.png'))
    plt.close()

    # save results to csv
    with open(os.path.join(f'{fresult}', 'final_results', 'results.csv'), 'w') as f:
        f.write('Indices, PSNR, SSIM, LPIPS, Height, Width\n')
        f.write(
            f'{start_f} - {start_f + batch_size - 1}, {psnr_final_sharp}, {ssim_final_sharp}, {lpips_final_sharp}, {nr}, {nc}\n')

    return psnr_final_sharp, ssim_final_sharp, lpips_final_sharp.detach().cpu().numpy(), ssd_final_sharp, ssdg_final_sharp, nr, nc


def predict_model_grid(grid_t_input_single_gd, model_grid, vec_scale, xy_grid_batch_var,
                       offset_save_path, epoch, debug, num_iter_optim):
    h_t = model_grid(grid_t_input_single_gd)
    h_t_grad = diff_operators.gradient(h_t, grid_t_input_single_gd)
    h_t_grad_batched = h_t_grad.squeeze(0).permute(3, 0, 1, 2)
    # [batch_size, C=2, nr, nc] - G_1 to G_k * vec_scale [-vec_scale, vec_scale]:
    # refined_xyt = (xy_grid_batch_var + (torch.mean(h_t.squeeze(),dim=[0,1]) * (1 - 1 / 1.33) * h_t_grad_batched[:, :2, :, :]))
    offsets = torch.mean(h_t.squeeze()) * (1 - 1 / 1.33) * h_t_grad_batched[:, :2, :, :]
    if (epoch % 100 == 0 or epoch == num_iter_optim) and debug:
        np.save(os.path.join(offset_save_path, f'offsets_{epoch}.npy'), offsets.detach().cpu().numpy())
    refined_xyt = (xy_grid_batch_var + offsets)
    return refined_xyt, h_t


def write_config_dict(config, config_file_path):
    # create directory if it does not exist
    os.makedirs(os.path.dirname(config_file_path), exist_ok=True)
    # open yml file for writing
    with open(config_file_path, 'w') as f:
        # write all parameters to yml file
        # config is a dictionary
        f.write('parameters:\n')
        json.dump(config, f, indent=4)


if __name__ == '__main__':
    # read parameters from config file, override values if parameter in cli
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', type=int, default=10)
    parser.add_argument('--scale_factor', type=float, default=0.5)
    parser.add_argument('--siren', action='store_false', help='Use Siren architecture')

    parser.add_argument('--dim_in_imgen', type=int, default=256,
                        help='Input dimension for imgen, this also decides the GRFF dimensions. Default 256')
    parser.add_argument('--w0_first_imgen', type=float, default=45.0, help='w0 for first layer')
    parser.add_argument('--num_layers_imgen', type=int, default=3,
                        help='Decide number of layers in imgen SIREN, default 3')
    parser.add_argument('--width_layers_imgen', action=ParseAction, nargs="+", default=[256, 256, 256],
                        help='Decide width of each layer in imgen SIREN, default [256, 256, 256]')

    parser.add_argument('--w0_first_grid', type=float, default=15, help='w0 for first layer')
    parser.add_argument('--num_layers_grid', type=int, default=2,
                        help='Decide number of layers in grid SIREN, default 2')
    parser.add_argument('--width_layers_grid', action=ParseAction, nargs="+", default=[128, 128],
                        help='Decide width of each layer in grid SIREN, default [128, 128]')
    parser.add_argument('--num_iter_initialize', type=int, default=700)
    parser.add_argument('--lr_init', type=float, default=1e-3,
                        help='Decide learning rate for initialization, default 1e-4')

    parser.add_argument('--num_iter_optim', type=int, default=2000)
    parser.add_argument('--lr_optim', type=float, default=1e-3, help='Decide learning rate for optim, default 1e-4')

    parser.add_argument('--environment', type=str, default='Local')
    parser.add_argument('--exp_name', type=str, default='No_exp_name')
    parser.add_argument('--debug', action='store_true', help='Print debug images, default False')
    parser.add_argument('--pname', type=str, default=None, help='Run only on sequence pname, default None')
    parser.add_argument('--bandwidth_img', type=int, default=8,
                        help='The frequency bandwidth for the turbulence field, default 8')
    parser.add_argument('--start_f', type=int, default=0, help='index of first image in the batch')

    config = parser.parse_args().__dict__

    # assert that the number of layers and width of layers are the same
    assert len(config['width_layers_imgen']) == config[
        'num_layers_imgen'], "Number of layers and len of width array should be equal for imgen"
    assert len(config['width_layers_grid']) == config[
        'num_layers_grid'], "Number of layers and len of width array should be equal for grid"

    # assert pname can not be None
    assert config['pname'] is not None, "Pname must have a value"

    config['arch_type'] = 'height_grad' if config['siren'] else 'conv'
    config['scale_factor_str'] = str(config['scale_factor']).replace('.', '')

    # wandb_arch_name = 'Siren' if config['siren'] else 'ReLU'
    # start a new wandb run to track this script
    run = wandb.init(
        # set the wandb project where this run will be logged
        project="SireNDIR",
        # name=f"{config['environment'][0]}_{wandb_arch_name}_BS_{config['batch_size']}_SF_{config['scale_factor_str']}_{config['dataset_name']}_{config['cur_batch_str']}_{config['exp_name']}",
        tags=[config['environment'], config['exp_name'], config['pname']],
        # track hyperparameters and run metadata
        config=config
    )
    config = run.config
    print(config.pname)
    config['result_folder'] = os.path.join('result', config['arch_type'],
                                           str(config['batch_size']), config['scale_factor_str'],
                                           config['exp_name'].replace(' ', '_'), config['pname'], run.name)

    # Configure logging
    write_config_dict(config.as_dict(), os.path.join(config['result_folder'], 'config.yml'))
    path_to_log = os.path.join(config['result_folder'], 'run.log')
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s - %(levelname)s - %(message)s',
                        handlers=[
                            logging.FileHandler(path_to_log),
                            logging.StreamHandler()
                        ])

    # Redirect stdout and stderr to the Logger
    sys.stdout = Logger(path_to_log)
    sys.stderr = Logger(path_to_log)

    # # find all folders that contain "james_real1_" in the name if config['sequence'] is None else dataset_name is equal to "james_real1_"+config['sequence']
    # if config['pname'] is None:
    #     datasets_names = [f for f in os.listdir('data') if 'james_real1_' in f]
    # else:
    #     datasets_names = [f'james_real1_{config["pname"]}']

    psnr_list = []
    ssim_list = []
    lpips_list = []
    ssd_list = []
    ssdg_list = []
    indices_list = []
    config['dataset_name'] = config.pname.split('_')[2].capitalize()
    extensions = ['.jpg', '.JPG', '.png', '.ppm', '.bmp', '.pgm', '.tif']
    # get the number of png files in pname
    # If the config['all_batches'] flag is set to True, dataset_len is assigned the length of the filtered list (i.e., the total number of .png files),
    # otherwise, it is assigned the value of config['batch_size'].
    # This ensures that we run on each sequence in datasets_names one time only, with the number of frames specified in config['batch_size'].
    dataset_len = len([f for f in os.listdir(os.path.join('data', config.pname, 'img')) if
                       f.endswith(tuple(extensions))])

    # If this condition is true, it means that there are not enough frames left to form a complete batch.
    # To handle this, the code adjusts the starting frame index by subtracting the difference between the batch size and the remaining frames
    if dataset_len - config['start_f'] < config['batch_size']:
        config['start_f'] -= (config['batch_size'] - (dataset_len - config['start_f']))

    config["cur_batch_str"] = f'{config["start_f"]}_{config["start_f"] + config["batch_size"] - 1}'

    print(f"Working on {config['dataset_name']}, indices {config['cur_batch_str']}")

    # Print config
    print(json.dumps(config.as_dict(), indent=4))

    batch_psnr, batch_ssim, batch_lpips, batch_ssd, batch_ssdg, height, width = train_batch(run.name,
                                                                                            **wandb.config)

    wandb.finish()

    print(
        f"Indices {config['start_f']} - {config['start_f'] + config['batch_size'] - 1}: PSNR: {batch_psnr:.03f}, SSIM: {batch_ssim:.03f}, LPIPS: {batch_lpips:.03f}, SSD: {batch_ssd:.05f}, SSDG: {batch_ssdg:.05f}")
    indices_list.append(f"{config['start_f']} - {config['start_f'] + config['batch_size']}")
    psnr_list.append(batch_psnr)
    ssim_list.append(batch_ssim)
    lpips_list.append(batch_lpips)
    ssd_list.append(batch_ssd)
    ssdg_list.append(batch_ssdg)

    wandb.finish()
    torch.cuda.empty_cache()

# create dataframe with results and save to csv, the dataframe will have the following columns:
# Dataset name, mean PSNR, mean SSIM, Height, Width
df = pd.DataFrame(list(zip(psnr_list, ssim_list, lpips_list, ssd_list, ssdg_list)),
                  columns=['PSNR', 'SSIM', 'LPIPS', 'SSD', 'SSDG'])
# calculate mean
df['PSNR'] = df['PSNR']
df['SSIM'] = df['SSIM']
df['LPIPS'] = df['LPIPS']
df['SSD'] = df['SSD']
df['SSDG'] = df['SSDG']
df['Dataset'] = config['dataset_name'] + f'_{run.name}'
df['Height'] = height
df['Width'] = width
# change order of columns to [Dataset, PSNR, SSIM, Height, Width]
df = df[['Dataset', 'PSNR', 'SSIM', 'LPIPS', 'SSD', 'SSDG', 'Height', 'Width']]

df.to_csv(os.path.join(config['result_folder'], f"results_{config['dataset_name']}.csv"),
          index=False)