"""
Convolutional mesh-free generator
"""
import torch.nn as nn
import torch.nn.functional as F
import math
import torch
import numpy as np
import matplotlib.pyplot as plt


def squeeze(x , f):
    x = x.permute(0,2,3,1)
    b, N1, N2, nch = x.shape
    x = torch.reshape(
        torch.permute(
            torch.reshape(x, shape=[b, N1//f, f, N2//f, f, nch]),
            [0, 1, 3, 2, 4, 5]),
        [b, N1//f, N2//f, nch*f*f])
    x = x.permute(0,3,1,2)
    return x


def rev_squeeze(x,f):
    x = x.permute(0,2,3,1)
    b, N1, N2, nch = x.shape
    x = torch.reshape(torch.permute(
                torch.reshape(x, shape=[b, N1, N2, f, f, nch//f**2]),
                [0, 1, 3, 2, 4, 5]), [b, N1*f, N2*f, nch//f**2])

    x = x.permute(0,3,1,2)

    return x


def reflect_coords(ix, min_val, max_val):

    pos_delta = ix[ix>max_val] - max_val

    neg_delta = min_val - ix[ix < min_val]

    ix[ix>max_val] = ix[ix>max_val] - 2*pos_delta
    ix[ix<min_val] = ix[ix<min_val] + 2*neg_delta

    return ix

def grid_sample_customized(image, grid, pad = 'reflect'):

    #image is a tensor of shape (N, C, IH, IW)
    #grid is a tensor of shape (N, H, W, 2)
    #This function uses bilinear interpolation with padding mode 'zeros'.
    #Equivalent performance with torch.nn.functional.grid_sample can be obtained by setting align_corners = True
    #pad: 'border': use border pixels
    #     'reflect': create reflect pad manually



    N, C, IH, IW = image.shape
    _, H, W, _ = grid.shape

    ix = grid[..., 0]
    iy = grid[..., 1]

    ix = ((ix + 1) / 2) * (IW-1);
    iy = ((iy + 1) / 2) * (IH-1);
    with torch.no_grad():
        ix_nw = torch.floor(ix);
        iy_nw = torch.floor(iy);
        ix_ne = ix_nw + 1;
        iy_ne = iy_nw;
        ix_sw = ix_nw;
        iy_sw = iy_nw + 1;
        ix_se = ix_nw + 1;
        iy_se = iy_nw + 1;

    nw = (ix_se - ix)    * (iy_se - iy)
    ne = (ix    - ix_sw) * (iy_sw - iy)
    sw = (ix_ne - ix)    * (iy    - iy_ne)
    se = (ix    - ix_nw) * (iy    - iy_nw)


    if pad == 'reflect':

        ix_nw = reflect_coords(ix_nw, 0, IW-1)
        iy_nw = reflect_coords(iy_nw, 0, IH-1)

        ix_ne = reflect_coords(ix_ne, 0, IW-1)
        iy_ne = reflect_coords(iy_ne, 0, IH-1)

        ix_sw = reflect_coords(ix_sw, 0, IW-1)
        iy_sw = reflect_coords(iy_sw, 0, IH-1)

        ix_se = reflect_coords(ix_se, 0, IW-1)
        iy_se = reflect_coords(iy_se, 0, IH-1)


    elif pad == 'border':

        with torch.no_grad():
            torch.clamp(ix_nw, 0, IW-1, out=ix_nw)
            torch.clamp(iy_nw, 0, IH-1, out=iy_nw)

            torch.clamp(ix_ne, 0, IW-1, out=ix_ne)
            torch.clamp(iy_ne, 0, IH-1, out=iy_ne)

            torch.clamp(ix_sw, 0, IW-1, out=ix_sw)
            torch.clamp(iy_sw, 0, IH-1, out=iy_sw)

            torch.clamp(ix_se, 0, IW-1, out=ix_se)
            torch.clamp(iy_se, 0, IH-1, out=iy_se)


    image = image.reshape(N, C, IH * IW)


    nw_val = torch.gather(image, 2, (iy_nw * IW + ix_nw).long().view(N, 1, H * W).repeat(1, C, 1))
    ne_val = torch.gather(image, 2, (iy_ne * IW + ix_ne).long().view(N, 1, H * W).repeat(1, C, 1))
    sw_val = torch.gather(image, 2, (iy_sw * IW + ix_sw).long().view(N, 1, H * W).repeat(1, C, 1))
    se_val = torch.gather(image, 2, (iy_se * IW + ix_se).long().view(N, 1, H * W).repeat(1, C, 1))

    out_val = (nw_val.view(N, C, H, W) * nw.view(N, 1, H, W) +
               ne_val.view(N, C, H, W) * ne.view(N, 1, H, W) +
               sw_val.view(N, C, H, W) * sw.view(N, 1, H, W) +
               se_val.view(N, C, H, W) * se.view(N, 1, H, W))

    return out_val




def cropper(image, coordinate , output_size):
    # Coordinate shape: b X b_pixels X 2
    # image shape: b X b_pixels X c X h X w
    d_coordinate = coordinate * 2
    b , b_pixels , c , h , w = image.shape
    crop_size = 2 * output_size/h
    x_m_x = crop_size/2
    x_p_x = d_coordinate[:,:,1]
    y_m_y = crop_size/2
    y_p_y = d_coordinate[:,:,0]
    theta = torch.zeros(b, b_pixels, 2,3).to(image.device)
    theta[:,:,0,0] = x_m_x
    theta[:,:,0,2] = x_p_x
    theta[:,:,1,1] = y_m_y
    theta[:,:,1,2] = y_p_y

    image = image.reshape(b*b_pixels , c , h , w)
    theta = theta.reshape(b*b_pixels , 2 , 3)

    f = F.affine_grid(theta, size=(b * b_pixels, c, output_size, output_size), align_corners=False)
    image_cropped = F.grid_sample(image, f, mode = 'bicubic', align_corners=False, padding_mode='reflection')
    # image_cropped = grid_sample_customized(image, f, pad = 'reflect')


    return image_cropped



class SuperCNN(nn.Module):

    def __init__(self, c):
        super(SuperCNN, self).__init__()
        
        self.c = c

        CNNs = []
        prev_ch = self.c
        # num_layers = [128,128,128,64]
        # num_layers = [128,128,128,128,128,64]
        num_layers = [64,64,64,64,64,64,64,64]
        for i in range(len(num_layers)):
            CNNs.append(nn.Conv2d(prev_ch, num_layers[i] ,2,
                padding = 'same', bias = True))
            prev_ch = num_layers[i]

        self.CNNs = nn.ModuleList(CNNs)

        self.maxpool = nn.MaxPool2d(2, 2)
        # self.maxpool = nn.AvgPool2d(2, 2)
        self.linear1 = nn.Linear(2 * 2 * 64, 64, bias = True)
        self.linear2 = nn.Linear(64, 64, bias = True)
        self.linear3 = nn.Linear(64, 64, bias = True)
        self.linear4 = nn.Linear(64, self.c, bias = True)
        # self.linear1 = nn.Linear(2 * 2 * 64, self.c, bias = True)

        ws1 = torch.ones(1)
        self.ws1 = nn.Parameter(ws1.clone().detach(), requires_grad=True)
        ws2 = torch.ones(1)
        self.ws2 = nn.Parameter(ws2.clone().detach(), requires_grad=True)

        alpha1 = torch.zeros(1)
        self.alpha1 = nn.Parameter(alpha1.clone().detach(), requires_grad=False)
        alpha2 = torch.zeros(1)
        self.alpha2 = nn.Parameter(alpha2.clone().detach(), requires_grad=False)


    def cropper(self, image, coordinate , output_size):
        # Coordinate shape: b X b_pixels X 2
        # image shape: b X b_pixels X c X h X w
        d_coordinate = coordinate * 2
        b , b_pixels , c , h , w = image.shape
        crop_size = 2 * output_size/h
        x_m_x = crop_size/2
        x_p_x = d_coordinate[:,:,1]
        y_m_y = crop_size/2
        y_p_y = d_coordinate[:,:,0]
        theta = torch.zeros(b, b_pixels, 2,3).to(image.device)
        theta[:,:,0,0] = x_m_x * self.ws1
        theta[:,:,0,2] = x_p_x
        theta[:,:,1,1] = y_m_y * self.ws2
        theta[:,:,1,2] = y_p_y
        theta[:,:,0,1] = self.alpha1
        theta[:,:,1,0] = self.alpha2

        image = image.reshape(b*b_pixels , c , h , w)
        theta = theta.reshape(b*b_pixels , 2 , 3)

        f = F.affine_grid(theta, size=(b * b_pixels, c, output_size, output_size), align_corners=False)
        image_cropped = F.grid_sample(image, f, mode = 'bicubic', align_corners=False, padding_mode='reflection')
        # image_cropped = grid_sample_customized(image, f, pad = 'reflect')


        return image_cropped

       
    def forward(self, coordinate, x):
        b , b_pixels , _ = coordinate.shape
        x = torch.unsqueeze(x, dim = 1)
        x =x.expand(-1, b_pixels , -1, -1, -1)

        w_size = 9
        x = self.cropper(x , coordinate , output_size = w_size)
        mid_pix = x[:,:,4,4]

        for i in range(len(self.CNNs)):

            x_temp = x
            x = self.CNNs[i](x)
            x = F.relu(x)
            if i % 4 == 3:
                x = self.maxpool(x)
            else:
                if not (i ==0 or i == len(self.CNNs)-1):
                    x = x+x_temp

        x = torch.flatten(x, 1)

        x_tmp = F.relu(self.linear1(x)) #+ mid_pix
        x_tmp = F.relu(self.linear2(x_tmp)) + x_tmp
        x_tmp = F.relu(self.linear3(x_tmp)) + x_tmp
        x = self.linear4(x_tmp) + mid_pix
        x = x.reshape(b, b_pixels, -1)

        return x




class SuperCNNv2(nn.Module):
    def __init__(self, c):
        super(SuperCNNv2, self).__init__()
        self.c = c
        self.N_shifts = 1
        sigma_shift = 0.03

        CNNs = []
        prev_ch = self.c*self.N_shifts
        num_layers = [128,128,128,64]
        # num_layers = [128,128,128,128,128,64]
        for i in range(len(num_layers)):
            CNNs.append(nn.Conv2d(prev_ch, num_layers[i] ,2,
                padding = 'same', bias = True))
            prev_ch = num_layers[i]

        self.CNNs = nn.ModuleList(CNNs)

        self.maxpool = nn.MaxPool2d(2, 2)
        self.linear1 = nn.Linear(2 * 2 * 64, 100, bias = True)
        self.linear2 = nn.Linear(100, self.c, bias = True)

        self.shift = torch.randn(size=(self.N_shifts,1,1,2))*sigma_shift
        self.shift[0] = 0 # ensure the true crop is there

        self.w_size = 9
        self.crop_size = self.w_size
        x_m_x = self.crop_size/2
        y_m_y = self.crop_size/2

        theta = torch.zeros(1, 1, 2,3)#.to(device)
        theta[:,:,0,0] = x_m_x
        theta[:,:,1,1] = y_m_y
        theta = theta.reshape(1 , 2 , 3)
        self.theta = nn.Parameter(theta.clone().detach(), requires_grad=False)
        

    def cropper(self,image, coordinate , output_size):
        # Coordinate shape: b X b_pixels X 2
        # image shape: b X b_pixels X c X h X w
        # Only add the influence of the shift by the query position 
        d_coordinate = coordinate * 2
        b , b_pixels , c , h , w = image.shape
        # crop_size = 2 * output_size/h
        # x_m_x = crop_size/2
        x_p_x = d_coordinate[:,:,1]
        # y_m_y = crop_size/2
        y_p_y = d_coordinate[:,:,0]

        theta_ = torch.zeros(b, b_pixels, 2,3).to(image.device)
        # theta[:,:,0,0] = x_m_x
        theta_[:,:,0,2] = x_p_x
        # theta[:,:,1,1] = y_m_y
        theta_[:,:,1,2] = y_p_y
        theta_ = theta_.reshape(b*b_pixels , 2 , 3)

        image = image.reshape(b*b_pixels , c , h , w)

        f = F.affine_grid(self.theta+theta_, size=(b * b_pixels, c, output_size, output_size), align_corners=False)
        image_cropped = F.grid_sample(image, f, mode = 'bicubic', align_corners=False, padding_mode='reflection')

        # image_cropped = grid_sample_customized(image, f, pad = 'reflect')

        return image_cropped
       
    def forward(self, coordinate, x):
        b , b_pixels , _ = coordinate.shape
        x = torch.unsqueeze(x, dim = 1)
        x =x.expand(-1, b_pixels , -1, -1, -1)

        # x = cropper(x , coordinate , output_size = w_size)
        self.shift = self.shift.to(x.device).requires_grad_(False)
        # i=0
        # tmp = self.cropper(x , coordinate + self.shift[i], output_size = self.w_size)
        x = torch.cat([ self.cropper(x , coordinate + self.shift[i], output_size = self.w_size) for i in range(self.N_shifts) ],dim=1)
        mid_pix = x[:,:3,4,4]

        for i in range(len(self.CNNs)):
            x_temp = x
            x = self.CNNs[i](x)
            x = F.relu(x)
            # x = self.maxpool(x)
            if i % 2 == 0:
                x = self.maxpool(x)
            else:
                if not (i ==0 or i == len(self.CNNs)-1):
                    x = x+x_temp

        x = torch.flatten(x, 1)

        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x) + mid_pix
        x = x.reshape(b, b_pixels, -1)

        return x






class autoencoder(nn.Module):

    def __init__(self, encoder=None, decoder=None):
        super(autoencoder, self).__init__()
        
        self.encoder = encoder
        self.decoder = decoder


class encoder(nn.Module):

    def __init__(self, latent_dim=256, in_res=64, c=3):
        super(encoder, self).__init__()
        
        self.in_res = in_res
        self.c = c
        prev_ch = c
        c_last = 256
        CNNs = []
        CNNs_add = []
        num_layers = [64,128,128,256]
        for i in range(len(num_layers)):
            CNNs.append(nn.Conv2d(prev_ch, num_layers[i] ,3,
                padding = 'same'))
            CNNs_add.append(nn.Conv2d(num_layers[i], num_layers[i] ,3,
                padding = 'same'))
            CNNs_add.append(nn.Conv2d(num_layers[i], num_layers[i] ,3,
                padding = 'same'))
            prev_ch = num_layers[i]
        
        if in_res == 64:
            num_layers = [c_last]
            for i in range(len(num_layers)):
                CNNs.append(nn.Conv2d(prev_ch, num_layers[i] ,3,
                                      padding = 'same'))
                CNNs_add.append(nn.Conv2d(num_layers[i], num_layers[i] ,3,
                    padding = 'same'))
                CNNs_add.append(nn.Conv2d(num_layers[i], num_layers[i] ,3,
                    padding = 'same'))
                prev_ch = num_layers[i]
        
        if in_res == 128:
            num_layers = [256,c_last]
            for i in range(len(num_layers)):
                CNNs.append(nn.Conv2d(prev_ch, num_layers[i] ,3,
                    padding = 'same'))
                CNNs_add.append(nn.Conv2d(num_layers[i], num_layers[i] ,3,
                    padding = 'same'))
                CNNs_add.append(nn.Conv2d(num_layers[i], num_layers[i] ,3,
                    padding = 'same'))
                prev_ch = num_layers[i]

            
            
        
        self.CNNs = nn.ModuleList(CNNs)
        self.CNNs_add = nn.ModuleList(CNNs_add)
        self.maxpool = nn.MaxPool2d(2, 2)
        
        feature_dim = 2 * 2 * c_last
        mlps = []
        mlps.append(nn.Linear(feature_dim, latent_dim))
        # mlps.append(nn.Linear(feature_dim, 2*latent_dim))
        # mlps.append(nn.Linear(2*latent_dim , latent_dim))

        self.mlps = nn.ModuleList(mlps)
       

    def forward(self, x):

        x_skip = torch.mean(x , dim = 1, keepdim = True)
        for i in range(len(self.CNNs)):
            x = self.CNNs[i](x)
            x = F.relu(x)
            xm = self.maxpool(x)
            if i < 4:

                f = 2**(i+1)
                xm = xm + squeeze(x_skip , f).repeat_interleave(xm.shape[1]//(f**2) , dim = 1)

            x = self.CNNs_add[i*2](xm)
            x = F.relu(x)
            x = self.CNNs_add[i*2 + 1](x)
            x = F.relu(x)
            x = x + xm
        # x = rev_squeeze(x , 16)
        # b = x.shape[0]
        # x = x.reshape([b , -1])
        # x = torch.tanh(x)
        x = torch.flatten(x, 1)
        for i in range(len(self.mlps)-1):
            x = self.mlps[i](x)
            x = F.relu(x)
        
        x = self.mlps[-1](x)
        
        return x


class decoder(nn.Module):

    def __init__(self, latent_dim=256, in_res=64, c=3):
        super(decoder, self).__init__()
        
        self.in_res = in_res
        self.c = c
        prev_ch = 256
        t_CNNs = []
        CNNs = []

        if in_res == 128:
            num_layers = [256,256,128,128,64,self.c]
            for i in range(len(num_layers)):
                # t_CNNs.append(nn.ConvTranspose2d(prev_ch, num_layers[i] ,3,
                #     stride=2,padding = 1, output_padding=1))
                c_inter = 64 if num_layers[i] == self.c else num_layers[i]
                t_CNNs.append(nn.Conv2d(prev_ch, c_inter ,3,
                    padding = 'same'))
                CNNs.append(nn.Conv2d(c_inter, c_inter ,3,
                    padding = 'same'))
                CNNs.append(nn.Conv2d(c_inter, num_layers[i] ,3,
                    padding = 'same'))
                prev_ch = num_layers[i]

        elif in_res == 64:

            num_layers = [256,128,128,64,self.c]
            for i in range(len(num_layers)):
                # t_CNNs.append(nn.ConvTranspose2d(prev_ch, num_layers[i] ,3,
                #     stride=2,padding = 1, output_padding=1))
                c_inter = 64 if num_layers[i] == self.c else num_layers[i]
                t_CNNs.append(nn.Conv2d(prev_ch, c_inter ,3,
                    padding = 'same'))
                CNNs.append(nn.Conv2d(c_inter, c_inter ,3,
                    padding = 'same'))
                CNNs.append(nn.Conv2d(c_inter, num_layers[i] ,3,
                    padding = 'same'))
                prev_ch = num_layers[i]
        

            
        self.t_CNNs = nn.ModuleList(t_CNNs)
        self.CNNs = nn.ModuleList(CNNs)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        
        self.feature_dim = 2 * 2 * 256
        mlps = []
        # mlps.append(nn.Linear(latent_dim, 2*latent_dim))
        # mlps.append(nn.Linear(2*latent_dim , self.feature_dim))
        mlps.append(nn.Linear(latent_dim , self.feature_dim))

        self.mlps = nn.ModuleList(mlps)
       
    def forward(self, x):

        for i in range(len(self.mlps)):
            x = self.mlps[i](x)
            x = F.relu(x)
        # x = squeeze(x , 16)
        b = x.shape[0]
        x = x.reshape([b, 256, 2, 2])
        
        for i in range(len(self.t_CNNs)-1):
            x = self.upsample(x)
            x = self.t_CNNs[i](x)
            xr = F.relu(x)
            x = self.CNNs[i*2](xr)
            x = F.relu(x)
            x = self.CNNs[i*2+1](x)
            x = F.relu(x)
            x = x + xr

        x = self.upsample(x)
        x = self.t_CNNs[-1](x)
        x = F.relu(x)
        x = self.CNNs[-2](x)
        x = F.relu(x)
        x = self.CNNs[-1](x)

        return x












def test_SuperCNN():
    coordinate = torch.rand(10,100,2)

    enc = encoder(latent_dim = 128, in_res = 256, c = 3)
    dec = decoder(latent_dim = 128, in_res = 256 , c = 3)

    net = SuperCNN(encoder = enc , decoder = dec)
    print(net.parameters)
    z = torch.rand(10,128) 
    out = net(coordinate , z)
    print(out.shape)


def test_cropper():
    from skimage import data
    x = data.astronaut()/255.0
    x = x.transpose(2,0,1)
    x = torch.tensor(x , dtype = torch.float32)
    # x = torch.unsqueeze(x , dim = 0)
    print(x.shape)
    x = x.expand(10,-1,-1,-1)

    # coordinate = torch.random.uniform(-1 , 1 , )

    cropper(x)

    # print(x.shape)

    
    
if __name__ == '__main__':
    # test_cropper()
    test_SuperCNN()


