import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import pytorch_lightning as pl
import os

class MagnitudeLayer(pl.LightningModule):
    '''
    The simple 4 stage magnitude layer
    '''
    def __init__(self,p=1,l_grid=1.,l_pixel=1.):
        super().__init__()
        self.grid = None
        self.p = p
        self.l_grid = l_grid
        self.l_pixel = l_pixel
    def forward(self,x):
        if self.grid is None:
            self.grid = self._generate_grid(x.shape[2],x.shape[3]).to(self.device)
        return self._magnitude_vec(x)
    def _generate_grid(self,x_pixel,y_pixel):
        xx = torch.linspace(0,x_pixel-1,x_pixel)
        yy = torch.linspace(0,y_pixel-1,y_pixel)
        grid = torch.meshgrid(xx,yy)
        grid_t = torch.stack(grid).view(2,-1).permute(1,0)
        return grid_t
    def _magnitude_vec(self,x):
        shape = x.shape
        grid = torch.tile(self.grid,(shape[0],1,1))
        x = x.permute(0,2,3,1)
        tmp_matrix =  torch.cat([self.l_grid*grid, self.l_pixel*x.view(shape[0],-1,shape[1])],dim=2)
        tmp_matrix = torch.cdist(tmp_matrix,tmp_matrix,p=self.p)
        tmp_matrix = torch.exp(-tmp_matrix)
        tmp_matrix = torch.inverse(tmp_matrix)
        return torch.sum(tmp_matrix,axis=-1).view(shape)

class MagnitudeLayerQuant(pl.LightningModule):
    '''
    A quantised version of the magnitude layer. We also minmax scale the output.
    '''
    def __init__(self,p=1,l_grid=1.,l_pixel=1.,levels=10):
        super().__init__()
        self.grid = None
        self.p = p
        self.l_grid = l_grid
        self.l_pixel = l_pixel
        self.quant = QuantLayer(levels=levels)
        self.min_max = MinMaxLayer()
    def forward(self,x):
        if self.grid is None:
            self.grid = self._generate_grid(x.shape[2],x.shape[3]).to(self.device)
        z = self._magnitude_vec(x)
        # z = self.min_max(z)
        z = self.quant(z)
        z = torch.sum(z,axis=-1).view(x.shape)
        z = self.min_max(z)
        return z
    def _generate_grid(self,x_pixel,y_pixel):
        xx = torch.linspace(0,x_pixel-1,x_pixel)
        yy = torch.linspace(0,y_pixel-1,y_pixel)
        grid = torch.meshgrid(xx,yy)
        grid_t = torch.stack(grid).view(2,-1).permute(1,0)
        return grid_t
    def _magnitude_vec(self,x):
        shape = x.shape
        grid = torch.tile(self.grid,(shape[0],1,1))
        x = x.permute(0,2,3,1)
        tmp_matrix =  torch.cat([self.l_grid*grid, self.l_pixel*x.view(shape[0],-1,shape[1])],dim=2)
        tmp_matrix = torch.cdist(tmp_matrix,tmp_matrix,p=self.p)
        tmp_matrix = torch.exp(-tmp_matrix)
        tmp_matrix = torch.inverse(tmp_matrix)
        return tmp_matrix

class MagnitudeLayerProduct(pl.LightningModule):
    '''
    The magnitude layer with product space metric
    '''
    def __init__(self,p=1,power=1,l_grid=1.,l_pixel=1.,hamming=False):
        super().__init__()
        self.grid = None
        self.p = p
        self.power = power
        self.l_grid = l_grid
        self.l_pixel = l_pixel
        self.hamming = hamming
    def forward(self,x):
        if self.grid is None:
            self.grid = self._generate_grid(x.shape[2],x.shape[3]).to(self.device)
        return self._magnitude_vec(x)
    def _generate_grid(self,x_pixel,y_pixel):
        xx = torch.linspace(0,x_pixel-1,x_pixel)
        yy = torch.linspace(0,y_pixel-1,y_pixel)
        grid = torch.meshgrid(xx,yy)
        grid_t = torch.stack(grid).view(2,-1).permute(1,0)
        return torch.cdist(grid_t,grid_t,p=self.p).unsqueeze(0)
    def _magnitude_vec(self,x):
        tmp_matrix = torch.abs(x.view(x.shape[0],x.shape[1],1,-1)-x.view(x.shape[0],x.shape[1],-1,1))
        if self.hamming:
            tmp_matrix[tmp_matrix>0.] = 1.
        tmp_matrix = self.l_grid*self.grid + self.l_pixel*torch.pow(tmp_matrix,self.power)
        tmp_matrix = torch.exp(-tmp_matrix)
        tmp_matrix = torch.inverse(tmp_matrix)
        return torch.sum(tmp_matrix,axis=-1).view(x.shape)

