import json
import os
from pathlib import Path
import torch
import torch.nn as nn
import torchvision
import cv2
import sys
from os import path as osp

from PIL import Image
import PIL
import numpy as np

import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
import plotly.offline as pyo
from plotly.subplots import make_subplots

import cv2


def calculate_ssd(image1, image2):
    """
    Calculate the Sum of Squared Differences (SSD) between two images.
    """
    if image1.shape != image2.shape:
        raise ValueError("Input images must have the same dimensions.")

    if len(image1.shape) == 3 and image1.shape[0] == 3:
        image1 = image1.transpose(1, 2, 0)
    if len(image2.shape) == 3 and image2.shape[0] == 3:
        image2 = image2.transpose(1, 2, 0)

    image1_gray = cv2.cvtColor(image1, cv2.COLOR_BGR2GRAY)
    image2_gray = cv2.cvtColor(image2, cv2.COLOR_BGR2GRAY)

    # Compute the SSD
    ssd = np.sum((image1_gray - image2_gray) ** 2) / (image1_gray.shape[0] * image1_gray.shape[1])
    return ssd


def calculate_ssdg(image1, image2):
    """
    Calculate the Sum of Squared Differences in Gradient (SSDG) between two images.
    """
    if image1.shape != image2.shape:
        raise ValueError("Input images must have the same dimensions.")

    if len(image1.shape) == 3 and image1.shape[0] == 3:
        image1 = image1.transpose(1, 2, 0)
    if len(image2.shape) == 3 and image2.shape[0] == 3:
        image2 = image2.transpose(1, 2, 0)

    image1_gray = cv2.cvtColor(image1, cv2.COLOR_BGR2GRAY)
    image2_gray = cv2.cvtColor(image2, cv2.COLOR_BGR2GRAY)
    # Compute gradients
    grad_x1 = cv2.Sobel(image1_gray, cv2.CV_64F, 1, 0, ksize=3)
    grad_y1 = cv2.Sobel(image1_gray, cv2.CV_64F, 0, 1, ksize=3)
    grad_x2 = cv2.Sobel(image2_gray, cv2.CV_64F, 1, 0, ksize=3)
    grad_y2 = cv2.Sobel(image2_gray, cv2.CV_64F, 0, 1, ksize=3)

    # Compute the SSD on gradients
    ssdg = np.sum((grad_x1 - grad_x2) ** 2 + (grad_y1 - grad_y2) ** 2) / (
            grad_x1.shape[0] * grad_x1.shape[1] * 2)
    return ssdg


def backwarp(tenInput, tenFlow):
    backwarp_tenGrid = {}
    if str(tenFlow.size()) not in backwarp_tenGrid:
        tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3]).view(1, 1, 1, tenFlow.shape[3]).expand(
            tenFlow.shape[0], -1, tenFlow.shape[2], -1)
        tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2]).view(1, 1, tenFlow.shape[2], 1).expand(
            tenFlow.shape[0], -1, -1, tenFlow.shape[3])

        backwarp_tenGrid[str(tenFlow.size())] = torch.cat([tenHorizontal, tenVertical], 1).cuda()
    # end

    tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
                         tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)

    return torch.nn.functional.grid_sample(input=tenInput,
                                           grid=(backwarp_tenGrid[str(tenFlow.size())] + tenFlow).permute(0, 2, 3, 1),
                                           mode='bilinear', padding_mode='zeros')


def backwarp_grid(tenInput, tenFlow_xy):
    return torch.nn.functional.grid_sample(input=tenInput, grid=tenFlow_xy.permute(0, 2, 3, 1), mode='bilinear',
                                           padding_mode='zeros')


def im_resize(im, scale_factor):
    width = int(im.size[1] * scale_factor)
    height = int(im.size[0] * scale_factor)
    newsize = (height, width)
    #     im1 = im.resize(newsize)
    return im.resize(newsize)


