# ---------------------------------------------------------------
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# This file has been modified from ddrm.
#
# Source:
# https://github.com/bahjat-kawar/ddrm/blob/master/functions/svd_replacement.py#L171
# https://github.com/bahjat-kawar/ddrm/blob/master/runners/diffusion.py#L264
# https://github.com/bahjat-kawar/ddrm/blob/master/functions/svd_replacement.py#L314
#
# The license for the original version of this file can be
# found in this directory (LICENSE_DDRM).
# The modifications to this file are subject to the same license.
# ---------------------------------------------------------------

import numpy as np
import torch
from .base import H_functions
import numpy as np
from motionblur.motionblur import Kernel
import scipy
from torch import nn
from scipy import ndimage
from scipy.interpolate import interp2d
import math
from ipdb import set_trace as debug

def cubic(x):
    absx = torch.abs(x)
    absx2 = absx**2
    absx3 = absx**3
    return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + (
        -0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2
    ) * (((absx > 1) * (absx <= 2)).type_as(absx))

def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
    if (scale < 1) and (antialiasing):
        # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width
        kernel_width = kernel_width / scale

    # Output-space coordinates
    x = torch.linspace(1, out_length, out_length)

    # Input-space coordinates. Calculate the inverse mapping such that 0.5
    # in output space maps to 0.5 in input space, and 0.5+scale in output
    # space maps to 1.5 in input space.
    u = x / scale + 0.5 * (1 - 1 / scale)

    # What is the left-most pixel that can be involved in the computation?
    left = torch.floor(u - kernel_width / 2)

    # What is the maximum number of pixels that can be involved in the
    # computation?  Note: it's OK to use an extra pixel here; if the
    # corresponding weights are all zero, it will be eliminated at the end
    # of this function.
    P = math.ceil(kernel_width) + 2

    # The indices of the input pixels involved in computing the k-th output
    # pixel are in row k of the indices matrix.
    indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view(
        1, P).expand(out_length, P)

    # The weights used to compute the k-th output pixel are in row k of the
    # weights matrix.
    distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices
    # apply cubic kernel
    if (scale < 1) and (antialiasing):
        weights = scale * cubic(distance_to_center * scale)
    else:
        weights = cubic(distance_to_center)
    # Normalize the weights matrix so that each row sums to 1.
    weights_sum = torch.sum(weights, 1).view(out_length, 1)
    weights = weights / weights_sum.expand(out_length, P)

    # If a column in weights is all zero, get rid of it. only consider the first and last column.
    weights_zero_tmp = torch.sum((weights == 0), 0)
    if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
        indices = indices.narrow(1, 1, P - 2)
        weights = weights.narrow(1, 1, P - 2)
    if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
        indices = indices.narrow(1, 0, P - 2)
        weights = weights.narrow(1, 0, P - 2)
    weights = weights.contiguous()
    indices = indices.contiguous()
    sym_len_s = -indices.min() + 1
    sym_len_e = indices.max() - in_length
    indices = indices + sym_len_s - 1
    return weights, indices, int(sym_len_s), int(sym_len_e)



def imresize(img, scale, antialiasing=True):
    # Now the scale should be the same for H and W
    # input: img: pytorch tensor, CHW or HW [0,1]
    # output: CHW or HW [0,1] w/o round
    need_squeeze = True if img.dim() == 2 else False
    if need_squeeze:
        img.unsqueeze_(0)
    in_C, in_H, in_W = img.size()
    out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale)
    kernel_width = 4
    kernel = 'cubic'

    # Return the desired dimension order for performing the resize.  The
    # strategy is to perform the resize first along the dimension with the
    # smallest scale factor.
    # Now we do not support this.

    # get weights and indices
    weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices(
        in_H, out_H, scale, kernel, kernel_width, antialiasing)
    weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices(
        in_W, out_W, scale, kernel, kernel_width, antialiasing)
    # process H dimension
    # symmetric copying
    img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W)
    img_aug.narrow(1, sym_len_Hs, in_H).copy_(img)

    sym_patch = img[:, :sym_len_Hs, :]
    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
    sym_patch_inv = sym_patch.index_select(1, inv_idx)
    img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv)

    sym_patch = img[:, -sym_len_He:, :]
    inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
    sym_patch_inv = sym_patch.index_select(1, inv_idx)
    img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv)

    out_1 = torch.FloatTensor(in_C, out_H, in_W)
    kernel_width = weights_H.size(1)
    for i in range(out_H):
        idx = int(indices_H[i][0])
        for j in range(out_C):
            out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i])

    # process W dimension
    # symmetric copying
    out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We)
    out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1)

    sym_patch = out_1[:, :, :sym_len_Ws]
    inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
    sym_patch_inv = sym_patch.index_select(2, inv_idx)
    out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv)

    sym_patch = out_1[:, :, -sym_len_We:]
    inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
    sym_patch_inv = sym_patch.index_select(2, inv_idx)
    out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv)

    out_2 = torch.FloatTensor(in_C, out_H, out_W)
    kernel_width = weights_W.size(1)
    for i in range(out_W):
        idx = int(indices_W[i][0])
        for j in range(out_C):
            out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i])
    if need_squeeze:
        out_2.squeeze_()
    return out_2


def shift_pixel(x, sf, upper_left=True):
    """shift pixel for super-resolution with different scale factors
    Args:
        x: WxHxC or WxH
        sf: scale factor
        upper_left: shift direction
    """
    h, w = x.shape[:2]
    shift = (sf-1)*0.5
    xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0)
    if upper_left:
        x1 = xv + shift
        y1 = yv + shift
    else:
        x1 = xv - shift
        y1 = yv - shift

    x1 = np.clip(x1, 0, w-1)
    y1 = np.clip(y1, 0, h-1)

    if x.ndim == 2:
        x = interp2d(xv, yv, x)(x1, y1)
    if x.ndim == 3:
        for i in range(x.shape[-1]):
            x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1)

    return x

