

import matplotlib.pyplot as plt
import torch
import numpy as np
import math

def detector_region(x):
    x = x.abs()
    detectors = torch.cat((
        x[:, 46 : 66, 46 : 66].mean(dim=(1, 2)).unsqueeze(-1),
        x[:, 46 : 66, 93 : 113].mean(dim=(1, 2)).unsqueeze(-1),
        x[:, 46 : 66, 140 : 160].mean(dim=(1, 2)).unsqueeze(-1),
        x[:, 85 : 105, 46 : 66].mean(dim=(1, 2)).unsqueeze(-1),
        x[:, 85 : 105, 78 : 98].mean(dim=(1, 2)).unsqueeze(-1),
        x[:, 85 : 105, 109 : 129].mean(dim=(1, 2)).unsqueeze(-1),
        x[:, 85 : 105, 140 : 160].mean(dim=(1, 2)).unsqueeze(-1),
        x[:, 125 : 145, 46 : 66].mean(dim=(1, 2)).unsqueeze(-1),
        x[:, 125 : 145, 93 : 113].mean(dim=(1, 2)).unsqueeze(-1),
        x[:, 125 : 145, 140 : 160].mean(dim=(1, 2)).unsqueeze(-1)), dim=-1)
    return detectors
class DiffractiveLayer(torch.nn.Module):

    def __init__(self):
        super(DiffractiveLayer, self).__init__()
        self.size = 400                         # 200 * 200 neurons in one layer
        self.distance = 0.2794                    # distance bewteen two layers (3cm)
        self.ll = 0.0144                          # layer length (8cm)
        self.wl = 5.32e-7                  # wave length
        self.fi = 1 / self.ll                   # frequency interval
        self.wn = 2 * 3.1415926 / self.wl       # wave number
        # self.phi (200, 200)
        self.phi = np.fromfunction(
            lambda x, y: np.square((x - (self.size // 2)) * self.fi) + np.square((y - (self.size // 2)) * self.fi),
            shape=(self.size, self.size), dtype=np.complex64)
        # h (200, 200)
        print('constructed with Fresnel!')
        h = np.fft.fftshift(np.exp(1.0j * self.wn * self.distance) * np.exp(-1.0j * self.wl * np.pi * self.distance * self.phi))
        # self.h (200, 200, 2)
        self.h = torch.nn.Parameter(torch.view_as_complex(torch.stack((torch.from_numpy(h.real), torch.from_numpy(h.imag)), dim=-1)), requires_grad=False)

    def forward(self, waves):
        # waves (batch, 200, 200, 2)
        waves = torch.nn.functional.pad(waves, (100,100,100,100))
        temp = torch.fft.ifft2( torch.fft.fft2(waves) * self.h )
        temp = torch.nn.functional.pad(temp, (-100,-100,-100,-100))

        return temp

class Net(torch.nn.Module):
    """
    phase only modulation
    """
    def __init__(self, num_layers=5, bits=8):

        super(Net, self).__init__()
        # self.phase (200, 200)
        self.size = 200
        self.phase = [torch.nn.Parameter(torch.from_numpy(2 * np.pi * np.random.random(size=(200, 200)).astype('float32'))) for _ in range(num_layers)]
        self.voltage = [torch.nn.Parameter(torch.from_numpy(np.random.uniform(low=0,high=255,size=(200, 200)).astype('float32'))) for _ in range(num_layers)]
        
        for i in range(num_layers):
          self.register_parameter("voltage" + "_" + str(i), self.voltage[i])
        self.diffractive_layers = torch.nn.ModuleList([DiffractiveLayer() for _ in range(num_layers)])
        self.last_diffractive_layer = DiffractiveLayer()
        self.softmax = torch.nn.Softmax(dim=-1)
        self.bn = [torch.nn.BatchNorm1d(self.size).cuda() for _ in range(num_layers)]
        self.bn2 = [torch.nn.BatchNorm1d(self.size).cuda() for _ in range(num_layers)]
        self.bits=bits

    def forward(self, x):
        # x (batch, 200, 200, 2)
        for index, layer in enumerate(self.diffractive_layers):
            temp = layer(x)
            #exp_j_phase = -6.371e-3*self.voltage[index] + 2.31224e-6*self.voltage[index]**3 - 5.165e-11*self.voltage[index]**5 + 3.469e-16*self.voltage[index]**7
            if self.bits == 8:
                exp_j_phase =  -5.79276e-3*self.voltage[index] + 2.43589e-6  *self.voltage[index]**3 - 5.7496e-11  *self.voltage[index]**5 + 4.0797e-16 * self.voltage[index]**7
                amplitude = 1.88716e-3*self.voltage[index] - 1.1558e-4 *self.voltage[index]**2 + 1.0287e-6 *self.voltage[index]**3 - 2.284437e-9 * self.voltage[index]**4 + 0.58789
            elif self.bits == 12:
                exp_j_phase = -5.56470e-3*self.voltage[index] + 2.302377e-6  *self.voltage[index]**3 - 5.197037e-11  *self.voltage[index]**5 + 3.5165e-16 * self.voltage[index]**7
                amplitude = 2.06236e-3*self.voltage[index] - 1.154027e-4 *self.voltage[index]**2 + 1.002995e-6 *self.voltage[index]**3 - 2.191565e-9 * self.voltage[index]**4 + 0.5832887
            elif self.bits == 16:
                exp_j_phase = -6.64304e-3*self.voltage[index] + 2.37064e-6 * self.voltage[index]**3 - 5.328276e-11 * self.voltage[index]**5 + 3.5963e-16  *self.voltage[index]**7
                amplitude = 2.387192e-3*self.voltage[index] - 1.21138558e-4 *self.voltage[index]**2 + 1.03988555e-6 *self.voltage[index]**3 - 2.26933e-9*self.voltage[index]**4 + 0.57865
            #amplitude = (2.42708 / (1 + 22846.7 * c))
            #amplitude = 1.8 * (3.6e-3 * self.voltage[index] - 1.41e-4*self.voltage[index]**2
            #                    + 1.15e-6*self.voltage[index]**3 - 2.47e-9*self.voltage[index]**4 + 0.56)

            #print(amplitude)
            phase_trig_form = torch.view_as_complex(torch.stack((amplitude*torch.cos(exp_j_phase), amplitude*torch.sin(exp_j_phase)), dim=-1))
            x = temp * phase_trig_form
        #print(x)
        x = self.last_diffractive_layer(x)
        #print(x)
        #x_abs (batch, 200, 200)
        x_abs = x.abs()
        output = self.softmax(detector_region(x_abs))
        return output

if __name__ == '__main__':
    print(Net())


