import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt
import imageio.v2 as imageio
from tqdm.autonotebook import tqdm

from sensor import *

T = 0.5
spp = 8
sensor_size = (128, 256)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
img1 = imageio.imread('notebooks/bird.png') / 255.
img2 = imageio.imread('notebooks/pepper.png') / 255.
img3 = imageio.imread('notebooks/cityscapes_example1.png') / 255.
img4 = imageio.imread('notebooks/cityscapes_example1.png') / 255.
img5 = imageio.imread('notebooks/cityscapes_example1.png') / 255.
img1 = torch.from_numpy(img1).permute(2, 0, 1).unsqueeze(0).float().to(device)[:, :3, :, :]
img2 = torch.from_numpy(img2).permute(2, 0, 1).unsqueeze(0).float().to(device)[:, :3, :, :]
img2 = F.interpolate(img2, size=img1.shape[-2:], mode='bilinear', align_corners=False)
img3 = torch.from_numpy(img3).permute(2, 0, 1).unsqueeze(0).float().to(device)[:, :3, :, :]
img4 = torch.from_numpy(img4).permute(2, 0, 1).unsqueeze(0).float().to(device)[:, :3, :, :]
img5 = torch.from_numpy(img5).permute(2, 0, 1).unsqueeze(0).float().to(device)[:, :3, :, :]

#img = torch.cat((img1, img2), dim=0)
img = torch.cat((img3,img4), dim=0)

#deformation = PowerDeformation(1.0)
deformation = Monotone_Linear_Spline_Deformation(12)
#deformation = IndependentAnisotropicHalfNormalTunableSigmoidDeformation(1.0)
#deformation = VectorFieldDeformation(CubicBumpVectorField(1.0))

def _test():
    # # build ramp image that gradually increases from 0 to 1 from left to right
    # ramp = torch.linspace(0, 1, img.shape[-1], device=img.device)
    # ramp = ramp.unsqueeze(0).unsqueeze(0).repeat(img.shape[0], img.shape[1], 1, 1)
    # img = ramp.permute(0, 1, 3, 2)

    t = torch.tensor([0.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], requires_grad=True, device=device)
    sensor_img = simulate(img, deformation, t, sensor_size, spp)

    L = sensor_img.mean()
    L.backward()

    print(t.grad)

    sensor_img_np = sensor_img.detach().cpu().numpy()

    return sensor_img_np

def _testFiniteDifference():
    # # build ramp image that gradually increases from 0 to 1 from left to right
    # ramp = torch.linspace(0, 1, img.shape[-1], device=img.device)
    # ramp = ramp.unsqueeze(0).unsqueeze(0).repeat(img.shape[0], img.shape[1], 1, 1)
    # img = ramp.permute(0, 1, 3, 2)
    t = torch.tensor([0.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], device=device)
    dt = 1e-2
    sensor_img_ = simulate(img, deformation, t, sensor_size, 4*spp)
    t[0] += dt
    sensor_img = simulate(img, deformation, t, sensor_size, 4*spp)

    dImg = (sensor_img - sensor_img_) / (dt)

    print(dImg.mean(), dImg.var())

    return dImg.detach().cpu().numpy()

def _testJacobian():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    X = torch.linspace(-1, 1, 512, device=device)
    Y = torch.linspace(-1, 1, 512, device=device)
    X, Y = torch.meshgrid(X, Y)
    X = torch.stack((X, Y), dim=2)

    V = CubicBumpVectorField(1.0)
    t = torch.tensor([T, T], requires_grad=True, device=device)

    _, jac = V.integrate(X, t.item(), computeJacobian=True)
    det2 = jac[..., 0, 0] * jac[..., 1, 1] - jac[..., 0, 1] * jac[..., 1, 0]
    plt.imshow(det2.detach().cpu().numpy())
    plt.colorbar()
    plt.show()

    _, det = V.integrate(X, t.item(), computeJacobian=False)
    plt.imshow(det.detach().cpu().numpy())
    plt.colorbar()
    plt.show()

    plt.imshow((det2 - det).detach().cpu().numpy())
    plt.colorbar()
    plt.show()

    assert torch.allclose(det, det2, atol=1e-3)