def visualize_rgb(warp_np):
    #     warp_np = warp_np.transpose(1,2,0)
    nr = warp_np.shape[0]
    nc = warp_np.shape[1]
    warp_np = (warp_np - np.amin(warp_np)) / (np.amax(warp_np) - np.amin(warp_np))
    one_pad = np.ones((nr, nc, 1))
    out_warp_np = np.concatenate((warp_np, one_pad), axis=-1)
    return out_warp_np


def visualize_rgb_norm(warp_np):
    #     warp_np = warp_np.transpose(1,2,0)
    nr = warp_np.shape[0]
    nc = warp_np.shape[1]
    #     warp_np = (warp_np - np.amin(warp_np))/(np.amax(warp_np) - np.amin(warp_np))
    one_pad = np.ones((nr, nc, 1))
    out_warp_np = np.concatenate((warp_np, one_pad), axis=-1)
    return out_warp_np


def load_gt_npy_with_batch_str(path, batch_str):
    batch_start, batch_end = map(int, batch_str.split('_'))
    files = list(Path(path).glob('*.npy'))
    sorted_files = sorted(files, key=lambda x: int(x.stem.split('_')[-1]))
    filtered_files = [f for f in sorted_files if batch_start + 1 <= int(f.stem.split('_')[-1]) <= batch_end + 1]

    depths = [np.load(file) for file in filtered_files]
    depths_hwt = np.stack(depths, axis=-1)
    depth_hwt_tensor = torch.from_numpy(depths_hwt).float()
    return depth_hwt_tensor


def load_gt_npy(path, batch_size):
    depths = []
    for idx, file in enumerate(sorted(os.listdir(path), key=lambda x: int(x.split('_')[-1].split('.')[0]))):
        if idx < batch_size:
            file = osp.join(path, file)
            print(file)
            depths.append(np.load(file))
    depths_hwt = np.stack(depths, axis=-1)
    depth_hwt_tensor = torch.from_numpy(depths_hwt).float()
    return depth_hwt_tensor