def modcrop(img, sf):
    '''
    img: tensor image, NxCxWxH or CxWxH or WxH
    sf: scale factor
    '''
    w, h = img.shape[-2:]
    im = img.clone()
    return im[..., :w - w % sf, :h - h % sf]

def circular_pad(x, pad):
    '''
    # x[N, 1, W, H] -> x[N, 1, W + 2 pad, H + 2 pad] (pariodic padding)
    '''
    x = torch.cat([x, x[:, :, 0:pad, :]], dim=2)
    x = torch.cat([x, x[:, :, :, 0:pad]], dim=3)
    x = torch.cat([x[:, :, -2 * pad:-pad, :], x], dim=2)
    x = torch.cat([x[:, :, :, -2 * pad:-pad], x], dim=3)
    return x

def build_sisr(opt, log, k, factor=2, std=1.5):
    log.info(f"[Corrupt] sisr (4x): {std=}  ...")

    factor = 2
    def sisr(img):
        b, c, w, h = img.shape
        img = img.to(opt.device)

        img = modcrop(img, factor)

        img = (img + 1) / 2
        img = ndimage.filters.convolve(img.cpu(), np.expand_dims(np.expand_dims(k, axis=0), axis=0), mode='wrap')
        # y = imresize(torch.tensor(img).squeeze(), 1 / factor, True)
        st = 0
        y = img[..., st::factor, st::factor]
        y = torch.tensor(y).squeeze()

        # img = img * 2 - 1
        # y += torch.randn_like(y) * 0.005

        aty = imresize(y, factor, False).cpu().unsqueeze(0).numpy()
        aty = ndimage.filters.convolve(aty,
                                       np.expand_dims(np.expand_dims(torch.flip(k, [0, 1]), axis=0), axis=0), mode='wrap')
        aty_final = shift_pixel(torch.tensor(aty).squeeze(0).permute(1, 2, 0).numpy(), factor)


        return torch.tensor(aty_final).permute(2, 0, 1).unsqueeze(0).to(opt.device) * 2 - 1, y.unsqueeze(0).to(opt.device) * 2 - 1


    return sisr


# def fmult(img, sr_filter='bicubic', factor=4, image_size=256):
#     sr_bicubic = build_sr_bicubic(factor, img.device, image_size)
#     b, c, w, h = img.shape
#     _img = sr_bicubic.H(img).reshape(b, c, w // factor, h // factor)
#     return _img

def fmult(img, k, std=1.5, factor=2, image_size=256):
    img = (img + 1) / 2
    device_name = img.device
    img = ndimage.filters.convolve(img.cpu(), np.expand_dims(np.expand_dims(k, axis=0), axis=0), mode='wrap')
    # y = imresize(torch.tensor(img).squeeze(), 1 / factor, False).to(device_name)
    st = 0
    y = img[..., st::factor, st::factor]
    y = torch.tensor(y).squeeze().to(device_name)
    return y.unsqueeze(0) * 2 - 1

def ftran(img, k, std=1.5, factor=2, image_size=256):
    img = (img + 1) / 2
    device_name = img.device
    aty = imresize(torch.tensor(img).squeeze(), factor, True).to(device_name).unsqueeze(0)
    aty = ndimage.filters.convolve(aty,
                                   np.expand_dims(np.expand_dims(torch.flip(k, [0, 1]), axis=0), axis=0), mode='wrap')
    aty_final = shift_pixel(torch.tensor(aty).squeeze().permute(1, 2, 0).numpy(), factor)
    aty_final = torch.tensor(aty_final)
    # img = img * 2 - 1
    return aty_final.permute(2, 0, 1).unsqueeze(0) * 2 - 1



class Blurkernel(nn.Module):
    def __init__(self, blur_type='gaussian', kernel_size=31, std=1.5, device=None):
        super().__init__()
        self.blur_type = blur_type
        self.kernel_size = kernel_size
        self.std = std
        self.device = device
        self.seq = nn.Sequential(
            nn.ReflectionPad2d(self.kernel_size//2),
            nn.Conv2d(3, 3, self.kernel_size, stride=1, padding=0, bias=False, groups=3)
        )

        self.weights_init()

    def forward(self, x):
        return self.seq(x)

    def weights_init(self):
        n = np.zeros((self.kernel_size, self.kernel_size))
        n[self.kernel_size // 2, self.kernel_size // 2] = 1
        k = scipy.ndimage.gaussian_filter(n, sigma=self.std)
        k = torch.from_numpy(k)
        self.k = k
        for name, f in self.named_parameters():
            f.data.copy_(k)

    def update_weights(self, k):
        if not torch.is_tensor(k):
            k = torch.from_numpy(k).to(self.device)
        for name, f in self.named_parameters():
            f.data.copy_(k)

    def get_kernel(self):
        return self.k




#
# ax = ndimage.filters.convolve(im.squeeze().permute(1, 2, 0).cpu(),
#                               np.expand_dims(self.kernel_forward.squeeze().cpu().numpy(), axis=2),
#                               mode='wrap')
# ax = imresize(torch.tensor(ax).permute(2, 0, 1), 1 / self.sr_forward, False)
# atax = imresize(ax, self.sr_forward, False).permute(1, 2, 0).numpy()
#
# atax = ndimage.filters.convolve(atax,
#                                 np.expand_dims(self.kernel_forward.squeeze().numpy()[::-1, ::-1], axis=2),
#                                 mode='wrap')
# atax = shift_pixel(atax, self.sr_forward)
# atax = torch.tensor(atax).permute(2, 0, 1).unsqueeze(0)