def _testForwardAD():
    # # build ramp image that gradually increases from 0 to 1 from left to right
    # ramp = torch.linspace(0, 1, img.shape[-1], device=img.device)
    # ramp = ramp.unsqueeze(0).unsqueeze(0).repeat(img.shape[0], img.shape[1], 1, 1)
    # img = ramp.permute(0, 1, 3, 2)

    with torch.autograd.forward_ad.dual_level():
        t = torch.tensor([0.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], device=device)
        t_tangent = torch.tensor([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], device=device)
        dual_input = torch.autograd.forward_ad.make_dual(t, t_tangent)
        sensor_img = simulate(img, deformation, dual_input, sensor_size, spp)

        jvp = torch.autograd.forward_ad.unpack_dual(sensor_img)

        grad = jvp.tangent

        print(grad.mean(), grad.var())
        sensor_img_np = grad.detach().cpu().numpy()

        return sensor_img_np

def _testGradients():
    from mpl_toolkits.axes_grid1 import make_axes_locatable

    fig, ax = plt.subplots(1, 3, figsize=(12, 4))

    sensor_img = _test()
    grad_analytic = _testForwardAD()
    grad_numeric = _testFiniteDifference()

    ax[0].set_title('Sensor image')
    ax[1].set_title('Analytic gradient')
    ax[2].set_title('Numeric gradient')

    ax[0].imshow(sensor_img[0].transpose((1, 2, 0)))

    divider = make_axes_locatable(ax[1])
    im1 = ax[1].imshow(grad_analytic[0][0], vmin=-20, vmax=20)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im1, cax=cax, orientation='vertical')

    divider = make_axes_locatable(ax[2])
    im2 = ax[2].imshow(grad_numeric[0][0], vmin=-20, vmax=20)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    fig.colorbar(im2, cax=cax, orientation='vertical')

    plt.show()

class TestModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.sensor = FoveatedSensor(sensor_size, spp=8, deform=deformation, constrain_t=False)
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 3, 3, padding=1),
        )

    def forward(self, img):
        sensor_img = self.sensor(img)
        #return self.cnn(sensor_img)
        return sensor_img


def _testLayer():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = TestModel().to(device)

    radiance = img

    t = deformation.getNeutralParameter().clone().to(device)
    t[0] = -0.5
    t[1] = 0.5
    t[2] = 2.0
    t[3] = 0.1
    target = simulate(radiance, deformation, t, sensor_size, spp)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    sensor_img = model.sensor(radiance)
    output = model(radiance)
    fig, ax = plt.subplots(1, 2, figsize=(12, 4))
    ax[0].imshow(sensor_img[0].detach().cpu().numpy().transpose((1, 2, 0)))
    ax[1].imshow(target[0].detach().cpu().numpy().transpose((1, 2, 0)))
    plt.show()

    for epoch in tqdm(range(500)):
        optimizer.zero_grad()
        sensor_img = model(radiance)
        loss = F.mse_loss(sensor_img, target)
        loss.backward()
        optimizer.step()

        print("Sensor t: {}".format(model.sensor.t.detach().cpu().numpy()))
        print('Epoch: {}, Loss: {}'.format(epoch, loss.item()))

    print('Final t: {}'.format(model.sensor.t.detach().cpu().numpy()))
    sensor_img = model(radiance)
    fig, ax = plt.subplots(1, 2, figsize=(12, 4))
    ax[0].imshow(sensor_img[0].detach().cpu().numpy().transpose((1, 2, 0)))
    ax[1].imshow(target[0].detach().cpu().numpy().transpose((1, 2, 0)))
    plt.show()

def _testForwardBackwardWarp():
    orig = img.clone()
    h, w = img.shape[-2:]
    print(h, w)

    t = torch.tensor([-0.5, 0.2], requires_grad=False, device=device)
    target = simulate(img, deformation, t, sensor_size, spp)

    step_x = (1 + 1) / 2048
    step_y = (1 + 1) / 1024
    pixel_pos_x = torch.arange(-1, 1, step_x, device=img.device)
    pixel_pos_y = torch.arange(-1, 1, step_y, device=img.device)
    pixel_pos_x, pixel_pos_y = torch.meshgrid(pixel_pos_x, pixel_pos_y, indexing='xy')
    pixel_pos = torch.stack((pixel_pos_x, pixel_pos_y), dim=2)

    pixel_pos = deformation.inverse(pixel_pos, t)
    pixel_pos = einops.repeat(pixel_pos, 'h w c -> b h w c', b=2)

    test = F.grid_sample(target, pixel_pos, align_corners=True, mode='nearest')
    plt.imshow(test[0].detach().cpu().numpy().transpose((1, 2, 0)))
    plt.show()

    plt.imshow(pixel_pos[0].detach().cpu().numpy()[..., 1])
    plt.show()

if __name__ == '__main__':
    _testLayer()