def visualize_water_height(height_tensor,
                           epoch,
                           with_gt=None,
                           save_path=None,
                           json_path=None,
                           crange_height=None,
                           crange_gt=None,
                           global_colorscale=False):
    if json_path is not None:
        with open(json_path, 'r') as json_file:
            data = json.load(json_file)
            height_np = np.array(data['height_np'])
            height_np = 2 * (height_np - height_np.min()) / (height_np.max() - height_np.min()) - 1
            gt_np = np.array(data['gt_np']) if data['gt_np'] is not None else None
            if gt_np is not None:
                gt_np = 2 * (gt_np - gt_np.min()) / (gt_np.max() - gt_np.min()) - 1
    else:
        # Convert the height tensor to a numpy array and squeeze to remove single-dimensional entries
        height_np = height_tensor.cpu().detach().numpy().squeeze()
        gt_np = with_gt.cpu().detach().numpy().squeeze() if with_gt is not None else None

    # Initialize z-axis limits for height_np
    if crange_height is not None:
        height_z_min, height_z_max = crange_height
    else:
        height_z_min, height_z_max = -5.0, 5.0

    if crange_gt is not None:
        gt_z_min, gt_z_max = crange_gt
    else:
        # Initialize z-axis limits for gt_np
        gt_z_min, gt_z_max = float('inf'), float('-inf')

    if global_colorscale:
        gt_z_min = height_z_min = -1.0
        gt_z_max = height_z_max = 1.0
    else:
        height_z_min, height_z_max = height_np.min(), height_np.max()

    frames = []
    for t in range(height_np.shape[-1]):
        data = [go.Surface(z=height_np[:, :, t], cmin=height_z_min, cmax=height_z_max)]
        if gt_np is not None:
            gt_z_min = min(gt_z_min, gt_np.min())
            gt_z_max = max(gt_z_max, gt_np.max())
            data.append(go.Surface(z=gt_np[:, :, t], showscale=False, cmin=gt_z_min, cmax=gt_z_max))
        frames.append(go.Frame(data=data, name=str(t)))

    # Create the initial surface plot
    data = [go.Surface(z=height_np[:, :, 0], cmin=height_z_min, cmax=height_z_max)]
    if gt_np is not None:
        data.append(go.Surface(z=gt_np[:, :, 0], showscale=False, cmin=gt_z_min, cmax=gt_z_max))

    if gt_np is not None:
        # Create subplots
        fig = make_subplots(rows=1, cols=2, specs=[[{'type': 'surface'}, {'type': 'surface'}]])

        # Add initial data to subplots
        fig.add_trace(data[0], row=1, col=1)
        fig.add_trace(data[1], row=1, col=2)
    else:
        # Create a single plot
        fig = go.Figure(data=data)

    # Add frames to the figure
    fig.frames = frames

    # Add slider to the layout
    sliders = [dict(
        steps=[dict(method='animate',
                    args=[[str(t)], dict(mode='immediate',
                                         frame=dict(duration=100, redraw=True),
                                         transition=dict(duration=0))],
                    label=str(t)) for t in range(height_np.shape[-1])],
        active=0,
        transition=dict(duration=0),
        x=0.1,
        xanchor='left',
        y=0,
        yanchor='top'
    )]

    # Update layout for better visualization
    fig.update_layout(title=f'Water surface at epoch {epoch}', autosize=True,
                      scene=dict(xaxis_title='X Axis',
                                 yaxis_title='Y Axis',
                                 zaxis_title='Height',
                                 zaxis=dict(range=[height_z_min, height_z_max])),
                      sliders=sliders)

    if gt_np is not None:
        fig.update_layout(scene2=dict(zaxis=dict(range=[gt_z_min, gt_z_max])))

    # save to disk
    if save_path is not None:
        if json_path is not None:
            epoch = 5000
            fname = f'height_epoch_{epoch}.html'
        else:
            fname = f'height_epoch_{epoch}.html'
        html_save_path = osp.join(save_path, fname)
        pyo.plot(fig, filename=html_save_path, auto_open=False)

    # Save the raw data to a JSON file
    if save_path is not None:
        if json_path is not None:
            epoch = 5000
            fname = f'modded_height_epoch_{epoch}.json'
        else:
            fname = f'height_epoch_{epoch}.json'
    json_save_path = osp.join(save_path, fname)
    raw_data = {
        'height_np': height_np.tolist(),
        'gt_np': gt_np.tolist() if gt_np is not None else None
    }
    with open(json_save_path, 'w') as json_file:
        json.dump(raw_data, json_file)



def np_to_cv2(im):
    # im is input image in numpy format with shape (H,W,C) and range [0,1]
    im = (im * 255).astype(np.uint8)
    im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
    return im


def has_file_allowed_extension(filename, extensions):
    """Checks if a file is an allowed extension.

    Args:
        filename (string): path to a file

    Returns:
        bool: True if the filename ends with a known image extension
    """
    filename_lower = filename.lower()
    return any(filename_lower.endswith(ext) for ext in extensions)


def crop_image(img, d=32):
    '''Make dimensions divisible by `d`'''

    new_size = (img.size[0] - img.size[0] % d,
                img.size[1] - img.size[1] % d)

    bbox = [
        int((img.size[0] - new_size[0]) / 2),
        int((img.size[1] - new_size[1]) / 2),
        int((img.size[0] + new_size[0]) / 2),
        int((img.size[1] + new_size[1]) / 2),
    ]

    img_cropped = img.crop(bbox)
    return img_cropped


def get_params(opt_over, net, net_input, downsampler=None):
    '''Returns parameters that we want to optimize over.

    Args:
        opt_over: comma separated list, e.g. "net,input" or "net"
        net: network
        net_input: torch.Tensor that stores input `z`
    '''
    opt_over_list = opt_over.split(',')
    params = []

    for opt in opt_over_list:

        if opt == 'net':
            params += [x for x in net.parameters()]
        elif opt == 'down':
            assert downsampler is not None
            params = [x for x in downsampler.parameters()]
        elif opt == 'input':
            net_input.requires_grad = True
            params += [net_input]
        else:
            assert False, 'what is it?'

    return params


def get_image_grid(images_np, nrow=8):
    '''Creates a grid from a list of images by concatenating them.'''
    images_torch = [torch.from_numpy(x) for x in images_np]
    torch_grid = torchvision.utils.make_grid(images_torch, nrow)

    return torch_grid.numpy()


