import os
import csv
from time import time
import random
import pathlib
import argparse
import numpy as np
from tqdm import tqdm
import pandas as pd
import torch
import torchvision
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import pickle
import layers as layers
import utils as utils
import data as dataset

device='cuda:0'

class DiffractiveLayer(torch.nn.Module):
    """ Implementation of diffractive layer that enables device-quantization aware training using Gumble-Softmax 

    Args:
        phase_func: phase space designed in a given hardware device, where the index represents the applying voltage/grayvalue (e.g., SLM)
        intensity_func: intensity space designed in a given hardware device, where the index represents the applying voltage/grayvalue (e.g., SLM)
        size: system size 
        distance: diffraction distance
        name: name of the layer
        precision: hardware precision encoded in number of possible values of phase or amplitude-phase. default: 256 (8-bit)
        amplitude_factor: training regularization factor w.r.t amplitude vs phase in backpropogation
        phase_mod: enable phase modulation or just diffraction. default: True
    Shape:
        - Input: :math:`(*)`. Input can be of any shape
        - Output: :math:`(*)`. Output is of the same shape as input

    Examples::

    Reference: 
    """

    def __init__(self, phase_func, intensity_func, tau=10, wavelength=5.32e-7, pixel_size=3.6e-5, fill_factor=0.58, size=200, pad = 0, distance=0.1, name="diffractive_layer",
                                precision=256, amplitude_factor = 6, mesh_size = 1, approx = "Fresnel", phase_mod=True):
                                #precision=256, amplitude_factor = 6, Fresnel=False, Fraunhofer=False,phase_mod=True):
        super(DiffractiveLayer, self).__init__()
        self.size = size                         # 200 * 200 neurons in one layer
        self.distance = distance                    # distance bewteen two layers (3cm)
        self.pad = pad ## pad has to be >0
        self.ll = pixel_size * (self.size + self.pad*2)                          # layer length (8cm)
        self.wl = wavelength                  # wave length
        self.fi = 1 / self.ll                   # frequency interval
        self.wn = 2 * 3.1415926 / self.wl       # wave number
        self.pixel_size = pixel_size
        self.ddi = 1 / self.pixel_size
        self.fill_factor = fill_factor
        self.approx = approx
        self.mesh_size = mesh_size
        self.tau = tau
        self.precision = precision
        # self.phi (syssize, syssize)
        self.phi = np.fromfunction(
            lambda x, y: np.square((x - ( (self.size+self.pad*2) // 2)) * self.fi) + np.square((y - ( (self.size+self.pad*2) // 2)) * self.fi),
            shape=(self.size+self.pad*2, self.size+self.pad*2), dtype=np.complex64)
        # h (syssize, syssize)

        if self.approx == "Fresnel":
            print("Network is constructed using Fresnel approximation")
            h = np.fft.fftshift(np.exp(1.0j * self.wn * self.distance) * np.exp(-1.0j * self.wl * np.pi * self.distance * self.phi))
            # self.h (syssize, syssize, 2)
            self.h = torch.nn.Parameter(torch.view_as_complex(torch.stack((torch.from_numpy(h.real), torch.from_numpy(h.imag)), dim=-1)), requires_grad=False)
        elif self.approx == "Fraunhofer":
            print('Network is constructed using fraunhofer.4 approximation')
            wn = self.wn
            distance = self.distance
            r = np.fromfunction(
                   lambda x, y: np.square((x - (self.size + self.pad * 2) / 2) * self.pixel_size) + np.square((y - (self.size + self.pad * 2) / 2 ) * self.pixel_size), shape=(self.size+self.pad*2, self.size+self.pad*2), dtype=np.float64)

            temp = np.fromfunction(
                   lambda x, y: np.sinc(wn * (x - (self.size + self.pad * 2) /2 ) * self.pixel_size / distance * self.pixel_size*np.sqrt(self.fill_factor) / 2 /np.pi) * np.sinc(wn * (y - (self.size + self.pad * 2) / 2) * self.pixel_size / distance * self.pixel_size*np.sqrt(self.fill_factor) / 2 /np.pi),
                   shape=(self.size+self.pad*2, self.size+self.pad*2), dtype=np.float64)
            h = temp * np.exp(1.0j * wn * r / (2 * distance))* np.exp(1.0j * wn * distance)/(1.0j * 2 * np.pi/wn * distance) * self.pixel_size * self.pixel_size * self.fill_factor
            h = torch.from_numpy(h)
            h = torch.fft.fftshift(h)
            self.h = torch.nn.Parameter(torch.fft.fft2(h.to(torch.complex64)), requires_grad=False)
        elif self.approx == "Fresnel2":
            print('Network is constructed using Fresnel.2 approximation')
            wn = self.wn
            distance = self.distance
            # Grid Mesh - Still Fresnel for each point
            self.size = size * self.mesh_size
            self.pixel_size = pixel_size/self.mesh_size
            self.ll = self.pixel_size * self.size
            self.fi = 1 / self.ll
            self.pad = pad
            r = np.fromfunction(
                lambda x, y: np.square((x - (self.size + self.pad*2) // 2) * self.pixel_size) + np.square((y - (self.size + self.pad*2) // 2) * self.pixel_size), shape=(self.size + self.pad*2, self.size + self.pad*2), dtype=np.float64)

            h = np.exp(1.0j * self.wn * self.distance) * np.exp(1.0j * self.wn/2/distance * r)/(1.0j * self.wl * distance)
            h = torch.fft.ifftshift(torch.from_numpy(h))
            self.h = torch.nn.Parameter(torch.fft.fft2(h.to(torch.complex64)), requires_grad=False)
        elif self.approx == "Sommerfeld":
            print("Network is constructed using Sommerfeld approximation")
            wn = self.wn * self.pixel_size
            distance = self.distance * self.ddi
            r = np.fromfunction(
                    lambda x, y: np.square((x - ((self.size + self.pad * 2)//2))) + np.square((y - ((self.size + self.pad * 2) // 2))) + np.square(distance),
                    shape = ((self.size + self.pad * 2), (self.size + self.pad * 2)), dtype=np.float64)
            r = torch.from_numpy(r)
            h = 1 / (2 * np.pi) * distance / r
            r = np.sqrt(r)
            temp = wn * r
            temp = torch.view_as_complex(torch.stack((torch.cos(temp), torch.sin(temp)), dim=-1))
            h = h * (1 / r - 1.0j * wn) * temp
            h = torch.fft.fftshift(h)
            self.h = torch.nn.Parameter(torch.fft.fft2(h.to(torch.complex64)), requires_grad=False)
        else:
            assert(0), "approximation function %s is not implemented; currently supporting Fresnel,Freshnel2, Sommerfeld,Fraunhofer"

        # initialization with gumbel softmax (random one-hot encoding for voltage)
        self.voltage = torch.nn.Parameter(torch.nn.functional.gumbel_softmax(
                                torch.from_numpy(np.random.uniform(low=0,high=1,
                                        size=(self.size, self.size, self.precision)).astype('float32')),tau=self.tau, hard=True))
        self.register_parameter(name, self.voltage)
        self.phase_func = phase_func
        self.intensity_func = intensity_func
        self.phase_model = phase_mod
        self.amplitude_factor = amplitude_factor
        assert (self.pad > 0), "padding in forward diffraction has to be greater than 0 (need more explainations here)"

    def forward(self, waves):
        # waves (batch, 200, 200, 2)
        waves = torch.nn.functional.pad(waves, (self.pad,self.pad,self.pad,self.pad)) # pad to eliminate perodic effects 
        temp = torch.fft.ifft2(torch.fft.fft2(waves) * self.h) # prop
        temp = torch.nn.functional.pad(temp, (-self.pad,-self.pad,-self.pad,-self.pad)) # reverse pad for next prop (center crop)
        if not self.phase_model:
            return temp
        exp_j_phase = torch.matmul(torch.nn.functional.gumbel_softmax(self.voltage,tau=self.tau, hard=True), self.phase_func)
        # mimic look-up-table matching for amplitude vectors
        amplitude = torch.matmul(torch.nn.functional.gumbel_softmax(self.voltage,tau=self.tau, hard=True),
                                self.intensity_func) * self.amplitude_factor # amplitude_factor is a training regularization term
        phase_trig_form = torch.view_as_complex(torch.stack((torch.mul(amplitude,torch.cos(exp_j_phase)), torch.mul(amplitude,torch.sin(exp_j_phase))), dim=-1))
        x = temp * phase_trig_form
        return x

class DiffractiveClassifier_CoDesign(torch.nn.Module):
    def __init__(self, phase_func, intensity_func, device, tau=10, wavelength=5.32e-7, pixel_size=0.000036, batch_norm=False, sys_size = 200, pad = 100, distance=0.1, num_layers=2, precision=256, amp_factor=6, approx="Fresnel"):
        super(DiffractiveClassifier_CoDesign, self).__init__()
        self.amp_factor = amp_factor
        self.tau = tau
        self.size = sys_size
        self.distance = distance
        self.wavelength = wavelength
        self.pixel_size = pixel_size
        self.pad = pad
        self.approx=approx
        self.phase_func = phase_func.to(device)
        self.intensity_func = intensity_func.to(device)
        self.approx=approx
        self.precision = precision
        self.diffractive_layers = torch.nn.ModuleList([DiffractiveLayer(self.phase_func, self.intensity_func, tau=self.tau, wavelength=self.wavelength, pixel_size=self.pixel_size,
                size=self.size,pad = self.pad, precision = self.precision, distance=self.distance, amplitude_factor = amp_factor, approx=self.approx, phase_mod=True) for _ in range(num_layers)])
        self.last_diffraction = DiffractiveLayer(None, None, tau=self.tau, wavelength=self.wavelength, pixel_size=self.pixel_size,
                                size=self.size, pad = self.pad, precision=self.precision, distance=self.distance, approx=self.approx, phase_mod=False)

        self.detector = layers.Detector_10(start_x = [46,46,46], start_y = [46,46,46], det_size = 20,
                                        gap_x = [19,20], gap_y = [27, 12, 27])
    def forward(self, x):
        #print('tau', self.tau)
        for index, layer in enumerate(self.diffractive_layers):
            x = layer(x)
        x = self.last_diffraction(x)
        output = self.detector(x)
        return output

    def to_slm(self, fname):
        for index, layer in enumerate(self.diffractive_layers):
            w = torch.argmax(torch.nn.functional.gumbel_softmax(layer.voltage,tau=1,hard=True).cpu(), dim=-1)
            with open(fname+str(index)+".npy",'wb') as f:
                print("saving model %s layer(%d) into numpy array %s" % (fname, index, fname+str(index)+".npy"), w.shape)
                np.save(f,w)
            f.close()  
        return

    def save_weights_numpy(self, fname):
        for index, layer in enumerate(self.diffractive_layers):
            w = torch.argmax(torch.nn.functional.gumbel_softmax(layer.voltage,tau=1,hard=True).cpu(), dim=-1)
            with open(fname+str(index)+".npy",'wb') as f:
                print("saving model %s layer(%d) into numpy array %s" % (fname, index, fname+str(index)+".npy"), w.shape)
                np.save(f,w)
            f.close()
        return
    def prop_view(self, x):
        prop_list = []
        prop_list.append(x)
        x = x * self.amp_factor
        for index, layer in enumerate(self.diffractive_layers):
            x = layer(x)
            prop_list.append(x)
        x = self.last_diffraction(x)
        prop_list.append(x)
        for i in range(x.shape[0]):
            print(i)
            utils.forward_func_visualization(prop_list, self.size, fname="mnist_%s.pdf" % i, idx=i)
        output = self.detector(x)
        return
    def phase_view(self,x, cmap="hsv"):
        phase_list = []
        for index, layer in enumerate(self.diffractive_layers):
            phase_list.append(torch.argmax(torch.nn.functional.gumbel_softmax(layer.voltage,tau=1,hard=True).cpu(), dim=-1))
        print(phase_list[0].shape)
        utils.phase_visualization(phase_list,size=self.size, cmap="gray", fname="prop_view_reflection.pdf")
        return

def prop_vis(model, val_dataloader, epoch, args):
    criterion = torch.nn.MSELoss(reduction='sum').cuda()
    with torch.no_grad():
        model.eval()
        tk1 = tqdm(val_dataloader, ncols=100, total=int(len(val_dataloader)))
        for val_iter, val_data_batch in enumerate(tk1):
            val_images, val_labels = utils.data_to_cplex(val_data_batch,device=device)
            val_outputs = model.prop_view(val_images)
            return 

def train(model,train_dataloader, val_dataloader,lambda1,tau, epochs,lr, args):        
    criterion = torch.nn.MSELoss(reduction='sum').to(device)
    print('training starts.')
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=20, gamma=0.5)
    #log = []
    #log.append(tau)
    #log.append(epochs)
    #log.append(lr)
    #log_arr = np.array(log).reshape(1, 3)
    #f = open(args.result_record_path, 'ab')
    #np.savetxt(f, log_arr, fmt='%.4f')
    #f.close()
    
    for epoch in range(args.start_epoch + 1, args.start_epoch + 1 + epochs):
        log = []
        model.train()
        train_len = 0.0
        train_running_counter = 0.0
        train_running_loss = 0.0
        tk0 = tqdm(train_dataloader, ncols=150, total=int(len(train_dataloader)))
        for train_iter, train_data_batch in enumerate(tk0):
            train_images, train_labels = utils.data_to_cplex(train_data_batch, device='cuda:0')
            train_outputs = model(train_images)
            train_loss_ = lambda1 * criterion(train_outputs, train_labels)
            train_counter_ = torch.eq(torch.argmax(train_labels, dim=1), torch.argmax(train_outputs, dim=1)).float().sum()
            
            optimizer.zero_grad()
            train_loss_.backward(retain_graph=True)
            optimizer.step()
            train_len += len(train_labels)
            train_running_loss += train_loss_.item()
            train_running_counter += train_counter_

            train_loss = train_running_loss / train_len
            train_accuracy = train_running_counter / train_len

            tk0.set_description_str('Epoch {}/{} : Training'.format(epoch, args.start_epoch + 1 + epochs - 1))
            tk0.set_postfix({'Train_Loss': '{:.2f}'.format(train_loss), 'Train_Accuracy': '{:.5f}'.format(train_accuracy)})
        scheduler.step()
        log.append(train_loss)
        log.append(train_accuracy.cpu())
        torch.save(model.state_dict(), (args.model_save_path + str(epoch) + args.model_name))
        print('Model : "' + args.model_save_path + str(epoch) + args.model_name + '" saved.')

        val_loss, val_accuracy = eval(model, val_dataloader, epoch, epochs,args)
        log.append(val_loss)
        log.append(val_accuracy.cpu())
        log_arr = np.array(log).reshape(1, 4)
        f = open(args.result_record_path, 'ab')
        np.savetxt(f, log_arr, fmt='%.4f')
        f.close()
    return train_loss, train_accuracy, val_loss, val_accuracy, log 

def eval(model, val_dataloader, epoch,epochs, args):
    criterion = torch.nn.MSELoss(reduction='sum').to(device)
    with torch.no_grad():
        model.eval()
        val_len = 0.0
        val_running_counter = 0.0
        val_running_loss = 0.0

        tk1 = tqdm(val_dataloader, ncols=100, total=int(len(val_dataloader)))
        for val_iter, val_data_batch in enumerate(tk1):
            val_images, val_labels = utils.data_to_cplex(val_data_batch,device='cuda:0')
            val_outputs = model(val_images)

            val_loss_ = criterion(val_outputs, val_labels)
            val_counter_ = torch.eq(torch.argmax(val_labels, dim=1), torch.argmax(val_outputs, dim=1)).float().sum()

            val_len += len(val_labels)
            val_running_loss += val_loss_.item()
            val_running_counter += val_counter_

            val_loss = val_running_loss / val_len
            val_accuracy = val_running_counter / val_len

            tk1.set_description_str('Epoch {}/{} : Validating'.format(epoch, args.start_epoch + 1 + epochs - 1 ))
            tk1.set_postfix({'Val_Loss': '{:.5f}'.format(val_loss), 'Val_Accuarcy': '{:.5f}'.format(val_accuracy)})
    return val_loss, val_accuracy
   
def main(args):
    torch.autograd.set_detect_anomaly(True)
    if not os.path.exists(args.model_save_path):
        os.mkdir(args.model_save_path)

    if args.dataset == "mnist":
        print("training and testing on MNIST10 dataset")
        load_dataset = dataset.load_dataset(batch_size = args.batch_size, system_size = args.sys_size, datapath = "./data")
        train_dataloader, val_dataloader = load_dataset.MNIST()
    elif args.dataset == "Fmnist":
        print("training and testing on FashionMNIST10 dataset")
        load_dataset = dataset.load_dataset(batch_size = args.batch_size, system_size = args.sys_size, datapath = "./Fdata")
        train_dataloader, val_dataloader = load_dataset.FMNIST()
    else:
        assert(0), "current version only supports MNIST10 and FashionMNIST10"

    phase_file =  args.phase_file
    phase_function = utils.phase_func(phase_file,  i_k=args.precision)
    with open('phase_file.npy', 'wb') as f_phase:
        np.save(f_phase, phase_function.cpu().numpy())
    intensity_file = args.intensity_file
    intensity_function = utils.intensity_func(intensity_file,  i_k=args.precision)
    with open('intensity_file.npy', 'wb') as f_amp:
        np.save(f_amp, intensity_function.cpu().numpy())
        
    model = DiffractiveClassifier_CoDesign(num_layers=args.depth, batch_norm =args.use_batch_norm,device=device, 
			wavelength=args.wavelength, pixel_size = args.pixel_size, sys_size=args.sys_size, pad = args.pad, 
                        distance=args.distance,phase_func=phase_function, intensity_func=intensity_function, tau=args.tau[0],
			precision=args.precision, amp_factor=args.amp_factor, approx=args.approx)
    model.to(device)
    sd = model.state_dict()
    if args.whether_load_model:
        model.load_state_dict(torch.load(args.model_save_path + str(args.start_epoch) +  args.model_name))
        print('Model1 : "' + args.model_save_path + str(args.start_epoch) + args.model_name + '" loaded.')
        if args.get_phase:
            utils.get_phase(model, args)
        if args.save_w:
            eval(model, val_dataloader, 0, args)
            model.save_weights_numpy(fname="SLM_MNIST")
            return
    if args.vis:
        prop_vis(model, val_dataloader, 0, args)
        return
    else:
        if os.path.exists(args.result_record_path):
            os.remove(args.result_record_path)
        else:
            with open('init.csv', 'w') as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow(
                    ['Epoch', 'Train_Loss', "Train_Acc", 'Val_Loss', "Val_Acc", "LR"])
    lambda1= args.lambda1
    
    if args.evaluation:
        print('evaluation only!')
        eval(model, val_dataloader, 0,0, args)
        model.phase_view(None)
        return
    elif args.cooling_schedule:
        print('cooling schedule!')
        start_time = time()
        
        for k in range(len(args.tau)):
            model = DiffractiveClassifier_CoDesign(num_layers=args.depth, batch_norm =args.use_batch_norm,device=device, 
			wavelength=args.wavelength, pixel_size = args.pixel_size, sys_size=args.sys_size, pad = args.pad, 
                        distance=args.distance,phase_func=phase_function, intensity_func=intensity_function, tau=args.tau[k],
			precision=args.precision, amp_factor=args.amp_factor, approx=args.approx)
            model.to(device)
            model.load_state_dict(sd)
            train_loss, train_acc, val_loss, val_acc, log = train(model, train_dataloader, val_dataloader, lambda1,args.tau[k], args.epochs[k], args.lr[k], args)
            sd = model.state_dict()

        print('run time', time()-start_time)
    else:
        print('train with single temperature.')
        start_time = time()
        train(model, train_dataloader, val_dataloader,  lambda1, args.tau[0], args.epochs[0],args.lr[0], args)
        print('run time', time()-start_time)
        return
       
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch-size', type=int, default=350)
    #parser.add_argument('--epochs', type=int, default=5)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--dataset', type=str, default="mnist", help='define train/test dataset (mnist, cifar10, cifar100)')
    #parser.add_argument('--lr', type=float, default=0.05, help='learning rate')
    parser.add_argument('--depth', type=int, default=4, help='number of fourier optic transformations/num of layers')
    parser.add_argument('--whether-load-model', type=bool, default=False, help="load pre-train model")
    parser.add_argument('--get-phase', type=bool, default=False, help="load pre-train model and extra phase parameters")
    parser.add_argument('--save-w', type=bool, default=False, help="save voltage parameters for SLM deployment")
    parser.add_argument('--evaluation', type=bool, default=False, help="Evaluation only")
    parser.add_argument('--cooling-schedule', type=bool, default=False, help="train with temperature cooling.")
    parser.add_argument('--start-epoch', type=int, default=0, help='load pre-train model at which epoch')
    parser.add_argument('--model-name', type=str, default='_model.pth')
    parser.add_argument('--model-save-path', type=str, default="./saved_model/")
    parser.add_argument('--result-record-path', type=pathlib.Path, default="./result.csv", help="save training result.")
    parser.add_argument('--lambda1', type=float, default=1, help="loss weight for the model.")
    parser.add_argument('--phase-file', type=str, default='./device_parameters/phase.csv', help="the experimental data collected for phase function.")
    parser.add_argument('--intensity-file', type=str, default='./device_parameters/intensity.csv', help="the experimental data collected for phase function.")
    parser.add_argument('--use-batch-norm', type=bool, default=False, help="use BN layer in modulation")
    parser.add_argument('--vis', type=bool, default=False, help="")
    parser.add_argument('--sys-size', type=int, default=200, help='system size (dim of each diffractive layer)')
    parser.add_argument('--distance', type=float, default=0.6604, help='layer distance (default=0.1 meter)')
    parser.add_argument('--precision', type=int, default=256, help='precision (# bits) of the phase/intensity of given HW (e.g., 2**8 intervals)')
    parser.add_argument('--amp-factor', type=float, default=6, help='regularization factors to balance phase-amplitude where they share same downstream graidents')
    parser.add_argument('--pixel-size', type=float, default=0.000036, help='the size of pixel in diffractive layers')
    parser.add_argument('--pad', type=int, default=100, help='the padding size ')
    parser.add_argument('--approx', type=str, default='Sommerfeld', help="Use which Approximation, Sommerfeld, fresnel or fraunhofer.")
    parser.add_argument('--wavelength', type=float, default=5.32e-7, help='wavelength')
    parser.add_argument('--tau', nargs='+', type=float,default=[10,5,2,25], help='temperature in gumbel softmax')
    parser.add_argument('--epochs', nargs='+', type=int, default=[10,10,10,10], help='training epochs for each temperature in gumbel_softmax.')
    parser.add_argument('--lr', nargs='+', type=float, default=[0.9,0.9,0.9,0.9], help='learning rate for each temperature in gumbel_softmax.')

    torch.backends.cudnn.benchmark = True
    args_ = parser.parse_args()
    random.seed(args_.seed)
    np.random.seed(args_.seed)
    torch.manual_seed(args_.seed)
    main(args_)