class MagnitudeLayerProduct_grey(pl.LightningModule):
    '''
    The magnitude layer with product space metric for colour images. It produces a "greyscale" image.
    The product space magnitude layer considers each channel separately
    '''
    def __init__(self,p=1,power=1,l_grid=1.,l_pixel=1.):
        super().__init__()
        self.grid = None
        self.p = p
        self.power = power
        self.l_grid = l_grid
        self.l_pixel = l_pixel
    def forward(self,x):
        if self.grid is None:
            self.grid = self._generate_grid(x.shape[2],x.shape[3]).to(self.device)
        return self._magnitude_vec(x)
    def _generate_grid(self,x_pixel,y_pixel):
        xx = torch.linspace(0,x_pixel-1,x_pixel)
        yy = torch.linspace(0,y_pixel-1,y_pixel)
        grid = torch.meshgrid(xx,yy)
        grid_t = torch.stack(grid).view(2,-1).permute(1,0)
        return torch.cdist(grid_t,grid_t,p=self.p).unsqueeze(0)
    def _magnitude_vec(self,x):
        shape = x.shape
        x = x.permute(0,2,3,1).view(shape[0],-1,shape[1])
        tmp_matrix = self.l_grid*self.grid + self.l_pixel*torch.cdist(x,x,p=self.p)
        tmp_matrix = torch.exp(-tmp_matrix)
        tmp_matrix = torch.inverse(tmp_matrix)
        return torch.sum(tmp_matrix,axis=-1).view(shape[0],1,shape[2],shape[3])

class MagnitudeLayerProductQuant(pl.LightningModule):
    '''
    The qunatized version of the product space magnitude layer
    '''
    def __init__(self,p=1,power=1,l_grid=1.,l_pixel=1.,hamming=False,levels=10):
        super().__init__()
        self.grid = None
        self.p = p
        self.power = power
        self.l_grid = l_grid
        self.l_pixel = l_pixel
        self.hamming = hamming
        self.quant = QuantLayer(levels=levels)
        self.min_max = MinMaxLayer()
    def forward(self,x):
        if self.grid is None:
            self.grid = self._generate_grid(x.shape[2],x.shape[3]).to(self.device)
        z = self._magnitude_vec(x)
        # z = self.min_max(z)
        z = self.quant(z)
        z = torch.sum(z,axis=-1).view(x.shape)
        z = self.min_max(z)
        return z
    def _generate_grid(self,x_pixel,y_pixel):
        xx = torch.linspace(0,x_pixel-1,x_pixel)
        yy = torch.linspace(0,y_pixel-1,y_pixel)
        grid = torch.meshgrid(xx,yy)
        grid_t = torch.stack(grid).view(2,-1).permute(1,0)
        return torch.cdist(grid_t,grid_t,p=self.p).unsqueeze(0)
    def _magnitude_vec(self,x):
        tmp_matrix = torch.abs(x.view(x.shape[0],x.shape[1],1,-1)-x.view(x.shape[0],x.shape[1],-1,1))
        if self.hamming:
            tmp_matrix[tmp_matrix>0.] = 1.
        tmp_matrix = self.l_grid*self.grid + self.l_pixel*torch.pow(tmp_matrix,self.power)
        tmp_matrix = torch.exp(-tmp_matrix)
        tmp_matrix = torch.inverse(tmp_matrix)
        return tmp_matrix

class QuantLayer(nn.Module):
    '''
    The quantization layer
    '''
    def __init__(self,levels=2):
        super().__init__()
        if levels is not None:
            self.thresholds = levels-1 #Since we only consider the interior thresholds
        else:
            self.thresholds = 'None'
    def forward(self,x):
        if self.thresholds == 'None':
            return x
        elif self.thresholds == 0:
            return self.thresholds*x
        else:
            z = self.thresholds*x
            z = torch.round(z) #rounding as opposed to threshlding results in speedups
            return torch.clamp(z/self.thresholds, 0, 1)

class MinMaxLayer(nn.Module):
    '''
    A simple minmax scaler
    '''
    def __init__(self):
        super().__init__()
    def forward(self,x):
        mins = x.view(x.shape[0],x.shape[1],-1).min(-1).values.view(x.shape[0],x.shape[1],1,1)
        maxs = x.view(x.shape[0],x.shape[1],-1).max(-1).values.view(x.shape[0],x.shape[1],1,1)
        return (x-mins)/(maxs-mins)