def plot_image_grid(images_np, nrow=8, factor=1, interpolation='lanczos'):
    """Draws images in a grid
    
    Args:
        images_np: list of images, each image is np.array of size 3xHxW of 1xHxW
        nrow: how many images will be in one row
        factor: size if the plt.figure 
        interpolation: interpolation used in plt.imshow
    """
    n_channels = max(x.shape[0] for x in images_np)
    assert (n_channels == 3) or (n_channels == 1), "images should have 1 or 3 channels"

    images_np = [x if (x.shape[0] == n_channels) else np.concatenate([x, x, x], axis=0) for x in images_np]

    grid = get_image_grid(images_np, nrow)

    plt.figure(figsize=(len(images_np) + factor, 12 + factor))

    if images_np[0].shape[0] == 1:
        plt.imshow(grid[0], cmap='gray', interpolation=interpolation)
    else:
        plt.imshow(grid.transpose(1, 2, 0), interpolation=interpolation)

    plt.show()

    return grid


def plot_image_grid_save(fname, images_np, nrow=8, factor=1, interpolation='lanczos'):
    """Draws images in a grid
    
    Args:
        images_np: list of images, each image is np.array of size 3xHxW of 1xHxW
        nrow: how many images will be in one row
        factor: size if the plt.figure 
        interpolation: interpolation used in plt.imshow
    """
    n_channels = max(x.shape[0] for x in images_np)
    assert (n_channels == 3) or (n_channels == 1), "images should have 1 or 3 channels"

    images_np = [x if (x.shape[0] == n_channels) else np.concatenate([x, x, x], axis=0) for x in images_np]

    grid = get_image_grid(images_np, nrow)

    plt.figure(figsize=(8, 3))

    if images_np[0].shape[0] == 1:
        plt.imshow(grid[0], cmap='gray', interpolation=interpolation)
    else:
        plt.imshow(grid.transpose(1, 2, 0), interpolation=interpolation)

    plt.tight_layout()
    plt.savefig(fname)
    plt.show()

    return grid


def load(path):
    """Load PIL image. If the image is RGBD, load it as RGB."""
    img = Image.open(path)
    if img.mode == 'RGBA':
        img = img.convert('RGB')
    return img


def get_image(path, imsize=-1):
    """Load an image and resize to a specific size.

    Args: 
        path: path to image
        imsize: tuple or scalar with dimensions; -1 for `no resize`
    """
    img = load(path)

    if isinstance(imsize, int):
        imsize = (imsize, imsize)

    if imsize[0] != -1 and img.size != imsize:
        if imsize[0] > img.size[0]:
            img = img.resize(imsize, Image.BICUBIC)
        else:
            img = img.resize(imsize, Image.ANTIALIAS)

    img_np = pil_to_np(img)

    return img, img_np


def fill_noise(x, noise_type):
    """Fills tensor `x` with noise of type `noise_type`."""
    if noise_type == 'u':
        x.uniform_()
    elif noise_type == 'n':
        x.normal_()
    else:
        assert False


def get_noise_batch(batch_size, input_depth, method, spatial_size, noise_type='u', var=1. / 10):
    """Returns a pytorch.Tensor of size (1 x `input_depth` x `spatial_size[0]` x `spatial_size[1]`) 
    initialized in a specific way.
    Args:
        input_depth: number of channels in the tensor
        method: `noise` for fillting tensor with noise; `meshgrid` for np.meshgrid
        spatial_size: spatial size of the tensor to initialize
        noise_type: 'u' for uniform; 'n' for normal
        var: a factor, a noise will be multiplicated by. Basically it is standard deviation scaler. 
    """
    if isinstance(spatial_size, int):
        spatial_size = (spatial_size, spatial_size)
    if method == 'noise':
        shape = [batch_size, input_depth, spatial_size[0], spatial_size[1]]
        net_input = torch.zeros(shape)

        fill_noise(net_input, noise_type)
        net_input *= var
    elif method == 'meshgrid':
        assert input_depth == 2
        X, Y = np.meshgrid(np.arange(0, spatial_size[1]) / float(spatial_size[1] - 1),
                           np.arange(0, spatial_size[0]) / float(spatial_size[0] - 1))
        meshgrid = np.concatenate([X[None, :], Y[None, :]])
        net_input = np_to_torch(meshgrid)
    else:
        assert False

    return net_input


