import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import pytorch_lightning as pl
import os

class MagLeNet(pl.LightningModule):
    def __init__(self,p=1,power=1,l_grid=1.,l_pixel=1.,hamming=False):
        super().__init__()
        self.mag_layer = MagnitudeLayerProduct(p=p,power=power,l_grid=l_grid,l_pixel=l_pixel,hamming=hamming)
        self.min_max_layer = MinMaxLayer()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = self.mag_layer(x)
        x = self.min_max_layer(x)
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def training_step(self,batch,batch_idx):
        x,y = batch
        x = self.mag_layer(x)
        x = self.min_max_layer(x)
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x)
        x = self.fc2(x)
        y_hat =  F.log_softmax(x, dim=1)

        loss = F.nll_loss(y_hat, y)

        self.log('train_loss',loss)
        return loss

    def validation_step(self,batch,batch_idx):
        x,y = batch
        y_hat = F.log_softmax(self.forward(x), dim=1).argmax(dim=1,keepdim=False)
        return {'y':y,'y_hat':y_hat}

    def validation_epoch_end(self,outputs):
        y = torch.hstack([x['y'].view(-1) for x in outputs]).detach()
        y_hat = torch.hstack([x['y_hat'].view(-1) for x in outputs]).detach()
        precision = ((y == y_hat).sum().item()/y.shape[0])

        self.log('val_precision',precision)

    def test_step(self,batch,batch_idx):
        x,y = batch
        y_hat = F.log_softmax(self.forward(x), dim=1).argmax(dim=1,keepdim=False)
        return {'y':y,'y_hat':y_hat}

    def test_epoch_end(self,outputs):
        y = torch.hstack([x['y'].view(-1) for x in outputs]).detach().cpu()
        y_hat = torch.hstack([x['y_hat'].view(-1) for x in outputs]).detach().cpu()
        precision = ((y == y_hat).sum()/y.shape[0])
        self.log('test_precision',precision)

    def adv_test_step(self,x):
        out = F.log_softmax(self.forward(x), dim=1)
        return out

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(),lr=0.001)
        return optimizer

class MagnitudeLayerProduct(pl.LightningModule):
    def __init__(self,p=1,power=1,l_grid=1.,l_pixel=1.,hamming=False):
        super().__init__()
        self.grid = None
        self.p = p
        self.power = power
        self.l_grid = l_grid
        self.l_pixel = l_pixel
        self.hamming = hamming
    def forward(self,x):
        if self.grid is None:
            self.grid = self._generate_grid(x.shape[2],x.shape[3]).to(self.device)
        return self._magnitude_vec(x)
    def _generate_grid(self,x_pixel,y_pixel):
        xx = torch.linspace(0,x_pixel-1,x_pixel)
        yy = torch.linspace(0,y_pixel-1,y_pixel)
        grid = torch.meshgrid(xx,yy)
        grid_t = torch.stack(grid).view(2,-1).permute(1,0)
        return torch.cdist(grid_t,grid_t,p=self.p).unsqueeze(0)
    def _magnitude_vec(self,x):
        tmp_matrix = torch.abs(x.view(x.shape[0],x.shape[1],1,-1)-x.view(x.shape[0],x.shape[1],-1,1))
        if self.hamming:
            tmp_matrix[tmp_matrix>0.] = 1.
        tmp_matrix = self.l_grid*self.grid + self.l_pixel*torch.pow(tmp_matrix,self.power)
        tmp_matrix = torch.exp(-tmp_matrix)
        tmp_matrix = torch.inverse(tmp_matrix)
        return torch.sum(tmp_matrix,axis=-1).view(x.shape)

class MinMaxLayer(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self,x):
        mins = x.view(x.shape[0],x.shape[1],-1).min(-1).values.view(x.shape[0],x.shape[1],1,1)
        maxs = x.view(x.shape[0],x.shape[1],-1).max(-1).values.view(x.shape[0],x.shape[1],1,1)
        return (x-mins)/(maxs-mins)