def get_noise(input_depth, method, spatial_size, noise_type='u', var=1. / 10):
    """Returns a pytorch.Tensor of size (1 x `input_depth` x `spatial_size[0]` x `spatial_size[1]`) 
    initialized in a specific way.
    Args:
        input_depth: number of channels in the tensor
        method: `noise` for fillting tensor with noise; `meshgrid` for np.meshgrid
        spatial_size: spatial size of the tensor to initialize
        noise_type: 'u' for uniform; 'n' for normal
        var: a factor, a noise will be multiplicated by. Basically it is standard deviation scaler. 
    """
    if isinstance(spatial_size, int):
        spatial_size = (spatial_size, spatial_size)
    if method == 'noise':
        shape = [1, input_depth, spatial_size[0], spatial_size[1]]
        net_input = torch.zeros(shape)

        fill_noise(net_input, noise_type)
        net_input *= var
    elif method == 'meshgrid':
        assert input_depth == 2
        X, Y = np.meshgrid(np.arange(0, spatial_size[1]) / float(spatial_size[1] - 1),
                           np.arange(0, spatial_size[0]) / float(spatial_size[0] - 1))
        meshgrid = np.concatenate([X[None, :], Y[None, :]])
        net_input = np_to_torch(meshgrid)
    else:
        assert False

    return net_input


def pil_to_np(img_PIL):
    '''Converts image in PIL format to np.array.
    
    From W x H x C [0...255] to C x W x H [0..1]
    '''
    ar = np.array(img_PIL)

    if len(ar.shape) == 3:
        ar = ar.transpose(2, 0, 1)
    else:
        ar = ar[None, ...]

    return ar.astype(np.float32) / 255.


def np_to_pil(img_np):
    '''Converts image in np.array format to PIL image.
    
    From C x W x H [0..1] to  W x H x C [0...255]
    '''
    ar = np.clip(img_np * 255, 0, 255).astype(np.uint8)

    if img_np.shape[0] == 1:
        ar = ar[0]
    elif img_np.shape[0] == 3:
        ar = ar.transpose(1, 2, 0)

    return Image.fromarray(ar)


def np_to_torch(img_np):
    '''Converts image in numpy.array to torch.Tensor.

    From C x W x H [0..1] to  C x W x H [0..1]
    '''
    return torch.from_numpy(img_np)[None, :]


def torch_to_np(img_var):
    '''Converts an image in torch.Tensor format to np.array.

    From 1 x C x W x H [0..1] to  C x W x H [0..1]
    '''
    return img_var.detach().cpu().numpy()[0]


def optimize(optimizer_type, parameters, closure, LR, num_iter):
    """Runs optimization loop.

    Args:
        optimizer_type: 'LBFGS' of 'adam'
        parameters: list of Tensors to optimize over
        closure: function, that returns loss variable
        LR: learning rate
        num_iter: number of iterations 
    """
    if optimizer_type == 'LBFGS':
        # Do several steps with adam first
        optimizer = torch.optim.Adam(parameters, lr=0.001)
        for j in range(100):
            optimizer.zero_grad()
            closure()
            optimizer.step()

        print('Starting optimization with LBFGS')

        def closure2():
            optimizer.zero_grad()
            return closure()

        optimizer = torch.optim.LBFGS(parameters, max_iter=num_iter, lr=LR, tolerance_grad=-1, tolerance_change=-1)
        optimizer.step(closure2)

    elif optimizer_type == 'adam':
        print('Starting optimization with ADAM')
        optimizer = torch.optim.Adam(parameters, lr=LR)

        for j in range(num_iter):
            optimizer.zero_grad()
            closure()
            optimizer.step()
    else:
        assert False
